Merge branch 'Wei-Shaw:main' into main
This commit is contained in:
@@ -32,7 +32,7 @@ type AccountRepository interface {
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error)
|
||||
ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
|
||||
ListActive(ctx context.Context) ([]Account, error)
|
||||
ListByPlatform(ctx context.Context, platform string) ([]Account, error)
|
||||
|
||||
@@ -75,7 +75,7 @@ func (s *accountRepoStub) List(ctx context.Context, params pagination.Pagination
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
||||
func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ type AdminService interface {
|
||||
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
|
||||
|
||||
// Account management
|
||||
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
|
||||
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error)
|
||||
GetAccount(ctx context.Context, id int64) (*Account, error)
|
||||
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
|
||||
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
|
||||
@@ -1021,9 +1021,9 @@ func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []
|
||||
}
|
||||
|
||||
// Account management implementations
|
||||
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) {
|
||||
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
|
||||
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ type accountRepoStubForAdminList struct {
|
||||
listWithFiltersErr error
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
||||
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersCalls++
|
||||
s.listWithFiltersParams = params
|
||||
s.listWithFiltersPlatform = platform
|
||||
@@ -168,7 +168,7 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc")
|
||||
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(10), total)
|
||||
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
|
||||
|
||||
@@ -4119,6 +4119,15 @@ func (s *AntigravityGatewayService) extractSSEUsage(line string, usage *ClaudeUs
|
||||
if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 {
|
||||
usage.CacheCreationInputTokens = int(v)
|
||||
}
|
||||
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
|
||||
if cc, ok := u["cache_creation"].(map[string]any); ok {
|
||||
if v, ok := cc["ephemeral_5m_input_tokens"].(float64); ok {
|
||||
usage.CacheCreation5mTokens = int(v)
|
||||
}
|
||||
if v, ok := cc["ephemeral_1h_input_tokens"].(float64); ok {
|
||||
usage.CacheCreation1hTokens = int(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractClaudeUsage 从非流式 Claude 响应提取 usage
|
||||
@@ -4141,6 +4150,15 @@ func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage
|
||||
if v, ok := u["cache_creation_input_tokens"].(float64); ok {
|
||||
usage.CacheCreationInputTokens = int(v)
|
||||
}
|
||||
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
|
||||
if cc, ok := u["cache_creation"].(map[string]any); ok {
|
||||
if v, ok := cc["ephemeral_5m_input_tokens"].(float64); ok {
|
||||
usage.CacheCreation5mTokens = int(v)
|
||||
}
|
||||
if v, ok := cc["ephemeral_1h_input_tokens"].(float64); ok {
|
||||
usage.CacheCreation1hTokens = int(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
@@ -31,8 +31,8 @@ type ModelPricing struct {
|
||||
OutputPricePerToken float64 // 每token输出价格 (USD)
|
||||
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
|
||||
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
|
||||
CacheCreation5mPrice float64 // 5分钟缓存创建价格(每百万token)- 仅用于硬编码回退
|
||||
CacheCreation1hPrice float64 // 1小时缓存创建价格(每百万token)- 仅用于硬编码回退
|
||||
CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
|
||||
CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
|
||||
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
|
||||
}
|
||||
|
||||
@@ -172,12 +172,20 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
||||
if s.pricingService != nil {
|
||||
litellmPricing := s.pricingService.GetModelPricing(model)
|
||||
if litellmPricing != nil {
|
||||
// 启用 5m/1h 分类计费的条件:
|
||||
// 1. 存在 1h 价格
|
||||
// 2. 1h 价格 > 5m 价格(防止 LiteLLM 数据错误导致少收费)
|
||||
price5m := litellmPricing.CacheCreationInputTokenCost
|
||||
price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr
|
||||
enableBreakdown := price1h > 0 && price1h > price5m
|
||||
return &ModelPricing{
|
||||
InputPricePerToken: litellmPricing.InputCostPerToken,
|
||||
OutputPricePerToken: litellmPricing.OutputCostPerToken,
|
||||
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
|
||||
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
|
||||
SupportsCacheBreakdown: false,
|
||||
CacheCreation5mPrice: price5m,
|
||||
CacheCreation1hPrice: price1h,
|
||||
SupportsCacheBreakdown: enableBreakdown,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
@@ -209,9 +217,14 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul
|
||||
|
||||
// 计算缓存费用
|
||||
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
|
||||
// 支持详细缓存分类的模型(5分钟/1小时缓存)
|
||||
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)/1_000_000*pricing.CacheCreation5mPrice +
|
||||
float64(tokens.CacheCreation1hTokens)/1_000_000*pricing.CacheCreation1hPrice
|
||||
// 支持详细缓存分类的模型(5分钟/1小时缓存,价格为 per-token)
|
||||
if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 {
|
||||
// API 未返回 ephemeral 明细,回退到全部按 5m 单价计费
|
||||
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice
|
||||
} else {
|
||||
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice +
|
||||
float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice
|
||||
}
|
||||
} else {
|
||||
// 标准缓存创建价格(per-token)
|
||||
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
|
||||
@@ -280,10 +293,12 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
|
||||
|
||||
// 范围内部分:正常计费
|
||||
inRangeTokens := UsageTokens{
|
||||
InputTokens: inRangeInputTokens,
|
||||
OutputTokens: tokens.OutputTokens, // 输出只算一次
|
||||
CacheCreationTokens: tokens.CacheCreationTokens,
|
||||
CacheReadTokens: inRangeCacheTokens,
|
||||
InputTokens: inRangeInputTokens,
|
||||
OutputTokens: tokens.OutputTokens, // 输出只算一次
|
||||
CacheCreationTokens: tokens.CacheCreationTokens,
|
||||
CacheReadTokens: inRangeCacheTokens,
|
||||
CacheCreation5mTokens: tokens.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: tokens.CacheCreation1hTokens,
|
||||
}
|
||||
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
|
||||
if err != nil {
|
||||
|
||||
@@ -87,7 +87,7 @@ func (m *mockAccountRepoForPlatform) Delete(ctx context.Context, id int64) error
|
||||
func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
||||
func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
|
||||
@@ -349,6 +349,8 @@ type ClaudeUsage struct {
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||
CacheCreation5mTokens int // 5分钟缓存创建token(来自嵌套 cache_creation 对象)
|
||||
CacheCreation1hTokens int // 1小时缓存创建token(来自嵌套 cache_creation 对象)
|
||||
}
|
||||
|
||||
// ForwardResult 转发结果
|
||||
@@ -4403,6 +4405,14 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
||||
usage.InputTokens = msgStart.Message.Usage.InputTokens
|
||||
usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens
|
||||
usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens
|
||||
|
||||
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
|
||||
cc5m := gjson.Get(data, "message.usage.cache_creation.ephemeral_5m_input_tokens")
|
||||
cc1h := gjson.Get(data, "message.usage.cache_creation.ephemeral_1h_input_tokens")
|
||||
if cc5m.Exists() || cc1h.Exists() {
|
||||
usage.CacheCreation5mTokens = int(cc5m.Int())
|
||||
usage.CacheCreation1hTokens = int(cc1h.Int())
|
||||
}
|
||||
}
|
||||
|
||||
// 解析message_delta获取tokens(兼容GLM等把所有usage放在delta中的API)
|
||||
@@ -4431,6 +4441,14 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
||||
if msgDelta.Usage.CacheReadInputTokens > 0 {
|
||||
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
|
||||
}
|
||||
|
||||
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
|
||||
cc5m := gjson.Get(data, "usage.cache_creation.ephemeral_5m_input_tokens")
|
||||
cc1h := gjson.Get(data, "usage.cache_creation.ephemeral_1h_input_tokens")
|
||||
if cc5m.Exists() || cc1h.Exists() {
|
||||
usage.CacheCreation5mTokens = int(cc5m.Int())
|
||||
usage.CacheCreation1hTokens = int(cc1h.Int())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4451,6 +4469,14 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
|
||||
cc5m := gjson.GetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens")
|
||||
cc1h := gjson.GetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens")
|
||||
if cc5m.Exists() || cc1h.Exists() {
|
||||
response.Usage.CacheCreation5mTokens = int(cc5m.Int())
|
||||
response.Usage.CacheCreation1hTokens = int(cc1h.Int())
|
||||
}
|
||||
|
||||
// 兼容 Kimi cached_tokens → cache_read_input_tokens
|
||||
if response.Usage.CacheReadInputTokens == 0 {
|
||||
cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int()
|
||||
@@ -4560,10 +4586,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
} else {
|
||||
// Token 计费
|
||||
tokens := UsageTokens{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
}
|
||||
var err error
|
||||
cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier)
|
||||
@@ -4597,6 +4625,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
InputCost: cost.InputCost,
|
||||
OutputCost: cost.OutputCost,
|
||||
CacheCreationCost: cost.CacheCreationCost,
|
||||
@@ -4741,10 +4771,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
} else {
|
||||
// Token 计费(使用长上下文计费方法)
|
||||
tokens := UsageTokens{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
}
|
||||
var err error
|
||||
cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
|
||||
@@ -4778,6 +4810,8 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
InputCost: cost.InputCost,
|
||||
OutputCost: cost.OutputCost,
|
||||
CacheCreationCost: cost.CacheCreationCost,
|
||||
|
||||
@@ -74,7 +74,7 @@ func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error
|
||||
func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
||||
func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
|
||||
@@ -24,7 +24,7 @@ func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter s
|
||||
accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{
|
||||
Page: page,
|
||||
PageSize: opsAccountsPageSize,
|
||||
}, platformFilter, "", "", "")
|
||||
}, platformFilter, "", "", "", 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -27,14 +27,15 @@ var (
|
||||
// LiteLLMModelPricing LiteLLM价格数据结构
|
||||
// 只保留我们需要的字段,使用指针来处理可能缺失的值
|
||||
type LiteLLMModelPricing struct {
|
||||
InputCostPerToken float64 `json:"input_cost_per_token"`
|
||||
OutputCostPerToken float64 `json:"output_cost_per_token"`
|
||||
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
|
||||
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
|
||||
LiteLLMProvider string `json:"litellm_provider"`
|
||||
Mode string `json:"mode"`
|
||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||
OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格
|
||||
InputCostPerToken float64 `json:"input_cost_per_token"`
|
||||
OutputCostPerToken float64 `json:"output_cost_per_token"`
|
||||
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
|
||||
CacheCreationInputTokenCostAbove1hr float64 `json:"cache_creation_input_token_cost_above_1hr"`
|
||||
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
|
||||
LiteLLMProvider string `json:"litellm_provider"`
|
||||
Mode string `json:"mode"`
|
||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||
OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格
|
||||
}
|
||||
|
||||
// PricingRemoteClient 远程价格数据获取接口
|
||||
@@ -45,14 +46,15 @@ type PricingRemoteClient interface {
|
||||
|
||||
// LiteLLMRawEntry 用于解析原始JSON数据
|
||||
type LiteLLMRawEntry struct {
|
||||
InputCostPerToken *float64 `json:"input_cost_per_token"`
|
||||
OutputCostPerToken *float64 `json:"output_cost_per_token"`
|
||||
CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"`
|
||||
CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"`
|
||||
LiteLLMProvider string `json:"litellm_provider"`
|
||||
Mode string `json:"mode"`
|
||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||
OutputCostPerImage *float64 `json:"output_cost_per_image"`
|
||||
InputCostPerToken *float64 `json:"input_cost_per_token"`
|
||||
OutputCostPerToken *float64 `json:"output_cost_per_token"`
|
||||
CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"`
|
||||
CacheCreationInputTokenCostAbove1hr *float64 `json:"cache_creation_input_token_cost_above_1hr"`
|
||||
CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"`
|
||||
LiteLLMProvider string `json:"litellm_provider"`
|
||||
Mode string `json:"mode"`
|
||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||
OutputCostPerImage *float64 `json:"output_cost_per_image"`
|
||||
}
|
||||
|
||||
// PricingService 动态价格服务
|
||||
@@ -318,6 +320,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel
|
||||
if entry.CacheCreationInputTokenCost != nil {
|
||||
pricing.CacheCreationInputTokenCost = *entry.CacheCreationInputTokenCost
|
||||
}
|
||||
if entry.CacheCreationInputTokenCostAbove1hr != nil {
|
||||
pricing.CacheCreationInputTokenCostAbove1hr = *entry.CacheCreationInputTokenCostAbove1hr
|
||||
}
|
||||
if entry.CacheReadInputTokenCost != nil {
|
||||
pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost
|
||||
}
|
||||
|
||||
@@ -381,10 +381,31 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 尝试从响应头解析重置时间(Anthropic)
|
||||
// 2. Anthropic 平台:尝试解析 per-window 头(5h / 7d),选择实际触发的窗口
|
||||
if result := calculateAnthropic429ResetTime(headers); result != nil {
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, result.resetAt); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 更新 session window:优先使用 5h-reset 头精确计算,否则从 resetAt 反推
|
||||
windowEnd := result.resetAt
|
||||
if result.fiveHourReset != nil {
|
||||
windowEnd = *result.fiveHourReset
|
||||
}
|
||||
windowStart := windowEnd.Add(-5 * time.Hour)
|
||||
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
|
||||
slog.Warn("rate_limit_update_session_window_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
|
||||
slog.Info("anthropic_account_rate_limited", "account_id", account.ID, "reset_at", result.resetAt, "reset_in", time.Until(result.resetAt).Truncate(time.Second))
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 尝试从响应头解析重置时间(Anthropic 聚合头,向后兼容)
|
||||
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
|
||||
|
||||
// 3. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini)
|
||||
// 4. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini)
|
||||
if resetTimestamp == "" {
|
||||
switch account.Platform {
|
||||
case PlatformOpenAI:
|
||||
@@ -497,6 +518,112 @@ func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *tim
|
||||
return nil
|
||||
}
|
||||
|
||||
// anthropic429Result holds the parsed Anthropic 429 rate-limit information.
|
||||
type anthropic429Result struct {
|
||||
resetAt time.Time // The correct reset time to use for SetRateLimited
|
||||
fiveHourReset *time.Time // 5h window reset timestamp (for session window calculation), nil if not available
|
||||
}
|
||||
|
||||
// calculateAnthropic429ResetTime parses Anthropic's per-window rate-limit headers
|
||||
// to determine which window (5h or 7d) actually triggered the 429.
|
||||
//
|
||||
// Headers used:
|
||||
// - anthropic-ratelimit-unified-5h-utilization / anthropic-ratelimit-unified-5h-surpassed-threshold
|
||||
// - anthropic-ratelimit-unified-5h-reset
|
||||
// - anthropic-ratelimit-unified-7d-utilization / anthropic-ratelimit-unified-7d-surpassed-threshold
|
||||
// - anthropic-ratelimit-unified-7d-reset
|
||||
//
|
||||
// Returns nil when the per-window headers are absent (caller should fall back to
|
||||
// the aggregated anthropic-ratelimit-unified-reset header).
|
||||
func calculateAnthropic429ResetTime(headers http.Header) *anthropic429Result {
|
||||
reset5hStr := headers.Get("anthropic-ratelimit-unified-5h-reset")
|
||||
reset7dStr := headers.Get("anthropic-ratelimit-unified-7d-reset")
|
||||
|
||||
if reset5hStr == "" && reset7dStr == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var reset5h, reset7d *time.Time
|
||||
if ts, err := strconv.ParseInt(reset5hStr, 10, 64); err == nil {
|
||||
t := time.Unix(ts, 0)
|
||||
reset5h = &t
|
||||
}
|
||||
if ts, err := strconv.ParseInt(reset7dStr, 10, 64); err == nil {
|
||||
t := time.Unix(ts, 0)
|
||||
reset7d = &t
|
||||
}
|
||||
|
||||
is5hExceeded := isAnthropicWindowExceeded(headers, "5h")
|
||||
is7dExceeded := isAnthropicWindowExceeded(headers, "7d")
|
||||
|
||||
slog.Info("anthropic_429_window_analysis",
|
||||
"is_5h_exceeded", is5hExceeded,
|
||||
"is_7d_exceeded", is7dExceeded,
|
||||
"reset_5h", reset5hStr,
|
||||
"reset_7d", reset7dStr,
|
||||
)
|
||||
|
||||
// Select the correct reset time based on which window(s) are exceeded.
|
||||
var chosen *time.Time
|
||||
switch {
|
||||
case is5hExceeded && is7dExceeded:
|
||||
// Both exceeded → prefer 7d (longer cooldown), fall back to 5h
|
||||
chosen = reset7d
|
||||
if chosen == nil {
|
||||
chosen = reset5h
|
||||
}
|
||||
case is5hExceeded:
|
||||
chosen = reset5h
|
||||
case is7dExceeded:
|
||||
chosen = reset7d
|
||||
default:
|
||||
// Neither flag clearly exceeded — pick the sooner reset as best guess
|
||||
chosen = pickSooner(reset5h, reset7d)
|
||||
}
|
||||
|
||||
if chosen == nil {
|
||||
return nil
|
||||
}
|
||||
return &anthropic429Result{resetAt: *chosen, fiveHourReset: reset5h}
|
||||
}
|
||||
|
||||
// isAnthropicWindowExceeded checks whether a given Anthropic rate-limit window
|
||||
// (e.g. "5h" or "7d") has been exceeded, using utilization and surpassed-threshold headers.
|
||||
func isAnthropicWindowExceeded(headers http.Header, window string) bool {
|
||||
prefix := "anthropic-ratelimit-unified-" + window + "-"
|
||||
|
||||
// Check surpassed-threshold first (most explicit signal)
|
||||
if st := headers.Get(prefix + "surpassed-threshold"); strings.EqualFold(st, "true") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Fall back to utilization >= 1.0
|
||||
if utilStr := headers.Get(prefix + "utilization"); utilStr != "" {
|
||||
if util, err := strconv.ParseFloat(utilStr, 64); err == nil && util >= 1.0-1e-9 {
|
||||
// Use a small epsilon to handle floating point: treat 0.9999999... as >= 1.0
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// pickSooner returns whichever of the two time pointers is earlier.
|
||||
// If only one is non-nil, it is returned. If both are nil, returns nil.
|
||||
func pickSooner(a, b *time.Time) *time.Time {
|
||||
switch {
|
||||
case a != nil && b != nil:
|
||||
if a.Before(*b) {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
case a != nil:
|
||||
return a
|
||||
default:
|
||||
return b
|
||||
}
|
||||
}
|
||||
|
||||
// parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳
|
||||
// OpenAI 的 usage_limit_reached 错误格式:
|
||||
//
|
||||
|
||||
202
backend/internal/service/ratelimit_service_anthropic_test.go
Normal file
202
backend/internal/service/ratelimit_service_anthropic_test.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCalculateAnthropic429ResetTime_Only5hExceeded(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.02")
|
||||
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
|
||||
headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.32")
|
||||
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
|
||||
|
||||
result := calculateAnthropic429ResetTime(headers)
|
||||
assertAnthropicResult(t, result, 1770998400)
|
||||
|
||||
if result.fiveHourReset == nil || !result.fiveHourReset.Equal(time.Unix(1770998400, 0)) {
|
||||
t.Errorf("expected fiveHourReset=1770998400, got %v", result.fiveHourReset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateAnthropic429ResetTime_Only7dExceeded(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("anthropic-ratelimit-unified-5h-utilization", "0.50")
|
||||
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
|
||||
headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.05")
|
||||
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
|
||||
|
||||
result := calculateAnthropic429ResetTime(headers)
|
||||
assertAnthropicResult(t, result, 1771549200)
|
||||
|
||||
// fiveHourReset should still be populated for session window calculation
|
||||
if result.fiveHourReset == nil || !result.fiveHourReset.Equal(time.Unix(1770998400, 0)) {
|
||||
t.Errorf("expected fiveHourReset=1770998400, got %v", result.fiveHourReset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateAnthropic429ResetTime_BothExceeded(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.10")
|
||||
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
|
||||
headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.02")
|
||||
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
|
||||
|
||||
result := calculateAnthropic429ResetTime(headers)
|
||||
assertAnthropicResult(t, result, 1771549200)
|
||||
}
|
||||
|
||||
func TestCalculateAnthropic429ResetTime_NoPerWindowHeaders(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("anthropic-ratelimit-unified-reset", "1771549200")
|
||||
|
||||
result := calculateAnthropic429ResetTime(headers)
|
||||
if result != nil {
|
||||
t.Errorf("expected nil result when no per-window headers, got resetAt=%v", result.resetAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateAnthropic429ResetTime_NoHeaders(t *testing.T) {
|
||||
result := calculateAnthropic429ResetTime(http.Header{})
|
||||
if result != nil {
|
||||
t.Errorf("expected nil result for empty headers, got resetAt=%v", result.resetAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateAnthropic429ResetTime_SurpassedThreshold(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("anthropic-ratelimit-unified-5h-surpassed-threshold", "true")
|
||||
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
|
||||
headers.Set("anthropic-ratelimit-unified-7d-surpassed-threshold", "false")
|
||||
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
|
||||
|
||||
result := calculateAnthropic429ResetTime(headers)
|
||||
assertAnthropicResult(t, result, 1770998400)
|
||||
}
|
||||
|
||||
func TestCalculateAnthropic429ResetTime_UtilizationExactlyOne(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.0")
|
||||
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
|
||||
headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.5")
|
||||
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
|
||||
|
||||
result := calculateAnthropic429ResetTime(headers)
|
||||
assertAnthropicResult(t, result, 1770998400)
|
||||
}
|
||||
|
||||
func TestCalculateAnthropic429ResetTime_NeitherExceeded_UsesShorter(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("anthropic-ratelimit-unified-5h-utilization", "0.95")
|
||||
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") // sooner
|
||||
headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.80")
|
||||
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") // later
|
||||
|
||||
result := calculateAnthropic429ResetTime(headers)
|
||||
assertAnthropicResult(t, result, 1770998400)
|
||||
}
|
||||
|
||||
func TestCalculateAnthropic429ResetTime_Only5hResetHeader(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.05")
|
||||
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
|
||||
|
||||
result := calculateAnthropic429ResetTime(headers)
|
||||
assertAnthropicResult(t, result, 1770998400)
|
||||
}
|
||||
|
||||
func TestCalculateAnthropic429ResetTime_Only7dResetHeader(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.03")
|
||||
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
|
||||
|
||||
result := calculateAnthropic429ResetTime(headers)
|
||||
assertAnthropicResult(t, result, 1771549200)
|
||||
|
||||
if result.fiveHourReset != nil {
|
||||
t.Errorf("expected fiveHourReset=nil when no 5h headers, got %v", result.fiveHourReset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAnthropicWindowExceeded(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
headers http.Header
|
||||
window string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "utilization above 1.0",
|
||||
headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "1.02"),
|
||||
window: "5h",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "utilization exactly 1.0",
|
||||
headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "1.0"),
|
||||
window: "5h",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "utilization below 1.0",
|
||||
headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "0.99"),
|
||||
window: "5h",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "surpassed-threshold true",
|
||||
headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "true"),
|
||||
window: "7d",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "surpassed-threshold True (case insensitive)",
|
||||
headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "True"),
|
||||
window: "7d",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "surpassed-threshold false",
|
||||
headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "false"),
|
||||
window: "7d",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "no headers",
|
||||
headers: http.Header{},
|
||||
window: "5h",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := isAnthropicWindowExceeded(tc.headers, tc.window)
|
||||
if got != tc.expected {
|
||||
t.Errorf("expected %v, got %v", tc.expected, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// assertAnthropicResult is a test helper that verifies the result is non-nil and
|
||||
// has the expected resetAt unix timestamp.
|
||||
func assertAnthropicResult(t *testing.T, result *anthropic429Result, wantUnix int64) {
|
||||
t.Helper()
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
return // unreachable, but satisfies staticcheck SA5011
|
||||
}
|
||||
want := time.Unix(wantUnix, 0)
|
||||
if !result.resetAt.Equal(want) {
|
||||
t.Errorf("expected resetAt=%v, got %v", want, result.resetAt)
|
||||
}
|
||||
}
|
||||
|
||||
func makeHeader(key, value string) http.Header {
|
||||
h := http.Header{}
|
||||
h.Set(key, value)
|
||||
return h
|
||||
}
|
||||
@@ -26,8 +26,8 @@ type UsageLog struct {
|
||||
CacheCreationTokens int
|
||||
CacheReadTokens int
|
||||
|
||||
CacheCreation5mTokens int
|
||||
CacheCreation1hTokens int
|
||||
CacheCreation5mTokens int `gorm:"column:cache_creation_5m_tokens"`
|
||||
CacheCreation1hTokens int `gorm:"column:cache_creation_1h_tokens"`
|
||||
|
||||
InputCost float64
|
||||
OutputCost float64
|
||||
|
||||
Reference in New Issue
Block a user