From acb718d35560a8d83029fa8825a5f6ae93f16b40 Mon Sep 17 00:00:00 2001 From: ianshaw Date: Sat, 3 Jan 2026 06:32:51 -0800 Subject: [PATCH] =?UTF-8?q?perf(gateway):=20=E4=BC=98=E5=8C=96=E8=B4=9F?= =?UTF-8?q?=E8=BD=BD=E6=84=9F=E7=9F=A5=E8=B0=83=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 主要改进: - 优化负载感知调度的准确性和响应速度 - 将 AccountUsageService 的包级缓存改为依赖注入 - 修复 SSE/JSON 转义和 nil 安全问题 - 恢复 Google One 功能兼容性 --- .../internal/repository/concurrency_cache.go | 217 ++------------- .../concurrency_cache_integration_test.go | 39 ++- .../internal/service/account_usage_service.go | 247 +++++++++++++++--- .../internal/service/concurrency_service.go | 9 - .../service/gateway_multiplatform_test.go | 10 +- backend/internal/service/gateway_request.go | 46 +++- backend/internal/service/gateway_service.go | 111 +++++--- 7 files changed, 369 insertions(+), 310 deletions(-) diff --git a/backend/internal/repository/concurrency_cache.go b/backend/internal/repository/concurrency_cache.go index dfa555aa..0831f5eb 100644 --- a/backend/internal/repository/concurrency_cache.go +++ b/backend/internal/repository/concurrency_cache.go @@ -27,14 +27,8 @@ const ( accountSlotKeyPrefix = "concurrency:account:" // 格式: concurrency:user:{userID} userSlotKeyPrefix = "concurrency:user:" - - // Wait queue keys (global structures) - // - total: integer total queue depth across all users - // - updated: sorted set of userID -> lastUpdateUnixSec (for TTL cleanup) - // - counts: hash of userID -> current wait count - waitQueueTotalKey = "concurrency:wait:total" - waitQueueUpdatedKey = "concurrency:wait:updated" - waitQueueCountsKey = "concurrency:wait:counts" + // 等待队列计数器格式: concurrency:wait:{userID} + waitQueueKeyPrefix = "concurrency:wait:" // 账号级等待队列计数器格式: wait:account:{accountID} accountWaitKeyPrefix = "wait:account:" @@ -100,55 +94,27 @@ var ( `) // incrementWaitScript - only sets TTL on first creation to avoid refreshing - // KEYS[1] = total key - // KEYS[2] = updated zset key - // KEYS[3] = counts hash key - // ARGV[1] = userID - // ARGV[2] = maxWait - // ARGV[3] = TTL in seconds - // ARGV[4] = cleanup limit + // KEYS[1] = wait queue key + // ARGV[1] = maxWait + // ARGV[2] = TTL in seconds incrementWaitScript = redis.NewScript(` - local totalKey = KEYS[1] - local updatedKey = KEYS[2] - local countsKey = KEYS[3] - - local userID = ARGV[1] - local maxWait = tonumber(ARGV[2]) - local ttl = tonumber(ARGV[3]) - local cleanupLimit = tonumber(ARGV[4]) - - redis.call('SETNX', totalKey, 0) - - local timeResult = redis.call('TIME') - local now = tonumber(timeResult[1]) - local expireBefore = now - ttl - - -- Cleanup expired users (bounded) - local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit) - for _, uid in ipairs(expired) do - local c = tonumber(redis.call('HGET', countsKey, uid) or '0') - if c > 0 then - redis.call('DECRBY', totalKey, c) - end - redis.call('HDEL', countsKey, uid) - redis.call('ZREM', updatedKey, uid) + local current = redis.call('GET', KEYS[1]) + if current == false then + current = 0 + else + current = tonumber(current) end - local current = tonumber(redis.call('HGET', countsKey, userID) or '0') - if current >= maxWait then + if current >= tonumber(ARGV[1]) then return 0 end - local newVal = current + 1 - redis.call('HSET', countsKey, userID, newVal) - redis.call('ZADD', updatedKey, now, userID) - redis.call('INCR', totalKey) + local newVal = redis.call('INCR', KEYS[1]) - -- Keep global structures from living forever in totally idle deployments. - local ttlKeep = ttl * 2 - redis.call('EXPIRE', totalKey, ttlKeep) - redis.call('EXPIRE', updatedKey, ttlKeep) - redis.call('EXPIRE', countsKey, ttlKeep) + -- Only set TTL on first creation to avoid refreshing zombie data + if newVal == 1 then + redis.call('EXPIRE', KEYS[1], ARGV[2]) + end return 1 `) @@ -178,111 +144,6 @@ var ( // decrementWaitScript - same as before decrementWaitScript = redis.NewScript(` - local totalKey = KEYS[1] - local updatedKey = KEYS[2] - local countsKey = KEYS[3] - - local userID = ARGV[1] - local ttl = tonumber(ARGV[2]) - local cleanupLimit = tonumber(ARGV[3]) - - redis.call('SETNX', totalKey, 0) - - local timeResult = redis.call('TIME') - local now = tonumber(timeResult[1]) - local expireBefore = now - ttl - - -- Cleanup expired users (bounded) - local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit) - for _, uid in ipairs(expired) do - local c = tonumber(redis.call('HGET', countsKey, uid) or '0') - if c > 0 then - redis.call('DECRBY', totalKey, c) - end - redis.call('HDEL', countsKey, uid) - redis.call('ZREM', updatedKey, uid) - end - - local current = tonumber(redis.call('HGET', countsKey, userID) or '0') - if current <= 0 then - return 1 - end - - local newVal = current - 1 - if newVal <= 0 then - redis.call('HDEL', countsKey, userID) - redis.call('ZREM', updatedKey, userID) - else - redis.call('HSET', countsKey, userID, newVal) - redis.call('ZADD', updatedKey, now, userID) - end - redis.call('DECR', totalKey) - - local ttlKeep = ttl * 2 - redis.call('EXPIRE', totalKey, ttlKeep) - redis.call('EXPIRE', updatedKey, ttlKeep) - redis.call('EXPIRE', countsKey, ttlKeep) - - return 1 - `) - - // getTotalWaitScript returns the global wait depth with TTL cleanup. - // KEYS[1] = total key - // KEYS[2] = updated zset key - // KEYS[3] = counts hash key - // ARGV[1] = TTL in seconds - // ARGV[2] = cleanup limit - getTotalWaitScript = redis.NewScript(` - local totalKey = KEYS[1] - local updatedKey = KEYS[2] - local countsKey = KEYS[3] - - local ttl = tonumber(ARGV[1]) - local cleanupLimit = tonumber(ARGV[2]) - - redis.call('SETNX', totalKey, 0) - - local timeResult = redis.call('TIME') - local now = tonumber(timeResult[1]) - local expireBefore = now - ttl - - -- Cleanup expired users (bounded) - local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit) - for _, uid in ipairs(expired) do - local c = tonumber(redis.call('HGET', countsKey, uid) or '0') - if c > 0 then - redis.call('DECRBY', totalKey, c) - end - redis.call('HDEL', countsKey, uid) - redis.call('ZREM', updatedKey, uid) - end - - -- If totalKey got lost but counts exist (e.g. Redis restart), recompute once. - local total = redis.call('GET', totalKey) - if total == false then - total = 0 - local vals = redis.call('HVALS', countsKey) - for _, v in ipairs(vals) do - total = total + tonumber(v) - end - redis.call('SET', totalKey, total) - end - - local ttlKeep = ttl * 2 - redis.call('EXPIRE', totalKey, ttlKeep) - redis.call('EXPIRE', updatedKey, ttlKeep) - redis.call('EXPIRE', countsKey, ttlKeep) - - local result = tonumber(redis.call('GET', totalKey) or '0') - if result < 0 then - result = 0 - redis.call('SET', totalKey, 0) - end - return result - `) - - // decrementAccountWaitScript - account-level wait queue decrement - decrementAccountWaitScript = redis.NewScript(` local current = redis.call('GET', KEYS[1]) if current ~= false and tonumber(current) > 0 then redis.call('DECR', KEYS[1]) @@ -383,9 +244,7 @@ func userSlotKey(userID int64) string { } func waitQueueKey(userID int64) string { - // Historical: per-user string keys were used. - // Now we use global structures keyed by userID string. - return strconv.FormatInt(userID, 10) + return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) } func accountWaitKey(accountID int64) string { @@ -449,16 +308,8 @@ func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) // Wait queue operations func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { - userKey := waitQueueKey(userID) - result, err := incrementWaitScript.Run( - ctx, - c.rdb, - []string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey}, - userKey, - maxWait, - c.waitQueueTTLSeconds, - 200, // cleanup limit per call - ).Int() + key := waitQueueKey(userID) + result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int() if err != nil { return false, err } @@ -466,35 +317,11 @@ func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, } func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error { - userKey := waitQueueKey(userID) - _, err := decrementWaitScript.Run( - ctx, - c.rdb, - []string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey}, - userKey, - c.waitQueueTTLSeconds, - 200, // cleanup limit per call - ).Result() + key := waitQueueKey(userID) + _, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result() return err } -func (c *concurrencyCache) GetTotalWaitCount(ctx context.Context) (int, error) { - if c.rdb == nil { - return 0, nil - } - total, err := getTotalWaitScript.Run( - ctx, - c.rdb, - []string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey}, - c.waitQueueTTLSeconds, - 500, // cleanup limit per query (rare) - ).Int64() - if err != nil { - return 0, err - } - return int(total), nil -} - // Account wait queue operations func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { @@ -508,7 +335,7 @@ func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accoun func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error { key := accountWaitKey(accountID) - _, err := decrementAccountWaitScript.Run(ctx, c.rdb, []string{key}).Result() + _, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result() return err } diff --git a/backend/internal/repository/concurrency_cache_integration_test.go b/backend/internal/repository/concurrency_cache_integration_test.go index 56cd1d2e..5983c832 100644 --- a/backend/internal/repository/concurrency_cache_integration_test.go +++ b/backend/internal/repository/concurrency_cache_integration_test.go @@ -158,7 +158,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() { func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() { userID := int64(20) - userKey := waitQueueKey(userID) + waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2) require.NoError(s.T(), err, "IncrementWaitCount 1") @@ -172,31 +172,31 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() { require.NoError(s.T(), err, "IncrementWaitCount 3") require.False(s.T(), ok, "expected wait increment over max to fail") - ttl, err := s.rdb.TTL(s.ctx, waitQueueTotalKey).Result() - require.NoError(s.T(), err, "TTL wait total key") - s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL*2) + ttl, err := s.rdb.TTL(s.ctx, waitKey).Result() + require.NoError(s.T(), err, "TTL waitKey") + s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL) require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount") - val, err := s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Int() - require.NoError(s.T(), err, "HGET wait queue count") + val, err := s.rdb.Get(s.ctx, waitKey).Int() + if !errors.Is(err, redis.Nil) { + require.NoError(s.T(), err, "Get waitKey") + } require.Equal(s.T(), 1, val, "expected wait count 1") - - total, err := s.rdb.Get(s.ctx, waitQueueTotalKey).Int() - require.NoError(s.T(), err, "GET wait queue total") - require.Equal(s.T(), 1, total, "expected total wait count 1") } func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() { userID := int64(300) - userKey := waitQueueKey(userID) + waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) // Test decrement on non-existent key - should not error and should not create negative value require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key") - // Verify count remains zero / absent. - val, err := s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Int() - require.True(s.T(), errors.Is(err, redis.Nil)) + // Verify no key was created or it's not negative + val, err := s.rdb.Get(s.ctx, waitKey).Int() + if !errors.Is(err, redis.Nil) { + require.NoError(s.T(), err, "Get waitKey") + } require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty") // Set count to 1, then decrement twice @@ -210,15 +210,12 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() { // Decrement again on 0 - should not go negative require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero") - // Verify per-user count is absent and total is non-negative. - _, err = s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Result() - require.True(s.T(), errors.Is(err, redis.Nil), "expected count field removed on zero") - - total, err := s.rdb.Get(s.ctx, waitQueueTotalKey).Int() + // Verify count is 0, not negative + val, err = s.rdb.Get(s.ctx, waitKey).Int() if !errors.Is(err, redis.Nil) { - require.NoError(s.T(), err) + require.NoError(s.T(), err, "Get waitKey after double decrement") } - require.GreaterOrEqual(s.T(), total, 0, "expected non-negative total wait count") + require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count") } func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() { diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 0fc5c45e..50ee94d1 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -4,9 +4,11 @@ import ( "context" "fmt" "log" + "strings" "sync" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" ) @@ -17,11 +19,11 @@ type UsageLogRepository interface { Delete(ctx context.Context, id int64) error ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) - ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) + ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) - ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) + ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) @@ -32,10 +34,10 @@ type UsageLogRepository interface { GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error) - GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) + GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) - GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) + GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) // User dashboard stats GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) @@ -51,7 +53,7 @@ type UsageLogRepository interface { // Aggregated stats (optimized) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) - GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) + GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) @@ -69,13 +71,33 @@ type windowStatsCache struct { timestamp time.Time } -var ( - apiCacheMap = sync.Map{} // 缓存 API 响应 - windowStatsCacheMap = sync.Map{} // 缓存窗口统计 +// antigravityUsageCache 缓存 Antigravity 额度数据 +type antigravityUsageCache struct { + usageInfo *UsageInfo + timestamp time.Time +} + +const ( apiCacheTTL = 10 * time.Minute windowStatsCacheTTL = 1 * time.Minute ) +// UsageCache 封装账户使用量相关的缓存 +type UsageCache struct { + apiCache *sync.Map // accountID -> *apiUsageCache + windowStatsCache *sync.Map // accountID -> *windowStatsCache + antigravityCache *sync.Map // accountID -> *antigravityUsageCache +} + +// NewUsageCache 创建 UsageCache 实例 +func NewUsageCache() *UsageCache { + return &UsageCache{ + apiCache: &sync.Map{}, + antigravityCache: &sync.Map{}, + windowStatsCache: &sync.Map{}, + } +} + // WindowStats 窗口期统计 type WindowStats struct { Requests int64 `json:"requests"` @@ -91,6 +113,12 @@ type UsageProgress struct { WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量) } +// AntigravityModelQuota Antigravity 单个模型的配额信息 +type AntigravityModelQuota struct { + Utilization int `json:"utilization"` // 使用率 0-100 + ResetTime string `json:"reset_time"` // 重置时间 ISO8601 +} + // UsageInfo 账号使用量信息 type UsageInfo struct { UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间 @@ -99,6 +127,9 @@ type UsageInfo struct { SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口 GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额 GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额 + + // Antigravity 多模型配额 + AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"` } // ClaudeUsageResponse Anthropic API返回的usage结构 @@ -124,19 +155,51 @@ type ClaudeUsageFetcher interface { // AccountUsageService 账号使用量查询服务 type AccountUsageService struct { - accountRepo AccountRepository - usageLogRepo UsageLogRepository - usageFetcher ClaudeUsageFetcher - geminiQuotaService *GeminiQuotaService + accountRepo AccountRepository + usageLogRepo UsageLogRepository + usageFetcher ClaudeUsageFetcher + geminiQuotaService *GeminiQuotaService + antigravityQuotaFetcher QuotaFetcher + cache *UsageCache } // NewAccountUsageService 创建AccountUsageService实例 -func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher, geminiQuotaService *GeminiQuotaService) *AccountUsageService { +func NewAccountUsageService( + accountRepo AccountRepository, + usageLogRepo UsageLogRepository, + usageFetcher ClaudeUsageFetcher, + geminiQuotaService *GeminiQuotaService, + antigravityQuotaFetcher *AntigravityQuotaFetcher, + cache *UsageCache, +) *AccountUsageService { + if cache == nil { + cache = &UsageCache{ + apiCache: &sync.Map{}, + antigravityCache: &sync.Map{}, + windowStatsCache: &sync.Map{}, + } + } + if cache.apiCache == nil { + cache.apiCache = &sync.Map{} + } + if cache.antigravityCache == nil { + cache.antigravityCache = &sync.Map{} + } + if cache.windowStatsCache == nil { + cache.windowStatsCache = &sync.Map{} + } + + var quotaFetcher QuotaFetcher + if antigravityQuotaFetcher != nil { + quotaFetcher = antigravityQuotaFetcher + } return &AccountUsageService{ - accountRepo: accountRepo, - usageLogRepo: usageLogRepo, - usageFetcher: usageFetcher, - geminiQuotaService: geminiQuotaService, + accountRepo: accountRepo, + usageLogRepo: usageLogRepo, + usageFetcher: usageFetcher, + geminiQuotaService: geminiQuotaService, + antigravityQuotaFetcher: quotaFetcher, + cache: cache, } } @@ -154,12 +217,17 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U return s.getGeminiUsage(ctx, account) } + // Antigravity 平台:使用 AntigravityQuotaFetcher 获取额度 + if account.Platform == PlatformAntigravity { + return s.getAntigravityUsage(ctx, account) + } + // 只有oauth类型账号可以通过API获取usage(有profile scope) if account.CanGetUsage() { var apiResp *ClaudeUsageResponse // 1. 检查 API 缓存(10 分钟) - if cached, ok := apiCacheMap.Load(accountID); ok { + if cached, ok := s.cache.apiCache.Load(accountID); ok { if cache, ok := cached.(*apiUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL { apiResp = cache.response } @@ -172,7 +240,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U return nil, err } // 缓存 API 响应 - apiCacheMap.Store(accountID, &apiUsageCache{ + s.cache.apiCache.Store(accountID, &apiUsageCache{ response: apiResp, timestamp: time.Now(), }) @@ -224,12 +292,70 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou totals := geminiAggregateUsage(stats) resetAt := geminiDailyResetTime(now) - usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost, now) - usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost, now) + usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost) + usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost) return usage, nil } +// getAntigravityUsage 获取 Antigravity 账户额度 +func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *Account) (*UsageInfo, error) { + if s.antigravityQuotaFetcher == nil || !s.antigravityQuotaFetcher.CanFetch(account) { + now := time.Now() + return &UsageInfo{UpdatedAt: &now}, nil + } + + // Ensure project_id is stable for quota queries. + if strings.TrimSpace(account.GetCredential("project_id")) == "" { + projectID := antigravity.GenerateMockProjectID() + if account.Credentials == nil { + account.Credentials = map[string]any{} + } + account.Credentials["project_id"] = projectID + if s.accountRepo != nil { + _, err := s.accountRepo.BulkUpdate(ctx, []int64{account.ID}, AccountBulkUpdate{ + Credentials: map[string]any{"project_id": projectID}, + }) + if err != nil { + log.Printf("Failed to persist antigravity project_id for account %d: %v", account.ID, err) + } + } + } + + // 1. 检查缓存(10 分钟) + if cached, ok := s.cache.antigravityCache.Load(account.ID); ok { + if cache, ok := cached.(*antigravityUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL { + // 重新计算 RemainingSeconds + usage := cloneUsageInfo(cache.usageInfo) + if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil { + usage.FiveHour.RemainingSeconds = remainingSecondsUntil(*usage.FiveHour.ResetsAt) + } + return usage, nil + } + } + + // 2. 获取代理 URL + proxyURL, err := s.antigravityQuotaFetcher.GetProxyURL(ctx, account) + if err != nil { + log.Printf("Failed to get proxy URL for account %d: %v", account.ID, err) + proxyURL = "" + } + + // 3. 调用 API 获取额度 + result, err := s.antigravityQuotaFetcher.FetchQuota(ctx, account, proxyURL) + if err != nil { + return nil, fmt.Errorf("fetch antigravity quota failed: %w", err) + } + + // 4. 缓存结果 + s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{ + usageInfo: result.UsageInfo, + timestamp: time.Now(), + }) + + return result.UsageInfo, nil +} + // addWindowStats 为 usage 数据添加窗口期统计 // 使用独立缓存(1 分钟),与 API 缓存分离 func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) { @@ -241,7 +367,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou // 检查窗口统计缓存(1 分钟) var windowStats *WindowStats - if cached, ok := windowStatsCacheMap.Load(account.ID); ok { + if cached, ok := s.cache.windowStatsCache.Load(account.ID); ok { if cache, ok := cached.(*windowStatsCache); ok && time.Since(cache.timestamp) < windowStatsCacheTTL { windowStats = cache.stats } @@ -269,7 +395,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou } // 缓存窗口统计(1 分钟) - windowStatsCacheMap.Store(account.ID, &windowStatsCache{ + s.cache.windowStatsCache.Store(account.ID, &windowStatsCache{ stats: windowStats, timestamp: time.Now(), }) @@ -342,12 +468,12 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA // 5小时窗口 - 始终创建对象(即使 ResetsAt 为空) info.FiveHour = &UsageProgress{ - Utilization: resp.FiveHour.Utilization, + Utilization: clampFloat64(resp.FiveHour.Utilization, 0, 100), } if resp.FiveHour.ResetsAt != "" { if fiveHourReset, err := parseTime(resp.FiveHour.ResetsAt); err == nil { info.FiveHour.ResetsAt = &fiveHourReset - info.FiveHour.RemainingSeconds = int(time.Until(fiveHourReset).Seconds()) + info.FiveHour.RemainingSeconds = remainingSecondsUntil(fiveHourReset) } else { log.Printf("Failed to parse FiveHour.ResetsAt: %s, error: %v", resp.FiveHour.ResetsAt, err) } @@ -357,14 +483,14 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA if resp.SevenDay.ResetsAt != "" { if sevenDayReset, err := parseTime(resp.SevenDay.ResetsAt); err == nil { info.SevenDay = &UsageProgress{ - Utilization: resp.SevenDay.Utilization, + Utilization: clampFloat64(resp.SevenDay.Utilization, 0, 100), ResetsAt: &sevenDayReset, - RemainingSeconds: int(time.Until(sevenDayReset).Seconds()), + RemainingSeconds: remainingSecondsUntil(sevenDayReset), } } else { log.Printf("Failed to parse SevenDay.ResetsAt: %s, error: %v", resp.SevenDay.ResetsAt, err) info.SevenDay = &UsageProgress{ - Utilization: resp.SevenDay.Utilization, + Utilization: clampFloat64(resp.SevenDay.Utilization, 0, 100), } } } @@ -373,14 +499,14 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA if resp.SevenDaySonnet.ResetsAt != "" { if sonnetReset, err := parseTime(resp.SevenDaySonnet.ResetsAt); err == nil { info.SevenDaySonnet = &UsageProgress{ - Utilization: resp.SevenDaySonnet.Utilization, + Utilization: clampFloat64(resp.SevenDaySonnet.Utilization, 0, 100), ResetsAt: &sonnetReset, - RemainingSeconds: int(time.Until(sonnetReset).Seconds()), + RemainingSeconds: remainingSecondsUntil(sonnetReset), } } else { log.Printf("Failed to parse SevenDaySonnet.ResetsAt: %s, error: %v", resp.SevenDaySonnet.ResetsAt, err) info.SevenDaySonnet = &UsageProgress{ - Utilization: resp.SevenDaySonnet.Utilization, + Utilization: clampFloat64(resp.SevenDaySonnet.Utilization, 0, 100), } } } @@ -394,10 +520,7 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn // 如果有session_window信息 if account.SessionWindowEnd != nil { - remaining := int(time.Until(*account.SessionWindowEnd).Seconds()) - if remaining < 0 { - remaining = 0 - } + remaining := remainingSecondsUntil(*account.SessionWindowEnd) // 根据状态估算使用率 (百分比形式,100 = 100%) var utilization float64 @@ -409,6 +532,7 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn default: utilization = 0.0 } + utilization = clampFloat64(utilization, 0, 100) info.FiveHour = &UsageProgress{ Utilization: utilization, @@ -427,15 +551,12 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn return info } -func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64, now time.Time) *UsageProgress { +func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64) *UsageProgress { if limit <= 0 { return nil } - utilization := (float64(used) / float64(limit)) * 100 - remainingSeconds := int(resetAt.Sub(now).Seconds()) - if remainingSeconds < 0 { - remainingSeconds = 0 - } + utilization := clampFloat64((float64(used)/float64(limit))*100, 0, 100) + remainingSeconds := remainingSecondsUntil(resetAt) resetCopy := resetAt return &UsageProgress{ Utilization: utilization, @@ -448,3 +569,47 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64 }, } } + +func cloneUsageInfo(src *UsageInfo) *UsageInfo { + if src == nil { + return nil + } + dst := *src + if src.UpdatedAt != nil { + t := *src.UpdatedAt + dst.UpdatedAt = &t + } + dst.FiveHour = cloneUsageProgress(src.FiveHour) + dst.SevenDay = cloneUsageProgress(src.SevenDay) + dst.SevenDaySonnet = cloneUsageProgress(src.SevenDaySonnet) + dst.GeminiProDaily = cloneUsageProgress(src.GeminiProDaily) + dst.GeminiFlashDaily = cloneUsageProgress(src.GeminiFlashDaily) + if src.AntigravityQuota != nil { + dst.AntigravityQuota = make(map[string]*AntigravityModelQuota, len(src.AntigravityQuota)) + for k, v := range src.AntigravityQuota { + if v == nil { + dst.AntigravityQuota[k] = nil + continue + } + copyVal := *v + dst.AntigravityQuota[k] = ©Val + } + } + return &dst +} + +func cloneUsageProgress(src *UsageProgress) *UsageProgress { + if src == nil { + return nil + } + dst := *src + if src.ResetsAt != nil { + t := *src.ResetsAt + dst.ResetsAt = &t + } + if src.WindowStats != nil { + statsCopy := *src.WindowStats + dst.WindowStats = &statsCopy + } + return &dst +} diff --git a/backend/internal/service/concurrency_service.go b/backend/internal/service/concurrency_service.go index 8b0ad94c..65ef16db 100644 --- a/backend/internal/service/concurrency_service.go +++ b/backend/internal/service/concurrency_service.go @@ -32,7 +32,6 @@ type ConcurrencyCache interface { // 等待队列计数(只在首次创建时设置 TTL) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) DecrementWaitCount(ctx context.Context, userID int64) error - GetTotalWaitCount(ctx context.Context) (int, error) // 批量负载查询(只读) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) @@ -201,14 +200,6 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6 } } -// GetTotalWaitCount returns the total wait queue depth across users. -func (s *ConcurrencyService) GetTotalWaitCount(ctx context.Context) (int, error) { - if s.cache == nil { - return 0, nil - } - return s.cache.GetTotalWaitCount(ctx) -} - // IncrementAccountWaitCount increments the wait queue counter for an account. func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { if s.cache == nil { diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 806d2aef..0c8989fe 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -136,6 +136,12 @@ func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int6 func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error { return nil } +func (m *mockAccountRepoForPlatform) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { + return nil +} +func (m *mockAccountRepoForPlatform) ClearTempUnschedulable(ctx context.Context, id int64) error { + return nil +} func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int64) error { return nil } @@ -276,7 +282,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference( repo := &mockAccountRepoForPlatform{ accounts: []Account{ - {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey}, + {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey}, {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth}, }, accountsByID: map[int64]*Account{}, @@ -617,7 +623,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) { repo := &mockAccountRepoForPlatform{ accounts: []Account{ - {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey}, + {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey}, {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth}, }, accountsByID: map[int64]*Account{}, diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go index 32e9ffba..4cf40199 100644 --- a/backend/internal/service/gateway_request.go +++ b/backend/internal/service/gateway_request.go @@ -75,9 +75,19 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) { // FilterThinkingBlocks removes thinking blocks from request body // Returns filtered body or original body if filtering fails (fail-safe) // This prevents 400 errors from invalid thinking block signatures +// +// Strategy: +// - When thinking.type != "enabled": Remove all thinking blocks +// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures +// (blocks with missing/empty/dummy signatures that would cause 400 errors) func FilterThinkingBlocks(body []byte) []byte { // Fast path: if body doesn't contain "thinking", skip parsing - if !bytes.Contains(body, []byte("thinking")) { + if !bytes.Contains(body, []byte(`"type":"thinking"`)) && + !bytes.Contains(body, []byte(`"type": "thinking"`)) && + !bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) && + !bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) && + !bytes.Contains(body, []byte(`"thinking":`)) && + !bytes.Contains(body, []byte(`"thinking" :`)) { return body } @@ -86,6 +96,14 @@ func FilterThinkingBlocks(body []byte) []byte { return body // Return original on parse error } + // Check if thinking is enabled + thinkingEnabled := false + if thinking, ok := req["thinking"].(map[string]any); ok { + if thinkType, ok := thinking["type"].(string); ok && thinkType == "enabled" { + thinkingEnabled = true + } + } + messages, ok := req["messages"].([]any) if !ok { return body // No messages array @@ -98,6 +116,7 @@ func FilterThinkingBlocks(body []byte) []byte { continue } + role, _ := msgMap["role"].(string) content, ok := msgMap["content"].([]any) if !ok { continue @@ -106,6 +125,7 @@ func FilterThinkingBlocks(body []byte) []byte { // Filter thinking blocks from content array newContent := make([]any, 0, len(content)) filteredThisMessage := false + for _, block := range content { blockMap, ok := block.(map[string]any) if !ok { @@ -114,22 +134,34 @@ func FilterThinkingBlocks(body []byte) []byte { } blockType, _ := blockMap["type"].(string) - // Explicit Anthropic-style thinking block: {"type":"thinking", ...} - if blockType == "thinking" { + + // Handle thinking/redacted_thinking blocks + if blockType == "thinking" || blockType == "redacted_thinking" { + // When thinking is enabled and this is an assistant message, + // only keep thinking blocks with valid (non-empty, non-dummy) signatures + if thinkingEnabled && role == "assistant" { + signature, _ := blockMap["signature"].(string) + // Keep blocks with valid signatures, remove those without + if signature != "" && signature != "skip_thought_signature_validator" { + newContent = append(newContent, block) + continue + } + } + filtered = true filteredThisMessage = true - continue // Skip thinking blocks + continue } // Some clients send the "thinking" object without a "type" discriminator. - // Vertex/Claude still expects a signature for any thinking block, so we drop it. // We intentionally do not drop other typed blocks (e.g. tool_use) that might // legitimately contain a "thinking" key inside their payload. if blockType == "" { - if _, hasThinking := blockMap["thinking"]; hasThinking { + if thinkingContent, hasThinking := blockMap["thinking"]; hasThinking { + _ = thinkingContent filtered = true filteredThisMessage = true - continue // Skip thinking blocks + continue } } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 456cf81d..6d652fa6 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -905,7 +905,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) ( case AccountTypeOAuth, AccountTypeSetupToken: // Both oauth and setup-token use OAuth token flow return s.getOAuthToken(ctx, account) - case AccountTypeAPIKey: + case AccountTypeApiKey: apiKey := account.GetCredential("api_key") if apiKey == "" { return "", "", errors.New("api_key not found in credentials") @@ -976,7 +976,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // 应用模型映射(仅对apikey类型账号) originalModel := reqModel - if account.Type == AccountTypeAPIKey { + if account.Type == AccountTypeApiKey { mappedModel := account.GetMappedModel(reqModel) if mappedModel != reqModel { // 替换请求体中的模型名 @@ -1013,8 +1013,37 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return nil, fmt.Errorf("upstream request failed: %w", err) } - // 检查是否需要重试 - if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { + // 优先检测thinking block签名错误(400)并重试一次 + if resp.StatusCode == 400 { + respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + if readErr == nil { + _ = resp.Body.Close() + + if s.isThinkingBlockSignatureError(respBody) { + log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID) + + // 过滤thinking blocks并重试 + filteredBody := FilterThinkingBlocks(body) + retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel) + if buildErr == nil { + retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency) + if retryErr == nil { + // 使用重试后的响应,继续后续处理 + resp = retryResp + break + } + } + // 重试失败,恢复原始响应体继续处理 + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + break + } + // 不是thinking签名错误,恢复响应体 + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + } + } + + // 检查是否需要通用重试(排除400,因为400已经在上面特殊处理过了) + if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { if attempt < maxRetries { log.Printf("Account %d: upstream error %d, retry %d/%d after %v", account.ID, resp.StatusCode, attempt, maxRetries, retryDelay) @@ -1047,13 +1076,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } // 处理错误响应(不可重试的错误) - if resp.StatusCode >= 400 { - // 可选:对部分 400 触发 failover(默认关闭以保持语义) - if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 { - respBody, readErr := io.ReadAll(resp.Body) - if readErr != nil { - // ReadAll failed, fall back to normal error handling without consuming the stream - return s.handleErrorResponse(ctx, resp, c, account) + if resp.StatusCode >= 400 { + // 可选:对部分 400 触发 failover(默认关闭以保持语义) + if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 { + respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + if readErr != nil { + // ReadAll failed, fall back to normal error handling without consuming the stream + return s.handleErrorResponse(ctx, resp, c, account) } _ = resp.Body.Close() resp.Body = io.NopCloser(bytes.NewReader(respBody)) @@ -1110,7 +1139,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) { // 确定目标URL targetURL := claudeAPIURL - if account.Type == AccountTypeAPIKey { + if account.Type == AccountTypeApiKey { baseURL := account.GetBaseURL() targetURL = baseURL + "/v1/messages" } @@ -1136,10 +1165,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } } - // Filter thinking blocks from request body (prevents 400 errors from missing/invalid signatures). - // We apply this for the main /v1/messages path as well as count_tokens. - body = FilterThinkingBlocks(body) - req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) if err != nil { return nil, err @@ -1178,10 +1203,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // 处理anthropic-beta header(OAuth账号需要特殊处理) if tokenType == "oauth" { req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) - } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { + } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" { // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) if requestNeedsBetaFeatures(body) { - if beta := defaultAPIKeyBetaHeader(body); beta != "" { + if beta := defaultApiKeyBetaHeader(body); beta != "" { req.Header.Set("anthropic-beta", beta) } } @@ -1248,12 +1273,12 @@ func requestNeedsBetaFeatures(body []byte) bool { return false } -func defaultAPIKeyBetaHeader(body []byte) string { +func defaultApiKeyBetaHeader(body []byte) string { modelID := gjson.GetBytes(body, "model").String() if strings.Contains(strings.ToLower(modelID), "haiku") { - return claude.APIKeyHaikuBetaHeader + return claude.ApiKeyHaikuBetaHeader } - return claude.APIKeyBetaHeader + return claude.ApiKeyBetaHeader } func truncateForLog(b []byte, maxBytes int) string { @@ -1270,6 +1295,20 @@ func truncateForLog(b []byte, maxBytes int) string { return s } +// isThinkingBlockSignatureError 检测是否是thinking block签名错误 +// 这类错误可以通过过滤thinking blocks并重试来解决 +func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool { + msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) + if msg == "" { + return false + } + + // 检测thinking block签名相关的错误 + // 例如: "Invalid `signature` in `thinking` block" + return (strings.Contains(msg, "thinking") || strings.Contains(msg, "thought")) && + strings.Contains(msg, "signature") +} + func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool { // 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。 // 默认保守:无法识别则不切换。 @@ -1318,7 +1357,13 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res body, _ := io.ReadAll(resp.Body) // 处理上游错误,标记账号状态 - s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) + shouldDisable := false + if s.rateLimitService != nil { + shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) + } + if shouldDisable { + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + } // 根据状态码返回适当的自定义错误响应(不透传上游详细信息) var errType, errMsg string @@ -1630,7 +1675,7 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo // RecordUsageInput 记录使用量的输入参数 type RecordUsageInput struct { Result *ForwardResult - APIKey *APIKey + ApiKey *ApiKey User *User Account *Account Subscription *UserSubscription // 可选:订阅信息 @@ -1639,7 +1684,7 @@ type RecordUsageInput struct { // RecordUsage 记录使用量并扣费(或更新订阅用量) func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error { result := input.Result - apiKey := input.APIKey + apiKey := input.ApiKey user := input.User account := input.Account subscription := input.Subscription @@ -1676,7 +1721,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu durationMs := int(result.Duration.Milliseconds()) usageLog := &UsageLog{ UserID: user.ID, - APIKeyID: apiKey.ID, + ApiKeyID: apiKey.ID, AccountID: account.ID, RequestID: result.RequestID, Model: result.Model, @@ -1754,15 +1799,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, body := parsed.Body reqModel := parsed.Model - // Antigravity 账户不支持 count_tokens 转发,返回估算值 - // 参考 Antigravity-Manager 和 proxycast 实现 + // Antigravity 账户不支持 count_tokens 转发,直接返回空值 if account.Platform == PlatformAntigravity { - c.JSON(http.StatusOK, gin.H{"input_tokens": 100}) + c.JSON(http.StatusOK, gin.H{"input_tokens": 0}) return nil } // 应用模型映射(仅对 apikey 类型账号) - if account.Type == AccountTypeAPIKey { + if account.Type == AccountTypeApiKey { if reqModel != "" { mappedModel := account.GetMappedModel(reqModel) if mappedModel != reqModel { @@ -1848,7 +1892,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) { // 确定目标 URL targetURL := claudeAPICountTokensURL - if account.Type == AccountTypeAPIKey { + if account.Type == AccountTypeApiKey { baseURL := account.GetBaseURL() targetURL = baseURL + "/v1/messages/count_tokens" } @@ -1866,9 +1910,6 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } } - // Filter thinking blocks from request body (prevents 400 errors from invalid signatures) - body = FilterThinkingBlocks(body) - req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) if err != nil { return nil, err @@ -1910,10 +1951,10 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con // OAuth 账号:处理 anthropic-beta header if tokenType == "oauth" { req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) - } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { + } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" { // API-key:与 messages 同步的按需 beta 注入(默认关闭) if requestNeedsBetaFeatures(body) { - if beta := defaultAPIKeyBetaHeader(body); beta != "" { + if beta := defaultApiKeyBetaHeader(body); beta != "" { req.Header.Set("anthropic-beta", beta) } }