perf(gateway): 优化负载感知调度

主要改进:
- 优化负载感知调度的准确性和响应速度
- 将 AccountUsageService 的包级缓存改为依赖注入
- 修复 SSE/JSON 转义和 nil 安全问题
- 恢复 Google One 功能兼容性
This commit is contained in:
ianshaw
2026-01-03 06:32:51 -08:00
parent 26106eb0ac
commit acb718d355
7 changed files with 369 additions and 310 deletions

View File

@@ -27,14 +27,8 @@ const (
accountSlotKeyPrefix = "concurrency:account:"
// 格式: concurrency:user:{userID}
userSlotKeyPrefix = "concurrency:user:"
// Wait queue keys (global structures)
// - total: integer total queue depth across all users
// - updated: sorted set of userID -> lastUpdateUnixSec (for TTL cleanup)
// - counts: hash of userID -> current wait count
waitQueueTotalKey = "concurrency:wait:total"
waitQueueUpdatedKey = "concurrency:wait:updated"
waitQueueCountsKey = "concurrency:wait:counts"
// 等待队列计数器格式: concurrency:wait:{userID}
waitQueueKeyPrefix = "concurrency:wait:"
// 账号级等待队列计数器格式: wait:account:{accountID}
accountWaitKeyPrefix = "wait:account:"
@@ -100,55 +94,27 @@ var (
`)
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
// KEYS[1] = total key
// KEYS[2] = updated zset key
// KEYS[3] = counts hash key
// ARGV[1] = userID
// ARGV[2] = maxWait
// ARGV[3] = TTL in seconds
// ARGV[4] = cleanup limit
// KEYS[1] = wait queue key
// ARGV[1] = maxWait
// ARGV[2] = TTL in seconds
incrementWaitScript = redis.NewScript(`
local totalKey = KEYS[1]
local updatedKey = KEYS[2]
local countsKey = KEYS[3]
local userID = ARGV[1]
local maxWait = tonumber(ARGV[2])
local ttl = tonumber(ARGV[3])
local cleanupLimit = tonumber(ARGV[4])
redis.call('SETNX', totalKey, 0)
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
-- Cleanup expired users (bounded)
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
for _, uid in ipairs(expired) do
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
if c > 0 then
redis.call('DECRBY', totalKey, c)
end
redis.call('HDEL', countsKey, uid)
redis.call('ZREM', updatedKey, uid)
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
local current = tonumber(redis.call('HGET', countsKey, userID) or '0')
if current >= maxWait then
if current >= tonumber(ARGV[1]) then
return 0
end
local newVal = current + 1
redis.call('HSET', countsKey, userID, newVal)
redis.call('ZADD', updatedKey, now, userID)
redis.call('INCR', totalKey)
local newVal = redis.call('INCR', KEYS[1])
-- Keep global structures from living forever in totally idle deployments.
local ttlKeep = ttl * 2
redis.call('EXPIRE', totalKey, ttlKeep)
redis.call('EXPIRE', updatedKey, ttlKeep)
redis.call('EXPIRE', countsKey, ttlKeep)
-- Only set TTL on first creation to avoid refreshing zombie data
if newVal == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return 1
`)
@@ -178,111 +144,6 @@ var (
// decrementWaitScript - same as before
decrementWaitScript = redis.NewScript(`
local totalKey = KEYS[1]
local updatedKey = KEYS[2]
local countsKey = KEYS[3]
local userID = ARGV[1]
local ttl = tonumber(ARGV[2])
local cleanupLimit = tonumber(ARGV[3])
redis.call('SETNX', totalKey, 0)
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
-- Cleanup expired users (bounded)
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
for _, uid in ipairs(expired) do
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
if c > 0 then
redis.call('DECRBY', totalKey, c)
end
redis.call('HDEL', countsKey, uid)
redis.call('ZREM', updatedKey, uid)
end
local current = tonumber(redis.call('HGET', countsKey, userID) or '0')
if current <= 0 then
return 1
end
local newVal = current - 1
if newVal <= 0 then
redis.call('HDEL', countsKey, userID)
redis.call('ZREM', updatedKey, userID)
else
redis.call('HSET', countsKey, userID, newVal)
redis.call('ZADD', updatedKey, now, userID)
end
redis.call('DECR', totalKey)
local ttlKeep = ttl * 2
redis.call('EXPIRE', totalKey, ttlKeep)
redis.call('EXPIRE', updatedKey, ttlKeep)
redis.call('EXPIRE', countsKey, ttlKeep)
return 1
`)
// getTotalWaitScript returns the global wait depth with TTL cleanup.
// KEYS[1] = total key
// KEYS[2] = updated zset key
// KEYS[3] = counts hash key
// ARGV[1] = TTL in seconds
// ARGV[2] = cleanup limit
getTotalWaitScript = redis.NewScript(`
local totalKey = KEYS[1]
local updatedKey = KEYS[2]
local countsKey = KEYS[3]
local ttl = tonumber(ARGV[1])
local cleanupLimit = tonumber(ARGV[2])
redis.call('SETNX', totalKey, 0)
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
-- Cleanup expired users (bounded)
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
for _, uid in ipairs(expired) do
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
if c > 0 then
redis.call('DECRBY', totalKey, c)
end
redis.call('HDEL', countsKey, uid)
redis.call('ZREM', updatedKey, uid)
end
-- If totalKey got lost but counts exist (e.g. Redis restart), recompute once.
local total = redis.call('GET', totalKey)
if total == false then
total = 0
local vals = redis.call('HVALS', countsKey)
for _, v in ipairs(vals) do
total = total + tonumber(v)
end
redis.call('SET', totalKey, total)
end
local ttlKeep = ttl * 2
redis.call('EXPIRE', totalKey, ttlKeep)
redis.call('EXPIRE', updatedKey, ttlKeep)
redis.call('EXPIRE', countsKey, ttlKeep)
local result = tonumber(redis.call('GET', totalKey) or '0')
if result < 0 then
result = 0
redis.call('SET', totalKey, 0)
end
return result
`)
// decrementAccountWaitScript - account-level wait queue decrement
decrementAccountWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
@@ -383,9 +244,7 @@ func userSlotKey(userID int64) string {
}
func waitQueueKey(userID int64) string {
// Historical: per-user string keys were used.
// Now we use global structures keyed by userID string.
return strconv.FormatInt(userID, 10)
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
}
func accountWaitKey(accountID int64) string {
@@ -449,16 +308,8 @@ func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64)
// Wait queue operations
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
userKey := waitQueueKey(userID)
result, err := incrementWaitScript.Run(
ctx,
c.rdb,
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
userKey,
maxWait,
c.waitQueueTTLSeconds,
200, // cleanup limit per call
).Int()
key := waitQueueKey(userID)
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
if err != nil {
return false, err
}
@@ -466,35 +317,11 @@ func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64,
}
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
userKey := waitQueueKey(userID)
_, err := decrementWaitScript.Run(
ctx,
c.rdb,
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
userKey,
c.waitQueueTTLSeconds,
200, // cleanup limit per call
).Result()
key := waitQueueKey(userID)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}
func (c *concurrencyCache) GetTotalWaitCount(ctx context.Context) (int, error) {
if c.rdb == nil {
return 0, nil
}
total, err := getTotalWaitScript.Run(
ctx,
c.rdb,
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
c.waitQueueTTLSeconds,
500, // cleanup limit per query (rare)
).Int64()
if err != nil {
return 0, err
}
return int(total), nil
}
// Account wait queue operations
func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
@@ -508,7 +335,7 @@ func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accoun
func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
key := accountWaitKey(accountID)
_, err := decrementAccountWaitScript.Run(ctx, c.rdb, []string{key}).Result()
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}

View File

@@ -158,7 +158,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
userID := int64(20)
userKey := waitQueueKey(userID)
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2)
require.NoError(s.T(), err, "IncrementWaitCount 1")
@@ -172,31 +172,31 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
require.NoError(s.T(), err, "IncrementWaitCount 3")
require.False(s.T(), ok, "expected wait increment over max to fail")
ttl, err := s.rdb.TTL(s.ctx, waitQueueTotalKey).Result()
require.NoError(s.T(), err, "TTL wait total key")
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL*2)
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
require.NoError(s.T(), err, "TTL waitKey")
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
val, err := s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Int()
require.NoError(s.T(), err, "HGET wait queue count")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.Equal(s.T(), 1, val, "expected wait count 1")
total, err := s.rdb.Get(s.ctx, waitQueueTotalKey).Int()
require.NoError(s.T(), err, "GET wait queue total")
require.Equal(s.T(), 1, total, "expected total wait count 1")
}
func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
userID := int64(300)
userKey := waitQueueKey(userID)
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
// Test decrement on non-existent key - should not error and should not create negative value
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key")
// Verify count remains zero / absent.
val, err := s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Int()
require.True(s.T(), errors.Is(err, redis.Nil))
// Verify no key was created or it's not negative
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty")
// Set count to 1, then decrement twice
@@ -210,15 +210,12 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
// Decrement again on 0 - should not go negative
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero")
// Verify per-user count is absent and total is non-negative.
_, err = s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Result()
require.True(s.T(), errors.Is(err, redis.Nil), "expected count field removed on zero")
total, err := s.rdb.Get(s.ctx, waitQueueTotalKey).Int()
// Verify count is 0, not negative
val, err = s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err)
require.NoError(s.T(), err, "Get waitKey after double decrement")
}
require.GreaterOrEqual(s.T(), total, 0, "expected non-negative total wait count")
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
}
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {

View File

@@ -4,9 +4,11 @@ import (
"context"
"fmt"
"log"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
)
@@ -17,11 +19,11 @@ type UsageLogRepository interface {
Delete(ctx context.Context, id int64) error
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
@@ -32,10 +34,10 @@ type UsageLogRepository interface {
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error)
// User dashboard stats
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
@@ -51,7 +53,7 @@ type UsageLogRepository interface {
// Aggregated stats (optimized)
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
@@ -69,13 +71,33 @@ type windowStatsCache struct {
timestamp time.Time
}
var (
apiCacheMap = sync.Map{} // 缓存 API 响应
windowStatsCacheMap = sync.Map{} // 缓存窗口统计
// antigravityUsageCache 缓存 Antigravity 额度数据
type antigravityUsageCache struct {
usageInfo *UsageInfo
timestamp time.Time
}
const (
apiCacheTTL = 10 * time.Minute
windowStatsCacheTTL = 1 * time.Minute
)
// UsageCache 封装账户使用量相关的缓存
type UsageCache struct {
apiCache *sync.Map // accountID -> *apiUsageCache
windowStatsCache *sync.Map // accountID -> *windowStatsCache
antigravityCache *sync.Map // accountID -> *antigravityUsageCache
}
// NewUsageCache 创建 UsageCache 实例
func NewUsageCache() *UsageCache {
return &UsageCache{
apiCache: &sync.Map{},
antigravityCache: &sync.Map{},
windowStatsCache: &sync.Map{},
}
}
// WindowStats 窗口期统计
type WindowStats struct {
Requests int64 `json:"requests"`
@@ -91,6 +113,12 @@ type UsageProgress struct {
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
}
// AntigravityModelQuota Antigravity 单个模型的配额信息
type AntigravityModelQuota struct {
Utilization int `json:"utilization"` // 使用率 0-100
ResetTime string `json:"reset_time"` // 重置时间 ISO8601
}
// UsageInfo 账号使用量信息
type UsageInfo struct {
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
@@ -99,6 +127,9 @@ type UsageInfo struct {
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
// Antigravity 多模型配额
AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"`
}
// ClaudeUsageResponse Anthropic API返回的usage结构
@@ -124,19 +155,51 @@ type ClaudeUsageFetcher interface {
// AccountUsageService 账号使用量查询服务
type AccountUsageService struct {
accountRepo AccountRepository
usageLogRepo UsageLogRepository
usageFetcher ClaudeUsageFetcher
geminiQuotaService *GeminiQuotaService
accountRepo AccountRepository
usageLogRepo UsageLogRepository
usageFetcher ClaudeUsageFetcher
geminiQuotaService *GeminiQuotaService
antigravityQuotaFetcher QuotaFetcher
cache *UsageCache
}
// NewAccountUsageService 创建AccountUsageService实例
func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher, geminiQuotaService *GeminiQuotaService) *AccountUsageService {
func NewAccountUsageService(
accountRepo AccountRepository,
usageLogRepo UsageLogRepository,
usageFetcher ClaudeUsageFetcher,
geminiQuotaService *GeminiQuotaService,
antigravityQuotaFetcher *AntigravityQuotaFetcher,
cache *UsageCache,
) *AccountUsageService {
if cache == nil {
cache = &UsageCache{
apiCache: &sync.Map{},
antigravityCache: &sync.Map{},
windowStatsCache: &sync.Map{},
}
}
if cache.apiCache == nil {
cache.apiCache = &sync.Map{}
}
if cache.antigravityCache == nil {
cache.antigravityCache = &sync.Map{}
}
if cache.windowStatsCache == nil {
cache.windowStatsCache = &sync.Map{}
}
var quotaFetcher QuotaFetcher
if antigravityQuotaFetcher != nil {
quotaFetcher = antigravityQuotaFetcher
}
return &AccountUsageService{
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
usageFetcher: usageFetcher,
geminiQuotaService: geminiQuotaService,
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
usageFetcher: usageFetcher,
geminiQuotaService: geminiQuotaService,
antigravityQuotaFetcher: quotaFetcher,
cache: cache,
}
}
@@ -154,12 +217,17 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return s.getGeminiUsage(ctx, account)
}
// Antigravity 平台:使用 AntigravityQuotaFetcher 获取额度
if account.Platform == PlatformAntigravity {
return s.getAntigravityUsage(ctx, account)
}
// 只有oauth类型账号可以通过API获取usage有profile scope
if account.CanGetUsage() {
var apiResp *ClaudeUsageResponse
// 1. 检查 API 缓存10 分钟)
if cached, ok := apiCacheMap.Load(accountID); ok {
if cached, ok := s.cache.apiCache.Load(accountID); ok {
if cache, ok := cached.(*apiUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
apiResp = cache.response
}
@@ -172,7 +240,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return nil, err
}
// 缓存 API 响应
apiCacheMap.Store(accountID, &apiUsageCache{
s.cache.apiCache.Store(accountID, &apiUsageCache{
response: apiResp,
timestamp: time.Now(),
})
@@ -224,12 +292,70 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
totals := geminiAggregateUsage(stats)
resetAt := geminiDailyResetTime(now)
usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost, now)
usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost, now)
usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost)
usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost)
return usage, nil
}
// getAntigravityUsage 获取 Antigravity 账户额度
func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
if s.antigravityQuotaFetcher == nil || !s.antigravityQuotaFetcher.CanFetch(account) {
now := time.Now()
return &UsageInfo{UpdatedAt: &now}, nil
}
// Ensure project_id is stable for quota queries.
if strings.TrimSpace(account.GetCredential("project_id")) == "" {
projectID := antigravity.GenerateMockProjectID()
if account.Credentials == nil {
account.Credentials = map[string]any{}
}
account.Credentials["project_id"] = projectID
if s.accountRepo != nil {
_, err := s.accountRepo.BulkUpdate(ctx, []int64{account.ID}, AccountBulkUpdate{
Credentials: map[string]any{"project_id": projectID},
})
if err != nil {
log.Printf("Failed to persist antigravity project_id for account %d: %v", account.ID, err)
}
}
}
// 1. 检查缓存10 分钟)
if cached, ok := s.cache.antigravityCache.Load(account.ID); ok {
if cache, ok := cached.(*antigravityUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
// 重新计算 RemainingSeconds
usage := cloneUsageInfo(cache.usageInfo)
if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil {
usage.FiveHour.RemainingSeconds = remainingSecondsUntil(*usage.FiveHour.ResetsAt)
}
return usage, nil
}
}
// 2. 获取代理 URL
proxyURL, err := s.antigravityQuotaFetcher.GetProxyURL(ctx, account)
if err != nil {
log.Printf("Failed to get proxy URL for account %d: %v", account.ID, err)
proxyURL = ""
}
// 3. 调用 API 获取额度
result, err := s.antigravityQuotaFetcher.FetchQuota(ctx, account, proxyURL)
if err != nil {
return nil, fmt.Errorf("fetch antigravity quota failed: %w", err)
}
// 4. 缓存结果
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
usageInfo: result.UsageInfo,
timestamp: time.Now(),
})
return result.UsageInfo, nil
}
// addWindowStats 为 usage 数据添加窗口期统计
// 使用独立缓存1 分钟),与 API 缓存分离
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) {
@@ -241,7 +367,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
// 检查窗口统计缓存1 分钟)
var windowStats *WindowStats
if cached, ok := windowStatsCacheMap.Load(account.ID); ok {
if cached, ok := s.cache.windowStatsCache.Load(account.ID); ok {
if cache, ok := cached.(*windowStatsCache); ok && time.Since(cache.timestamp) < windowStatsCacheTTL {
windowStats = cache.stats
}
@@ -269,7 +395,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
}
// 缓存窗口统计1 分钟)
windowStatsCacheMap.Store(account.ID, &windowStatsCache{
s.cache.windowStatsCache.Store(account.ID, &windowStatsCache{
stats: windowStats,
timestamp: time.Now(),
})
@@ -342,12 +468,12 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA
// 5小时窗口 - 始终创建对象(即使 ResetsAt 为空)
info.FiveHour = &UsageProgress{
Utilization: resp.FiveHour.Utilization,
Utilization: clampFloat64(resp.FiveHour.Utilization, 0, 100),
}
if resp.FiveHour.ResetsAt != "" {
if fiveHourReset, err := parseTime(resp.FiveHour.ResetsAt); err == nil {
info.FiveHour.ResetsAt = &fiveHourReset
info.FiveHour.RemainingSeconds = int(time.Until(fiveHourReset).Seconds())
info.FiveHour.RemainingSeconds = remainingSecondsUntil(fiveHourReset)
} else {
log.Printf("Failed to parse FiveHour.ResetsAt: %s, error: %v", resp.FiveHour.ResetsAt, err)
}
@@ -357,14 +483,14 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA
if resp.SevenDay.ResetsAt != "" {
if sevenDayReset, err := parseTime(resp.SevenDay.ResetsAt); err == nil {
info.SevenDay = &UsageProgress{
Utilization: resp.SevenDay.Utilization,
Utilization: clampFloat64(resp.SevenDay.Utilization, 0, 100),
ResetsAt: &sevenDayReset,
RemainingSeconds: int(time.Until(sevenDayReset).Seconds()),
RemainingSeconds: remainingSecondsUntil(sevenDayReset),
}
} else {
log.Printf("Failed to parse SevenDay.ResetsAt: %s, error: %v", resp.SevenDay.ResetsAt, err)
info.SevenDay = &UsageProgress{
Utilization: resp.SevenDay.Utilization,
Utilization: clampFloat64(resp.SevenDay.Utilization, 0, 100),
}
}
}
@@ -373,14 +499,14 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA
if resp.SevenDaySonnet.ResetsAt != "" {
if sonnetReset, err := parseTime(resp.SevenDaySonnet.ResetsAt); err == nil {
info.SevenDaySonnet = &UsageProgress{
Utilization: resp.SevenDaySonnet.Utilization,
Utilization: clampFloat64(resp.SevenDaySonnet.Utilization, 0, 100),
ResetsAt: &sonnetReset,
RemainingSeconds: int(time.Until(sonnetReset).Seconds()),
RemainingSeconds: remainingSecondsUntil(sonnetReset),
}
} else {
log.Printf("Failed to parse SevenDaySonnet.ResetsAt: %s, error: %v", resp.SevenDaySonnet.ResetsAt, err)
info.SevenDaySonnet = &UsageProgress{
Utilization: resp.SevenDaySonnet.Utilization,
Utilization: clampFloat64(resp.SevenDaySonnet.Utilization, 0, 100),
}
}
}
@@ -394,10 +520,7 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
// 如果有session_window信息
if account.SessionWindowEnd != nil {
remaining := int(time.Until(*account.SessionWindowEnd).Seconds())
if remaining < 0 {
remaining = 0
}
remaining := remainingSecondsUntil(*account.SessionWindowEnd)
// 根据状态估算使用率 (百分比形式100 = 100%)
var utilization float64
@@ -409,6 +532,7 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
default:
utilization = 0.0
}
utilization = clampFloat64(utilization, 0, 100)
info.FiveHour = &UsageProgress{
Utilization: utilization,
@@ -427,15 +551,12 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
return info
}
func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64, now time.Time) *UsageProgress {
func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64) *UsageProgress {
if limit <= 0 {
return nil
}
utilization := (float64(used) / float64(limit)) * 100
remainingSeconds := int(resetAt.Sub(now).Seconds())
if remainingSeconds < 0 {
remainingSeconds = 0
}
utilization := clampFloat64((float64(used)/float64(limit))*100, 0, 100)
remainingSeconds := remainingSecondsUntil(resetAt)
resetCopy := resetAt
return &UsageProgress{
Utilization: utilization,
@@ -448,3 +569,47 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64
},
}
}
func cloneUsageInfo(src *UsageInfo) *UsageInfo {
if src == nil {
return nil
}
dst := *src
if src.UpdatedAt != nil {
t := *src.UpdatedAt
dst.UpdatedAt = &t
}
dst.FiveHour = cloneUsageProgress(src.FiveHour)
dst.SevenDay = cloneUsageProgress(src.SevenDay)
dst.SevenDaySonnet = cloneUsageProgress(src.SevenDaySonnet)
dst.GeminiProDaily = cloneUsageProgress(src.GeminiProDaily)
dst.GeminiFlashDaily = cloneUsageProgress(src.GeminiFlashDaily)
if src.AntigravityQuota != nil {
dst.AntigravityQuota = make(map[string]*AntigravityModelQuota, len(src.AntigravityQuota))
for k, v := range src.AntigravityQuota {
if v == nil {
dst.AntigravityQuota[k] = nil
continue
}
copyVal := *v
dst.AntigravityQuota[k] = &copyVal
}
}
return &dst
}
func cloneUsageProgress(src *UsageProgress) *UsageProgress {
if src == nil {
return nil
}
dst := *src
if src.ResetsAt != nil {
t := *src.ResetsAt
dst.ResetsAt = &t
}
if src.WindowStats != nil {
statsCopy := *src.WindowStats
dst.WindowStats = &statsCopy
}
return &dst
}

View File

@@ -32,7 +32,6 @@ type ConcurrencyCache interface {
// 等待队列计数(只在首次创建时设置 TTL
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
DecrementWaitCount(ctx context.Context, userID int64) error
GetTotalWaitCount(ctx context.Context) (int, error)
// 批量负载查询(只读)
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
@@ -201,14 +200,6 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6
}
}
// GetTotalWaitCount returns the total wait queue depth across users.
func (s *ConcurrencyService) GetTotalWaitCount(ctx context.Context) (int, error) {
if s.cache == nil {
return 0, nil
}
return s.cache.GetTotalWaitCount(ctx)
}
// IncrementAccountWaitCount increments the wait queue counter for an account.
func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
if s.cache == nil {

View File

@@ -136,6 +136,12 @@ func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int6
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
return nil
}
func (m *mockAccountRepoForPlatform) ClearTempUnschedulable(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int64) error {
return nil
}
@@ -276,7 +282,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},
@@ -617,7 +623,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},

View File

@@ -75,9 +75,19 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
// FilterThinkingBlocks removes thinking blocks from request body
// Returns filtered body or original body if filtering fails (fail-safe)
// This prevents 400 errors from invalid thinking block signatures
//
// Strategy:
// - When thinking.type != "enabled": Remove all thinking blocks
// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures
// (blocks with missing/empty/dummy signatures that would cause 400 errors)
func FilterThinkingBlocks(body []byte) []byte {
// Fast path: if body doesn't contain "thinking", skip parsing
if !bytes.Contains(body, []byte("thinking")) {
if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
!bytes.Contains(body, []byte(`"type": "thinking"`)) &&
!bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) &&
!bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) &&
!bytes.Contains(body, []byte(`"thinking":`)) &&
!bytes.Contains(body, []byte(`"thinking" :`)) {
return body
}
@@ -86,6 +96,14 @@ func FilterThinkingBlocks(body []byte) []byte {
return body // Return original on parse error
}
// Check if thinking is enabled
thinkingEnabled := false
if thinking, ok := req["thinking"].(map[string]any); ok {
if thinkType, ok := thinking["type"].(string); ok && thinkType == "enabled" {
thinkingEnabled = true
}
}
messages, ok := req["messages"].([]any)
if !ok {
return body // No messages array
@@ -98,6 +116,7 @@ func FilterThinkingBlocks(body []byte) []byte {
continue
}
role, _ := msgMap["role"].(string)
content, ok := msgMap["content"].([]any)
if !ok {
continue
@@ -106,6 +125,7 @@ func FilterThinkingBlocks(body []byte) []byte {
// Filter thinking blocks from content array
newContent := make([]any, 0, len(content))
filteredThisMessage := false
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
@@ -114,22 +134,34 @@ func FilterThinkingBlocks(body []byte) []byte {
}
blockType, _ := blockMap["type"].(string)
// Explicit Anthropic-style thinking block: {"type":"thinking", ...}
if blockType == "thinking" {
// Handle thinking/redacted_thinking blocks
if blockType == "thinking" || blockType == "redacted_thinking" {
// When thinking is enabled and this is an assistant message,
// only keep thinking blocks with valid (non-empty, non-dummy) signatures
if thinkingEnabled && role == "assistant" {
signature, _ := blockMap["signature"].(string)
// Keep blocks with valid signatures, remove those without
if signature != "" && signature != "skip_thought_signature_validator" {
newContent = append(newContent, block)
continue
}
}
filtered = true
filteredThisMessage = true
continue // Skip thinking blocks
continue
}
// Some clients send the "thinking" object without a "type" discriminator.
// Vertex/Claude still expects a signature for any thinking block, so we drop it.
// We intentionally do not drop other typed blocks (e.g. tool_use) that might
// legitimately contain a "thinking" key inside their payload.
if blockType == "" {
if _, hasThinking := blockMap["thinking"]; hasThinking {
if thinkingContent, hasThinking := blockMap["thinking"]; hasThinking {
_ = thinkingContent
filtered = true
filteredThisMessage = true
continue // Skip thinking blocks
continue
}
}

View File

@@ -905,7 +905,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
case AccountTypeOAuth, AccountTypeSetupToken:
// Both oauth and setup-token use OAuth token flow
return s.getOAuthToken(ctx, account)
case AccountTypeAPIKey:
case AccountTypeApiKey:
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return "", "", errors.New("api_key not found in credentials")
@@ -976,7 +976,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 应用模型映射仅对apikey类型账号
originalModel := reqModel
if account.Type == AccountTypeAPIKey {
if account.Type == AccountTypeApiKey {
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
// 替换请求体中的模型名
@@ -1013,8 +1013,37 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return nil, fmt.Errorf("upstream request failed: %w", err)
}
// 检查是否需要重试
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
// 优先检测thinking block签名错误400并重试一次
if resp.StatusCode == 400 {
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if readErr == nil {
_ = resp.Body.Close()
if s.isThinkingBlockSignatureError(respBody) {
log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID)
// 过滤thinking blocks并重试
filteredBody := FilterThinkingBlocks(body)
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil {
// 使用重试后的响应,继续后续处理
resp = retryResp
break
}
}
// 重试失败,恢复原始响应体继续处理
resp.Body = io.NopCloser(bytes.NewReader(respBody))
break
}
// 不是thinking签名错误恢复响应体
resp.Body = io.NopCloser(bytes.NewReader(respBody))
}
}
// 检查是否需要通用重试排除400因为400已经在上面特殊处理过了
if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
if attempt < maxRetries {
log.Printf("Account %d: upstream error %d, retry %d/%d after %v",
account.ID, resp.StatusCode, attempt, maxRetries, retryDelay)
@@ -1047,13 +1076,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
// 处理错误响应(不可重试的错误)
if resp.StatusCode >= 400 {
// 可选:对部分 400 触发 failover默认关闭以保持语义
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
// ReadAll failed, fall back to normal error handling without consuming the stream
return s.handleErrorResponse(ctx, resp, c, account)
if resp.StatusCode >= 400 {
// 可选:对部分 400 触发 failover默认关闭以保持语义
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if readErr != nil {
// ReadAll failed, fall back to normal error handling without consuming the stream
return s.handleErrorResponse(ctx, resp, c, account)
}
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
@@ -1110,7 +1139,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标URL
targetURL := claudeAPIURL
if account.Type == AccountTypeAPIKey {
if account.Type == AccountTypeApiKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages"
}
@@ -1136,10 +1165,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
// Filter thinking blocks from request body (prevents 400 errors from missing/invalid signatures).
// We apply this for the main /v1/messages path as well as count_tokens.
body = FilterThinkingBlocks(body)
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
@@ -1178,10 +1203,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 处理anthropic-beta headerOAuth账号需要特殊处理
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
// API-key仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
@@ -1248,12 +1273,12 @@ func requestNeedsBetaFeatures(body []byte) bool {
return false
}
func defaultAPIKeyBetaHeader(body []byte) string {
func defaultApiKeyBetaHeader(body []byte) string {
modelID := gjson.GetBytes(body, "model").String()
if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.APIKeyHaikuBetaHeader
return claude.ApiKeyHaikuBetaHeader
}
return claude.APIKeyBetaHeader
return claude.ApiKeyBetaHeader
}
func truncateForLog(b []byte, maxBytes int) string {
@@ -1270,6 +1295,20 @@ func truncateForLog(b []byte, maxBytes int) string {
return s
}
// isThinkingBlockSignatureError 检测是否是thinking block签名错误
// 这类错误可以通过过滤thinking blocks并重试来解决
func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if msg == "" {
return false
}
// 检测thinking block签名相关的错误
// 例如: "Invalid `signature` in `thinking` block"
return (strings.Contains(msg, "thinking") || strings.Contains(msg, "thought")) &&
strings.Contains(msg, "signature")
}
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
// 默认保守:无法识别则不切换。
@@ -1318,7 +1357,13 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
body, _ := io.ReadAll(resp.Body)
// 处理上游错误,标记账号状态
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
shouldDisable := false
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
var errType, errMsg string
@@ -1630,7 +1675,7 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
APIKey *APIKey
ApiKey *ApiKey
User *User
Account *Account
Subscription *UserSubscription // 可选:订阅信息
@@ -1639,7 +1684,7 @@ type RecordUsageInput struct {
// RecordUsage 记录使用量并扣费(或更新订阅用量)
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
result := input.Result
apiKey := input.APIKey
apiKey := input.ApiKey
user := input.User
account := input.Account
subscription := input.Subscription
@@ -1676,7 +1721,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
durationMs := int(result.Duration.Milliseconds())
usageLog := &UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
@@ -1754,15 +1799,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
body := parsed.Body
reqModel := parsed.Model
// Antigravity 账户不支持 count_tokens 转发,返回估算
// 参考 Antigravity-Manager 和 proxycast 实现
// Antigravity 账户不支持 count_tokens 转发,直接返回空
if account.Platform == PlatformAntigravity {
c.JSON(http.StatusOK, gin.H{"input_tokens": 100})
c.JSON(http.StatusOK, gin.H{"input_tokens": 0})
return nil
}
// 应用模型映射(仅对 apikey 类型账号)
if account.Type == AccountTypeAPIKey {
if account.Type == AccountTypeApiKey {
if reqModel != "" {
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
@@ -1848,7 +1892,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标 URL
targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeAPIKey {
if account.Type == AccountTypeApiKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages/count_tokens"
}
@@ -1866,9 +1910,6 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
// Filter thinking blocks from request body (prevents 400 errors from invalid signatures)
body = FilterThinkingBlocks(body)
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
@@ -1910,10 +1951,10 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
// API-key与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}