perf(gateway): 优化负载感知调度

主要改进:
- 优化负载感知调度的准确性和响应速度
- 将 AccountUsageService 的包级缓存改为依赖注入
- 修复 SSE/JSON 转义和 nil 安全问题
- 恢复 Google One 功能兼容性
This commit is contained in:
ianshaw
2026-01-03 06:32:51 -08:00
parent 26106eb0ac
commit acb718d355
7 changed files with 369 additions and 310 deletions

View File

@@ -27,14 +27,8 @@ const (
accountSlotKeyPrefix = "concurrency:account:" accountSlotKeyPrefix = "concurrency:account:"
// 格式: concurrency:user:{userID} // 格式: concurrency:user:{userID}
userSlotKeyPrefix = "concurrency:user:" userSlotKeyPrefix = "concurrency:user:"
// 等待队列计数器格式: concurrency:wait:{userID}
// Wait queue keys (global structures) waitQueueKeyPrefix = "concurrency:wait:"
// - total: integer total queue depth across all users
// - updated: sorted set of userID -> lastUpdateUnixSec (for TTL cleanup)
// - counts: hash of userID -> current wait count
waitQueueTotalKey = "concurrency:wait:total"
waitQueueUpdatedKey = "concurrency:wait:updated"
waitQueueCountsKey = "concurrency:wait:counts"
// 账号级等待队列计数器格式: wait:account:{accountID} // 账号级等待队列计数器格式: wait:account:{accountID}
accountWaitKeyPrefix = "wait:account:" accountWaitKeyPrefix = "wait:account:"
@@ -100,55 +94,27 @@ var (
`) `)
// incrementWaitScript - only sets TTL on first creation to avoid refreshing // incrementWaitScript - only sets TTL on first creation to avoid refreshing
// KEYS[1] = total key // KEYS[1] = wait queue key
// KEYS[2] = updated zset key // ARGV[1] = maxWait
// KEYS[3] = counts hash key // ARGV[2] = TTL in seconds
// ARGV[1] = userID
// ARGV[2] = maxWait
// ARGV[3] = TTL in seconds
// ARGV[4] = cleanup limit
incrementWaitScript = redis.NewScript(` incrementWaitScript = redis.NewScript(`
local totalKey = KEYS[1] local current = redis.call('GET', KEYS[1])
local updatedKey = KEYS[2] if current == false then
local countsKey = KEYS[3] current = 0
else
local userID = ARGV[1] current = tonumber(current)
local maxWait = tonumber(ARGV[2])
local ttl = tonumber(ARGV[3])
local cleanupLimit = tonumber(ARGV[4])
redis.call('SETNX', totalKey, 0)
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
-- Cleanup expired users (bounded)
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
for _, uid in ipairs(expired) do
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
if c > 0 then
redis.call('DECRBY', totalKey, c)
end
redis.call('HDEL', countsKey, uid)
redis.call('ZREM', updatedKey, uid)
end end
local current = tonumber(redis.call('HGET', countsKey, userID) or '0') if current >= tonumber(ARGV[1]) then
if current >= maxWait then
return 0 return 0
end end
local newVal = current + 1 local newVal = redis.call('INCR', KEYS[1])
redis.call('HSET', countsKey, userID, newVal)
redis.call('ZADD', updatedKey, now, userID)
redis.call('INCR', totalKey)
-- Keep global structures from living forever in totally idle deployments. -- Only set TTL on first creation to avoid refreshing zombie data
local ttlKeep = ttl * 2 if newVal == 1 then
redis.call('EXPIRE', totalKey, ttlKeep) redis.call('EXPIRE', KEYS[1], ARGV[2])
redis.call('EXPIRE', updatedKey, ttlKeep) end
redis.call('EXPIRE', countsKey, ttlKeep)
return 1 return 1
`) `)
@@ -178,111 +144,6 @@ var (
// decrementWaitScript - same as before // decrementWaitScript - same as before
decrementWaitScript = redis.NewScript(` decrementWaitScript = redis.NewScript(`
local totalKey = KEYS[1]
local updatedKey = KEYS[2]
local countsKey = KEYS[3]
local userID = ARGV[1]
local ttl = tonumber(ARGV[2])
local cleanupLimit = tonumber(ARGV[3])
redis.call('SETNX', totalKey, 0)
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
-- Cleanup expired users (bounded)
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
for _, uid in ipairs(expired) do
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
if c > 0 then
redis.call('DECRBY', totalKey, c)
end
redis.call('HDEL', countsKey, uid)
redis.call('ZREM', updatedKey, uid)
end
local current = tonumber(redis.call('HGET', countsKey, userID) or '0')
if current <= 0 then
return 1
end
local newVal = current - 1
if newVal <= 0 then
redis.call('HDEL', countsKey, userID)
redis.call('ZREM', updatedKey, userID)
else
redis.call('HSET', countsKey, userID, newVal)
redis.call('ZADD', updatedKey, now, userID)
end
redis.call('DECR', totalKey)
local ttlKeep = ttl * 2
redis.call('EXPIRE', totalKey, ttlKeep)
redis.call('EXPIRE', updatedKey, ttlKeep)
redis.call('EXPIRE', countsKey, ttlKeep)
return 1
`)
// getTotalWaitScript returns the global wait depth with TTL cleanup.
// KEYS[1] = total key
// KEYS[2] = updated zset key
// KEYS[3] = counts hash key
// ARGV[1] = TTL in seconds
// ARGV[2] = cleanup limit
getTotalWaitScript = redis.NewScript(`
local totalKey = KEYS[1]
local updatedKey = KEYS[2]
local countsKey = KEYS[3]
local ttl = tonumber(ARGV[1])
local cleanupLimit = tonumber(ARGV[2])
redis.call('SETNX', totalKey, 0)
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
-- Cleanup expired users (bounded)
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
for _, uid in ipairs(expired) do
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
if c > 0 then
redis.call('DECRBY', totalKey, c)
end
redis.call('HDEL', countsKey, uid)
redis.call('ZREM', updatedKey, uid)
end
-- If totalKey got lost but counts exist (e.g. Redis restart), recompute once.
local total = redis.call('GET', totalKey)
if total == false then
total = 0
local vals = redis.call('HVALS', countsKey)
for _, v in ipairs(vals) do
total = total + tonumber(v)
end
redis.call('SET', totalKey, total)
end
local ttlKeep = ttl * 2
redis.call('EXPIRE', totalKey, ttlKeep)
redis.call('EXPIRE', updatedKey, ttlKeep)
redis.call('EXPIRE', countsKey, ttlKeep)
local result = tonumber(redis.call('GET', totalKey) or '0')
if result < 0 then
result = 0
redis.call('SET', totalKey, 0)
end
return result
`)
// decrementAccountWaitScript - account-level wait queue decrement
decrementAccountWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1]) local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1]) redis.call('DECR', KEYS[1])
@@ -383,9 +244,7 @@ func userSlotKey(userID int64) string {
} }
func waitQueueKey(userID int64) string { func waitQueueKey(userID int64) string {
// Historical: per-user string keys were used. return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
// Now we use global structures keyed by userID string.
return strconv.FormatInt(userID, 10)
} }
func accountWaitKey(accountID int64) string { func accountWaitKey(accountID int64) string {
@@ -449,16 +308,8 @@ func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64)
// Wait queue operations // Wait queue operations
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
userKey := waitQueueKey(userID) key := waitQueueKey(userID)
result, err := incrementWaitScript.Run( result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
ctx,
c.rdb,
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
userKey,
maxWait,
c.waitQueueTTLSeconds,
200, // cleanup limit per call
).Int()
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -466,35 +317,11 @@ func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64,
} }
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error { func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
userKey := waitQueueKey(userID) key := waitQueueKey(userID)
_, err := decrementWaitScript.Run( _, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
ctx,
c.rdb,
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
userKey,
c.waitQueueTTLSeconds,
200, // cleanup limit per call
).Result()
return err return err
} }
func (c *concurrencyCache) GetTotalWaitCount(ctx context.Context) (int, error) {
if c.rdb == nil {
return 0, nil
}
total, err := getTotalWaitScript.Run(
ctx,
c.rdb,
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
c.waitQueueTTLSeconds,
500, // cleanup limit per query (rare)
).Int64()
if err != nil {
return 0, err
}
return int(total), nil
}
// Account wait queue operations // Account wait queue operations
func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
@@ -508,7 +335,7 @@ func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accoun
func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error { func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
key := accountWaitKey(accountID) key := accountWaitKey(accountID)
_, err := decrementAccountWaitScript.Run(ctx, c.rdb, []string{key}).Result() _, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err return err
} }

View File

@@ -158,7 +158,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() { func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
userID := int64(20) userID := int64(20)
userKey := waitQueueKey(userID) waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2) ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2)
require.NoError(s.T(), err, "IncrementWaitCount 1") require.NoError(s.T(), err, "IncrementWaitCount 1")
@@ -172,31 +172,31 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
require.NoError(s.T(), err, "IncrementWaitCount 3") require.NoError(s.T(), err, "IncrementWaitCount 3")
require.False(s.T(), ok, "expected wait increment over max to fail") require.False(s.T(), ok, "expected wait increment over max to fail")
ttl, err := s.rdb.TTL(s.ctx, waitQueueTotalKey).Result() ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
require.NoError(s.T(), err, "TTL wait total key") require.NoError(s.T(), err, "TTL waitKey")
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL*2) s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount") require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
val, err := s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Int() val, err := s.rdb.Get(s.ctx, waitKey).Int()
require.NoError(s.T(), err, "HGET wait queue count") if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.Equal(s.T(), 1, val, "expected wait count 1") require.Equal(s.T(), 1, val, "expected wait count 1")
total, err := s.rdb.Get(s.ctx, waitQueueTotalKey).Int()
require.NoError(s.T(), err, "GET wait queue total")
require.Equal(s.T(), 1, total, "expected total wait count 1")
} }
func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() { func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
userID := int64(300) userID := int64(300)
userKey := waitQueueKey(userID) waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
// Test decrement on non-existent key - should not error and should not create negative value // Test decrement on non-existent key - should not error and should not create negative value
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key") require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key")
// Verify count remains zero / absent. // Verify no key was created or it's not negative
val, err := s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Int() val, err := s.rdb.Get(s.ctx, waitKey).Int()
require.True(s.T(), errors.Is(err, redis.Nil)) if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty") require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty")
// Set count to 1, then decrement twice // Set count to 1, then decrement twice
@@ -210,15 +210,12 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
// Decrement again on 0 - should not go negative // Decrement again on 0 - should not go negative
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero") require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero")
// Verify per-user count is absent and total is non-negative. // Verify count is 0, not negative
_, err = s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Result() val, err = s.rdb.Get(s.ctx, waitKey).Int()
require.True(s.T(), errors.Is(err, redis.Nil), "expected count field removed on zero")
total, err := s.rdb.Get(s.ctx, waitQueueTotalKey).Int()
if !errors.Is(err, redis.Nil) { if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err) require.NoError(s.T(), err, "Get waitKey after double decrement")
} }
require.GreaterOrEqual(s.T(), total, 0, "expected non-negative total wait count") require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
} }
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() { func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {

View File

@@ -4,9 +4,11 @@ import (
"context" "context"
"fmt" "fmt"
"log" "log"
"strings"
"sync" "sync"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
) )
@@ -17,11 +19,11 @@ type UsageLogRepository interface {
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
@@ -32,10 +34,10 @@ type UsageLogRepository interface {
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error)
// User dashboard stats // User dashboard stats
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
@@ -51,7 +53,7 @@ type UsageLogRepository interface {
// Aggregated stats (optimized) // Aggregated stats (optimized)
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
@@ -69,13 +71,33 @@ type windowStatsCache struct {
timestamp time.Time timestamp time.Time
} }
var ( // antigravityUsageCache 缓存 Antigravity 额度数据
apiCacheMap = sync.Map{} // 缓存 API 响应 type antigravityUsageCache struct {
windowStatsCacheMap = sync.Map{} // 缓存窗口统计 usageInfo *UsageInfo
timestamp time.Time
}
const (
apiCacheTTL = 10 * time.Minute apiCacheTTL = 10 * time.Minute
windowStatsCacheTTL = 1 * time.Minute windowStatsCacheTTL = 1 * time.Minute
) )
// UsageCache 封装账户使用量相关的缓存
type UsageCache struct {
apiCache *sync.Map // accountID -> *apiUsageCache
windowStatsCache *sync.Map // accountID -> *windowStatsCache
antigravityCache *sync.Map // accountID -> *antigravityUsageCache
}
// NewUsageCache 创建 UsageCache 实例
func NewUsageCache() *UsageCache {
return &UsageCache{
apiCache: &sync.Map{},
antigravityCache: &sync.Map{},
windowStatsCache: &sync.Map{},
}
}
// WindowStats 窗口期统计 // WindowStats 窗口期统计
type WindowStats struct { type WindowStats struct {
Requests int64 `json:"requests"` Requests int64 `json:"requests"`
@@ -91,6 +113,12 @@ type UsageProgress struct {
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量) WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
} }
// AntigravityModelQuota Antigravity 单个模型的配额信息
type AntigravityModelQuota struct {
Utilization int `json:"utilization"` // 使用率 0-100
ResetTime string `json:"reset_time"` // 重置时间 ISO8601
}
// UsageInfo 账号使用量信息 // UsageInfo 账号使用量信息
type UsageInfo struct { type UsageInfo struct {
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间 UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
@@ -99,6 +127,9 @@ type UsageInfo struct {
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口 SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额 GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额 GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
// Antigravity 多模型配额
AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"`
} }
// ClaudeUsageResponse Anthropic API返回的usage结构 // ClaudeUsageResponse Anthropic API返回的usage结构
@@ -124,19 +155,51 @@ type ClaudeUsageFetcher interface {
// AccountUsageService 账号使用量查询服务 // AccountUsageService 账号使用量查询服务
type AccountUsageService struct { type AccountUsageService struct {
accountRepo AccountRepository accountRepo AccountRepository
usageLogRepo UsageLogRepository usageLogRepo UsageLogRepository
usageFetcher ClaudeUsageFetcher usageFetcher ClaudeUsageFetcher
geminiQuotaService *GeminiQuotaService geminiQuotaService *GeminiQuotaService
antigravityQuotaFetcher QuotaFetcher
cache *UsageCache
} }
// NewAccountUsageService 创建AccountUsageService实例 // NewAccountUsageService 创建AccountUsageService实例
func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher, geminiQuotaService *GeminiQuotaService) *AccountUsageService { func NewAccountUsageService(
accountRepo AccountRepository,
usageLogRepo UsageLogRepository,
usageFetcher ClaudeUsageFetcher,
geminiQuotaService *GeminiQuotaService,
antigravityQuotaFetcher *AntigravityQuotaFetcher,
cache *UsageCache,
) *AccountUsageService {
if cache == nil {
cache = &UsageCache{
apiCache: &sync.Map{},
antigravityCache: &sync.Map{},
windowStatsCache: &sync.Map{},
}
}
if cache.apiCache == nil {
cache.apiCache = &sync.Map{}
}
if cache.antigravityCache == nil {
cache.antigravityCache = &sync.Map{}
}
if cache.windowStatsCache == nil {
cache.windowStatsCache = &sync.Map{}
}
var quotaFetcher QuotaFetcher
if antigravityQuotaFetcher != nil {
quotaFetcher = antigravityQuotaFetcher
}
return &AccountUsageService{ return &AccountUsageService{
accountRepo: accountRepo, accountRepo: accountRepo,
usageLogRepo: usageLogRepo, usageLogRepo: usageLogRepo,
usageFetcher: usageFetcher, usageFetcher: usageFetcher,
geminiQuotaService: geminiQuotaService, geminiQuotaService: geminiQuotaService,
antigravityQuotaFetcher: quotaFetcher,
cache: cache,
} }
} }
@@ -154,12 +217,17 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return s.getGeminiUsage(ctx, account) return s.getGeminiUsage(ctx, account)
} }
// Antigravity 平台:使用 AntigravityQuotaFetcher 获取额度
if account.Platform == PlatformAntigravity {
return s.getAntigravityUsage(ctx, account)
}
// 只有oauth类型账号可以通过API获取usage有profile scope // 只有oauth类型账号可以通过API获取usage有profile scope
if account.CanGetUsage() { if account.CanGetUsage() {
var apiResp *ClaudeUsageResponse var apiResp *ClaudeUsageResponse
// 1. 检查 API 缓存10 分钟) // 1. 检查 API 缓存10 分钟)
if cached, ok := apiCacheMap.Load(accountID); ok { if cached, ok := s.cache.apiCache.Load(accountID); ok {
if cache, ok := cached.(*apiUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL { if cache, ok := cached.(*apiUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
apiResp = cache.response apiResp = cache.response
} }
@@ -172,7 +240,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return nil, err return nil, err
} }
// 缓存 API 响应 // 缓存 API 响应
apiCacheMap.Store(accountID, &apiUsageCache{ s.cache.apiCache.Store(accountID, &apiUsageCache{
response: apiResp, response: apiResp,
timestamp: time.Now(), timestamp: time.Now(),
}) })
@@ -224,12 +292,70 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
totals := geminiAggregateUsage(stats) totals := geminiAggregateUsage(stats)
resetAt := geminiDailyResetTime(now) resetAt := geminiDailyResetTime(now)
usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost, now) usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost)
usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost, now) usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost)
return usage, nil return usage, nil
} }
// getAntigravityUsage 获取 Antigravity 账户额度
func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
if s.antigravityQuotaFetcher == nil || !s.antigravityQuotaFetcher.CanFetch(account) {
now := time.Now()
return &UsageInfo{UpdatedAt: &now}, nil
}
// Ensure project_id is stable for quota queries.
if strings.TrimSpace(account.GetCredential("project_id")) == "" {
projectID := antigravity.GenerateMockProjectID()
if account.Credentials == nil {
account.Credentials = map[string]any{}
}
account.Credentials["project_id"] = projectID
if s.accountRepo != nil {
_, err := s.accountRepo.BulkUpdate(ctx, []int64{account.ID}, AccountBulkUpdate{
Credentials: map[string]any{"project_id": projectID},
})
if err != nil {
log.Printf("Failed to persist antigravity project_id for account %d: %v", account.ID, err)
}
}
}
// 1. 检查缓存10 分钟)
if cached, ok := s.cache.antigravityCache.Load(account.ID); ok {
if cache, ok := cached.(*antigravityUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
// 重新计算 RemainingSeconds
usage := cloneUsageInfo(cache.usageInfo)
if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil {
usage.FiveHour.RemainingSeconds = remainingSecondsUntil(*usage.FiveHour.ResetsAt)
}
return usage, nil
}
}
// 2. 获取代理 URL
proxyURL, err := s.antigravityQuotaFetcher.GetProxyURL(ctx, account)
if err != nil {
log.Printf("Failed to get proxy URL for account %d: %v", account.ID, err)
proxyURL = ""
}
// 3. 调用 API 获取额度
result, err := s.antigravityQuotaFetcher.FetchQuota(ctx, account, proxyURL)
if err != nil {
return nil, fmt.Errorf("fetch antigravity quota failed: %w", err)
}
// 4. 缓存结果
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
usageInfo: result.UsageInfo,
timestamp: time.Now(),
})
return result.UsageInfo, nil
}
// addWindowStats 为 usage 数据添加窗口期统计 // addWindowStats 为 usage 数据添加窗口期统计
// 使用独立缓存1 分钟),与 API 缓存分离 // 使用独立缓存1 分钟),与 API 缓存分离
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) { func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) {
@@ -241,7 +367,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
// 检查窗口统计缓存1 分钟) // 检查窗口统计缓存1 分钟)
var windowStats *WindowStats var windowStats *WindowStats
if cached, ok := windowStatsCacheMap.Load(account.ID); ok { if cached, ok := s.cache.windowStatsCache.Load(account.ID); ok {
if cache, ok := cached.(*windowStatsCache); ok && time.Since(cache.timestamp) < windowStatsCacheTTL { if cache, ok := cached.(*windowStatsCache); ok && time.Since(cache.timestamp) < windowStatsCacheTTL {
windowStats = cache.stats windowStats = cache.stats
} }
@@ -269,7 +395,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
} }
// 缓存窗口统计1 分钟) // 缓存窗口统计1 分钟)
windowStatsCacheMap.Store(account.ID, &windowStatsCache{ s.cache.windowStatsCache.Store(account.ID, &windowStatsCache{
stats: windowStats, stats: windowStats,
timestamp: time.Now(), timestamp: time.Now(),
}) })
@@ -342,12 +468,12 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA
// 5小时窗口 - 始终创建对象(即使 ResetsAt 为空) // 5小时窗口 - 始终创建对象(即使 ResetsAt 为空)
info.FiveHour = &UsageProgress{ info.FiveHour = &UsageProgress{
Utilization: resp.FiveHour.Utilization, Utilization: clampFloat64(resp.FiveHour.Utilization, 0, 100),
} }
if resp.FiveHour.ResetsAt != "" { if resp.FiveHour.ResetsAt != "" {
if fiveHourReset, err := parseTime(resp.FiveHour.ResetsAt); err == nil { if fiveHourReset, err := parseTime(resp.FiveHour.ResetsAt); err == nil {
info.FiveHour.ResetsAt = &fiveHourReset info.FiveHour.ResetsAt = &fiveHourReset
info.FiveHour.RemainingSeconds = int(time.Until(fiveHourReset).Seconds()) info.FiveHour.RemainingSeconds = remainingSecondsUntil(fiveHourReset)
} else { } else {
log.Printf("Failed to parse FiveHour.ResetsAt: %s, error: %v", resp.FiveHour.ResetsAt, err) log.Printf("Failed to parse FiveHour.ResetsAt: %s, error: %v", resp.FiveHour.ResetsAt, err)
} }
@@ -357,14 +483,14 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA
if resp.SevenDay.ResetsAt != "" { if resp.SevenDay.ResetsAt != "" {
if sevenDayReset, err := parseTime(resp.SevenDay.ResetsAt); err == nil { if sevenDayReset, err := parseTime(resp.SevenDay.ResetsAt); err == nil {
info.SevenDay = &UsageProgress{ info.SevenDay = &UsageProgress{
Utilization: resp.SevenDay.Utilization, Utilization: clampFloat64(resp.SevenDay.Utilization, 0, 100),
ResetsAt: &sevenDayReset, ResetsAt: &sevenDayReset,
RemainingSeconds: int(time.Until(sevenDayReset).Seconds()), RemainingSeconds: remainingSecondsUntil(sevenDayReset),
} }
} else { } else {
log.Printf("Failed to parse SevenDay.ResetsAt: %s, error: %v", resp.SevenDay.ResetsAt, err) log.Printf("Failed to parse SevenDay.ResetsAt: %s, error: %v", resp.SevenDay.ResetsAt, err)
info.SevenDay = &UsageProgress{ info.SevenDay = &UsageProgress{
Utilization: resp.SevenDay.Utilization, Utilization: clampFloat64(resp.SevenDay.Utilization, 0, 100),
} }
} }
} }
@@ -373,14 +499,14 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA
if resp.SevenDaySonnet.ResetsAt != "" { if resp.SevenDaySonnet.ResetsAt != "" {
if sonnetReset, err := parseTime(resp.SevenDaySonnet.ResetsAt); err == nil { if sonnetReset, err := parseTime(resp.SevenDaySonnet.ResetsAt); err == nil {
info.SevenDaySonnet = &UsageProgress{ info.SevenDaySonnet = &UsageProgress{
Utilization: resp.SevenDaySonnet.Utilization, Utilization: clampFloat64(resp.SevenDaySonnet.Utilization, 0, 100),
ResetsAt: &sonnetReset, ResetsAt: &sonnetReset,
RemainingSeconds: int(time.Until(sonnetReset).Seconds()), RemainingSeconds: remainingSecondsUntil(sonnetReset),
} }
} else { } else {
log.Printf("Failed to parse SevenDaySonnet.ResetsAt: %s, error: %v", resp.SevenDaySonnet.ResetsAt, err) log.Printf("Failed to parse SevenDaySonnet.ResetsAt: %s, error: %v", resp.SevenDaySonnet.ResetsAt, err)
info.SevenDaySonnet = &UsageProgress{ info.SevenDaySonnet = &UsageProgress{
Utilization: resp.SevenDaySonnet.Utilization, Utilization: clampFloat64(resp.SevenDaySonnet.Utilization, 0, 100),
} }
} }
} }
@@ -394,10 +520,7 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
// 如果有session_window信息 // 如果有session_window信息
if account.SessionWindowEnd != nil { if account.SessionWindowEnd != nil {
remaining := int(time.Until(*account.SessionWindowEnd).Seconds()) remaining := remainingSecondsUntil(*account.SessionWindowEnd)
if remaining < 0 {
remaining = 0
}
// 根据状态估算使用率 (百分比形式100 = 100%) // 根据状态估算使用率 (百分比形式100 = 100%)
var utilization float64 var utilization float64
@@ -409,6 +532,7 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
default: default:
utilization = 0.0 utilization = 0.0
} }
utilization = clampFloat64(utilization, 0, 100)
info.FiveHour = &UsageProgress{ info.FiveHour = &UsageProgress{
Utilization: utilization, Utilization: utilization,
@@ -427,15 +551,12 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
return info return info
} }
func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64, now time.Time) *UsageProgress { func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64) *UsageProgress {
if limit <= 0 { if limit <= 0 {
return nil return nil
} }
utilization := (float64(used) / float64(limit)) * 100 utilization := clampFloat64((float64(used)/float64(limit))*100, 0, 100)
remainingSeconds := int(resetAt.Sub(now).Seconds()) remainingSeconds := remainingSecondsUntil(resetAt)
if remainingSeconds < 0 {
remainingSeconds = 0
}
resetCopy := resetAt resetCopy := resetAt
return &UsageProgress{ return &UsageProgress{
Utilization: utilization, Utilization: utilization,
@@ -448,3 +569,47 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64
}, },
} }
} }
func cloneUsageInfo(src *UsageInfo) *UsageInfo {
if src == nil {
return nil
}
dst := *src
if src.UpdatedAt != nil {
t := *src.UpdatedAt
dst.UpdatedAt = &t
}
dst.FiveHour = cloneUsageProgress(src.FiveHour)
dst.SevenDay = cloneUsageProgress(src.SevenDay)
dst.SevenDaySonnet = cloneUsageProgress(src.SevenDaySonnet)
dst.GeminiProDaily = cloneUsageProgress(src.GeminiProDaily)
dst.GeminiFlashDaily = cloneUsageProgress(src.GeminiFlashDaily)
if src.AntigravityQuota != nil {
dst.AntigravityQuota = make(map[string]*AntigravityModelQuota, len(src.AntigravityQuota))
for k, v := range src.AntigravityQuota {
if v == nil {
dst.AntigravityQuota[k] = nil
continue
}
copyVal := *v
dst.AntigravityQuota[k] = &copyVal
}
}
return &dst
}
func cloneUsageProgress(src *UsageProgress) *UsageProgress {
if src == nil {
return nil
}
dst := *src
if src.ResetsAt != nil {
t := *src.ResetsAt
dst.ResetsAt = &t
}
if src.WindowStats != nil {
statsCopy := *src.WindowStats
dst.WindowStats = &statsCopy
}
return &dst
}

View File

@@ -32,7 +32,6 @@ type ConcurrencyCache interface {
// 等待队列计数(只在首次创建时设置 TTL // 等待队列计数(只在首次创建时设置 TTL
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
DecrementWaitCount(ctx context.Context, userID int64) error DecrementWaitCount(ctx context.Context, userID int64) error
GetTotalWaitCount(ctx context.Context) (int, error)
// 批量负载查询(只读) // 批量负载查询(只读)
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
@@ -201,14 +200,6 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6
} }
} }
// GetTotalWaitCount returns the total wait queue depth across users.
func (s *ConcurrencyService) GetTotalWaitCount(ctx context.Context) (int, error) {
if s.cache == nil {
return 0, nil
}
return s.cache.GetTotalWaitCount(ctx)
}
// IncrementAccountWaitCount increments the wait queue counter for an account. // IncrementAccountWaitCount increments the wait queue counter for an account.
func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
if s.cache == nil { if s.cache == nil {

View File

@@ -136,6 +136,12 @@ func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int6
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error { func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil return nil
} }
func (m *mockAccountRepoForPlatform) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
return nil
}
func (m *mockAccountRepoForPlatform) ClearTempUnschedulable(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int64) error { func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int64) error {
return nil return nil
} }
@@ -276,7 +282,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(
repo := &mockAccountRepoForPlatform{ repo := &mockAccountRepoForPlatform{
accounts: []Account{ accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey}, {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth}, {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
}, },
accountsByID: map[int64]*Account{}, accountsByID: map[int64]*Account{},
@@ -617,7 +623,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) { t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{ repo := &mockAccountRepoForPlatform{
accounts: []Account{ accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey}, {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth}, {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
}, },
accountsByID: map[int64]*Account{}, accountsByID: map[int64]*Account{},

View File

@@ -75,9 +75,19 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
// FilterThinkingBlocks removes thinking blocks from request body // FilterThinkingBlocks removes thinking blocks from request body
// Returns filtered body or original body if filtering fails (fail-safe) // Returns filtered body or original body if filtering fails (fail-safe)
// This prevents 400 errors from invalid thinking block signatures // This prevents 400 errors from invalid thinking block signatures
//
// Strategy:
// - When thinking.type != "enabled": Remove all thinking blocks
// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures
// (blocks with missing/empty/dummy signatures that would cause 400 errors)
func FilterThinkingBlocks(body []byte) []byte { func FilterThinkingBlocks(body []byte) []byte {
// Fast path: if body doesn't contain "thinking", skip parsing // Fast path: if body doesn't contain "thinking", skip parsing
if !bytes.Contains(body, []byte("thinking")) { if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
!bytes.Contains(body, []byte(`"type": "thinking"`)) &&
!bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) &&
!bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) &&
!bytes.Contains(body, []byte(`"thinking":`)) &&
!bytes.Contains(body, []byte(`"thinking" :`)) {
return body return body
} }
@@ -86,6 +96,14 @@ func FilterThinkingBlocks(body []byte) []byte {
return body // Return original on parse error return body // Return original on parse error
} }
// Check if thinking is enabled
thinkingEnabled := false
if thinking, ok := req["thinking"].(map[string]any); ok {
if thinkType, ok := thinking["type"].(string); ok && thinkType == "enabled" {
thinkingEnabled = true
}
}
messages, ok := req["messages"].([]any) messages, ok := req["messages"].([]any)
if !ok { if !ok {
return body // No messages array return body // No messages array
@@ -98,6 +116,7 @@ func FilterThinkingBlocks(body []byte) []byte {
continue continue
} }
role, _ := msgMap["role"].(string)
content, ok := msgMap["content"].([]any) content, ok := msgMap["content"].([]any)
if !ok { if !ok {
continue continue
@@ -106,6 +125,7 @@ func FilterThinkingBlocks(body []byte) []byte {
// Filter thinking blocks from content array // Filter thinking blocks from content array
newContent := make([]any, 0, len(content)) newContent := make([]any, 0, len(content))
filteredThisMessage := false filteredThisMessage := false
for _, block := range content { for _, block := range content {
blockMap, ok := block.(map[string]any) blockMap, ok := block.(map[string]any)
if !ok { if !ok {
@@ -114,22 +134,34 @@ func FilterThinkingBlocks(body []byte) []byte {
} }
blockType, _ := blockMap["type"].(string) blockType, _ := blockMap["type"].(string)
// Explicit Anthropic-style thinking block: {"type":"thinking", ...}
if blockType == "thinking" { // Handle thinking/redacted_thinking blocks
if blockType == "thinking" || blockType == "redacted_thinking" {
// When thinking is enabled and this is an assistant message,
// only keep thinking blocks with valid (non-empty, non-dummy) signatures
if thinkingEnabled && role == "assistant" {
signature, _ := blockMap["signature"].(string)
// Keep blocks with valid signatures, remove those without
if signature != "" && signature != "skip_thought_signature_validator" {
newContent = append(newContent, block)
continue
}
}
filtered = true filtered = true
filteredThisMessage = true filteredThisMessage = true
continue // Skip thinking blocks continue
} }
// Some clients send the "thinking" object without a "type" discriminator. // Some clients send the "thinking" object without a "type" discriminator.
// Vertex/Claude still expects a signature for any thinking block, so we drop it.
// We intentionally do not drop other typed blocks (e.g. tool_use) that might // We intentionally do not drop other typed blocks (e.g. tool_use) that might
// legitimately contain a "thinking" key inside their payload. // legitimately contain a "thinking" key inside their payload.
if blockType == "" { if blockType == "" {
if _, hasThinking := blockMap["thinking"]; hasThinking { if thinkingContent, hasThinking := blockMap["thinking"]; hasThinking {
_ = thinkingContent
filtered = true filtered = true
filteredThisMessage = true filteredThisMessage = true
continue // Skip thinking blocks continue
} }
} }

View File

@@ -905,7 +905,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
case AccountTypeOAuth, AccountTypeSetupToken: case AccountTypeOAuth, AccountTypeSetupToken:
// Both oauth and setup-token use OAuth token flow // Both oauth and setup-token use OAuth token flow
return s.getOAuthToken(ctx, account) return s.getOAuthToken(ctx, account)
case AccountTypeAPIKey: case AccountTypeApiKey:
apiKey := account.GetCredential("api_key") apiKey := account.GetCredential("api_key")
if apiKey == "" { if apiKey == "" {
return "", "", errors.New("api_key not found in credentials") return "", "", errors.New("api_key not found in credentials")
@@ -976,7 +976,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 应用模型映射仅对apikey类型账号 // 应用模型映射仅对apikey类型账号
originalModel := reqModel originalModel := reqModel
if account.Type == AccountTypeAPIKey { if account.Type == AccountTypeApiKey {
mappedModel := account.GetMappedModel(reqModel) mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel { if mappedModel != reqModel {
// 替换请求体中的模型名 // 替换请求体中的模型名
@@ -1013,8 +1013,37 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return nil, fmt.Errorf("upstream request failed: %w", err) return nil, fmt.Errorf("upstream request failed: %w", err)
} }
// 检查是否需要重试 // 优先检测thinking block签名错误400并重试一次
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { if resp.StatusCode == 400 {
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if readErr == nil {
_ = resp.Body.Close()
if s.isThinkingBlockSignatureError(respBody) {
log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID)
// 过滤thinking blocks并重试
filteredBody := FilterThinkingBlocks(body)
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil {
// 使用重试后的响应,继续后续处理
resp = retryResp
break
}
}
// 重试失败,恢复原始响应体继续处理
resp.Body = io.NopCloser(bytes.NewReader(respBody))
break
}
// 不是thinking签名错误恢复响应体
resp.Body = io.NopCloser(bytes.NewReader(respBody))
}
}
// 检查是否需要通用重试排除400因为400已经在上面特殊处理过了
if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
if attempt < maxRetries { if attempt < maxRetries {
log.Printf("Account %d: upstream error %d, retry %d/%d after %v", log.Printf("Account %d: upstream error %d, retry %d/%d after %v",
account.ID, resp.StatusCode, attempt, maxRetries, retryDelay) account.ID, resp.StatusCode, attempt, maxRetries, retryDelay)
@@ -1047,13 +1076,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
} }
// 处理错误响应(不可重试的错误) // 处理错误响应(不可重试的错误)
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
// 可选:对部分 400 触发 failover默认关闭以保持语义 // 可选:对部分 400 触发 failover默认关闭以保持语义
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 { if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
respBody, readErr := io.ReadAll(resp.Body) respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if readErr != nil { if readErr != nil {
// ReadAll failed, fall back to normal error handling without consuming the stream // ReadAll failed, fall back to normal error handling without consuming the stream
return s.handleErrorResponse(ctx, resp, c, account) return s.handleErrorResponse(ctx, resp, c, account)
} }
_ = resp.Body.Close() _ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody)) resp.Body = io.NopCloser(bytes.NewReader(respBody))
@@ -1110,7 +1139,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) { func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标URL // 确定目标URL
targetURL := claudeAPIURL targetURL := claudeAPIURL
if account.Type == AccountTypeAPIKey { if account.Type == AccountTypeApiKey {
baseURL := account.GetBaseURL() baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages" targetURL = baseURL + "/v1/messages"
} }
@@ -1136,10 +1165,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
} }
} }
// Filter thinking blocks from request body (prevents 400 errors from missing/invalid signatures).
// We apply this for the main /v1/messages path as well as count_tokens.
body = FilterThinkingBlocks(body)
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil { if err != nil {
return nil, err return nil, err
@@ -1178,10 +1203,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 处理anthropic-beta headerOAuth账号需要特殊处理 // 处理anthropic-beta headerOAuth账号需要特殊处理
if tokenType == "oauth" { if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
// API-key仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) // API-key仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) { if requestNeedsBetaFeatures(body) {
if beta := defaultAPIKeyBetaHeader(body); beta != "" { if beta := defaultApiKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta) req.Header.Set("anthropic-beta", beta)
} }
} }
@@ -1248,12 +1273,12 @@ func requestNeedsBetaFeatures(body []byte) bool {
return false return false
} }
func defaultAPIKeyBetaHeader(body []byte) string { func defaultApiKeyBetaHeader(body []byte) string {
modelID := gjson.GetBytes(body, "model").String() modelID := gjson.GetBytes(body, "model").String()
if strings.Contains(strings.ToLower(modelID), "haiku") { if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.APIKeyHaikuBetaHeader return claude.ApiKeyHaikuBetaHeader
} }
return claude.APIKeyBetaHeader return claude.ApiKeyBetaHeader
} }
func truncateForLog(b []byte, maxBytes int) string { func truncateForLog(b []byte, maxBytes int) string {
@@ -1270,6 +1295,20 @@ func truncateForLog(b []byte, maxBytes int) string {
return s return s
} }
// isThinkingBlockSignatureError 检测是否是thinking block签名错误
// 这类错误可以通过过滤thinking blocks并重试来解决
func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if msg == "" {
return false
}
// 检测thinking block签名相关的错误
// 例如: "Invalid `signature` in `thinking` block"
return (strings.Contains(msg, "thinking") || strings.Contains(msg, "thought")) &&
strings.Contains(msg, "signature")
}
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool { func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。 // 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
// 默认保守:无法识别则不切换。 // 默认保守:无法识别则不切换。
@@ -1318,7 +1357,13 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
// 处理上游错误,标记账号状态 // 处理上游错误,标记账号状态
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) shouldDisable := false
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息) // 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
var errType, errMsg string var errType, errMsg string
@@ -1630,7 +1675,7 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
// RecordUsageInput 记录使用量的输入参数 // RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct { type RecordUsageInput struct {
Result *ForwardResult Result *ForwardResult
APIKey *APIKey ApiKey *ApiKey
User *User User *User
Account *Account Account *Account
Subscription *UserSubscription // 可选:订阅信息 Subscription *UserSubscription // 可选:订阅信息
@@ -1639,7 +1684,7 @@ type RecordUsageInput struct {
// RecordUsage 记录使用量并扣费(或更新订阅用量) // RecordUsage 记录使用量并扣费(或更新订阅用量)
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error { func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
result := input.Result result := input.Result
apiKey := input.APIKey apiKey := input.ApiKey
user := input.User user := input.User
account := input.Account account := input.Account
subscription := input.Subscription subscription := input.Subscription
@@ -1676,7 +1721,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
durationMs := int(result.Duration.Milliseconds()) durationMs := int(result.Duration.Milliseconds())
usageLog := &UsageLog{ usageLog := &UsageLog{
UserID: user.ID, UserID: user.ID,
APIKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
RequestID: result.RequestID, RequestID: result.RequestID,
Model: result.Model, Model: result.Model,
@@ -1754,15 +1799,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
body := parsed.Body body := parsed.Body
reqModel := parsed.Model reqModel := parsed.Model
// Antigravity 账户不支持 count_tokens 转发,返回估算 // Antigravity 账户不支持 count_tokens 转发,直接返回空
// 参考 Antigravity-Manager 和 proxycast 实现
if account.Platform == PlatformAntigravity { if account.Platform == PlatformAntigravity {
c.JSON(http.StatusOK, gin.H{"input_tokens": 100}) c.JSON(http.StatusOK, gin.H{"input_tokens": 0})
return nil return nil
} }
// 应用模型映射(仅对 apikey 类型账号) // 应用模型映射(仅对 apikey 类型账号)
if account.Type == AccountTypeAPIKey { if account.Type == AccountTypeApiKey {
if reqModel != "" { if reqModel != "" {
mappedModel := account.GetMappedModel(reqModel) mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel { if mappedModel != reqModel {
@@ -1848,7 +1892,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) { func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标 URL // 确定目标 URL
targetURL := claudeAPICountTokensURL targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeAPIKey { if account.Type == AccountTypeApiKey {
baseURL := account.GetBaseURL() baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages/count_tokens" targetURL = baseURL + "/v1/messages/count_tokens"
} }
@@ -1866,9 +1910,6 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
} }
} }
// Filter thinking blocks from request body (prevents 400 errors from invalid signatures)
body = FilterThinkingBlocks(body)
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil { if err != nil {
return nil, err return nil, err
@@ -1910,10 +1951,10 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:处理 anthropic-beta header // OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" { if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
// API-key与 messages 同步的按需 beta 注入(默认关闭) // API-key与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) { if requestNeedsBetaFeatures(body) {
if beta := defaultAPIKeyBetaHeader(body); beta != "" { if beta := defaultApiKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta) req.Header.Set("anthropic-beta", beta)
} }
} }