perf(gateway): 优化负载感知调度
主要改进: - 优化负载感知调度的准确性和响应速度 - 将 AccountUsageService 的包级缓存改为依赖注入 - 修复 SSE/JSON 转义和 nil 安全问题 - 恢复 Google One 功能兼容性
This commit is contained in:
@@ -27,14 +27,8 @@ const (
|
||||
accountSlotKeyPrefix = "concurrency:account:"
|
||||
// 格式: concurrency:user:{userID}
|
||||
userSlotKeyPrefix = "concurrency:user:"
|
||||
|
||||
// Wait queue keys (global structures)
|
||||
// - total: integer total queue depth across all users
|
||||
// - updated: sorted set of userID -> lastUpdateUnixSec (for TTL cleanup)
|
||||
// - counts: hash of userID -> current wait count
|
||||
waitQueueTotalKey = "concurrency:wait:total"
|
||||
waitQueueUpdatedKey = "concurrency:wait:updated"
|
||||
waitQueueCountsKey = "concurrency:wait:counts"
|
||||
// 等待队列计数器格式: concurrency:wait:{userID}
|
||||
waitQueueKeyPrefix = "concurrency:wait:"
|
||||
// 账号级等待队列计数器格式: wait:account:{accountID}
|
||||
accountWaitKeyPrefix = "wait:account:"
|
||||
|
||||
@@ -100,55 +94,27 @@ var (
|
||||
`)
|
||||
|
||||
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
|
||||
// KEYS[1] = total key
|
||||
// KEYS[2] = updated zset key
|
||||
// KEYS[3] = counts hash key
|
||||
// ARGV[1] = userID
|
||||
// ARGV[2] = maxWait
|
||||
// ARGV[3] = TTL in seconds
|
||||
// ARGV[4] = cleanup limit
|
||||
// KEYS[1] = wait queue key
|
||||
// ARGV[1] = maxWait
|
||||
// ARGV[2] = TTL in seconds
|
||||
incrementWaitScript = redis.NewScript(`
|
||||
local totalKey = KEYS[1]
|
||||
local updatedKey = KEYS[2]
|
||||
local countsKey = KEYS[3]
|
||||
|
||||
local userID = ARGV[1]
|
||||
local maxWait = tonumber(ARGV[2])
|
||||
local ttl = tonumber(ARGV[3])
|
||||
local cleanupLimit = tonumber(ARGV[4])
|
||||
|
||||
redis.call('SETNX', totalKey, 0)
|
||||
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
|
||||
-- Cleanup expired users (bounded)
|
||||
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
|
||||
for _, uid in ipairs(expired) do
|
||||
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
|
||||
if c > 0 then
|
||||
redis.call('DECRBY', totalKey, c)
|
||||
end
|
||||
redis.call('HDEL', countsKey, uid)
|
||||
redis.call('ZREM', updatedKey, uid)
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
|
||||
local current = tonumber(redis.call('HGET', countsKey, userID) or '0')
|
||||
if current >= maxWait then
|
||||
if current >= tonumber(ARGV[1]) then
|
||||
return 0
|
||||
end
|
||||
|
||||
local newVal = current + 1
|
||||
redis.call('HSET', countsKey, userID, newVal)
|
||||
redis.call('ZADD', updatedKey, now, userID)
|
||||
redis.call('INCR', totalKey)
|
||||
local newVal = redis.call('INCR', KEYS[1])
|
||||
|
||||
-- Keep global structures from living forever in totally idle deployments.
|
||||
local ttlKeep = ttl * 2
|
||||
redis.call('EXPIRE', totalKey, ttlKeep)
|
||||
redis.call('EXPIRE', updatedKey, ttlKeep)
|
||||
redis.call('EXPIRE', countsKey, ttlKeep)
|
||||
-- Only set TTL on first creation to avoid refreshing zombie data
|
||||
if newVal == 1 then
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
end
|
||||
|
||||
return 1
|
||||
`)
|
||||
@@ -178,111 +144,6 @@ var (
|
||||
|
||||
// decrementWaitScript - same as before
|
||||
decrementWaitScript = redis.NewScript(`
|
||||
local totalKey = KEYS[1]
|
||||
local updatedKey = KEYS[2]
|
||||
local countsKey = KEYS[3]
|
||||
|
||||
local userID = ARGV[1]
|
||||
local ttl = tonumber(ARGV[2])
|
||||
local cleanupLimit = tonumber(ARGV[3])
|
||||
|
||||
redis.call('SETNX', totalKey, 0)
|
||||
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
|
||||
-- Cleanup expired users (bounded)
|
||||
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
|
||||
for _, uid in ipairs(expired) do
|
||||
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
|
||||
if c > 0 then
|
||||
redis.call('DECRBY', totalKey, c)
|
||||
end
|
||||
redis.call('HDEL', countsKey, uid)
|
||||
redis.call('ZREM', updatedKey, uid)
|
||||
end
|
||||
|
||||
local current = tonumber(redis.call('HGET', countsKey, userID) or '0')
|
||||
if current <= 0 then
|
||||
return 1
|
||||
end
|
||||
|
||||
local newVal = current - 1
|
||||
if newVal <= 0 then
|
||||
redis.call('HDEL', countsKey, userID)
|
||||
redis.call('ZREM', updatedKey, userID)
|
||||
else
|
||||
redis.call('HSET', countsKey, userID, newVal)
|
||||
redis.call('ZADD', updatedKey, now, userID)
|
||||
end
|
||||
redis.call('DECR', totalKey)
|
||||
|
||||
local ttlKeep = ttl * 2
|
||||
redis.call('EXPIRE', totalKey, ttlKeep)
|
||||
redis.call('EXPIRE', updatedKey, ttlKeep)
|
||||
redis.call('EXPIRE', countsKey, ttlKeep)
|
||||
|
||||
return 1
|
||||
`)
|
||||
|
||||
// getTotalWaitScript returns the global wait depth with TTL cleanup.
|
||||
// KEYS[1] = total key
|
||||
// KEYS[2] = updated zset key
|
||||
// KEYS[3] = counts hash key
|
||||
// ARGV[1] = TTL in seconds
|
||||
// ARGV[2] = cleanup limit
|
||||
getTotalWaitScript = redis.NewScript(`
|
||||
local totalKey = KEYS[1]
|
||||
local updatedKey = KEYS[2]
|
||||
local countsKey = KEYS[3]
|
||||
|
||||
local ttl = tonumber(ARGV[1])
|
||||
local cleanupLimit = tonumber(ARGV[2])
|
||||
|
||||
redis.call('SETNX', totalKey, 0)
|
||||
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
|
||||
-- Cleanup expired users (bounded)
|
||||
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
|
||||
for _, uid in ipairs(expired) do
|
||||
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
|
||||
if c > 0 then
|
||||
redis.call('DECRBY', totalKey, c)
|
||||
end
|
||||
redis.call('HDEL', countsKey, uid)
|
||||
redis.call('ZREM', updatedKey, uid)
|
||||
end
|
||||
|
||||
-- If totalKey got lost but counts exist (e.g. Redis restart), recompute once.
|
||||
local total = redis.call('GET', totalKey)
|
||||
if total == false then
|
||||
total = 0
|
||||
local vals = redis.call('HVALS', countsKey)
|
||||
for _, v in ipairs(vals) do
|
||||
total = total + tonumber(v)
|
||||
end
|
||||
redis.call('SET', totalKey, total)
|
||||
end
|
||||
|
||||
local ttlKeep = ttl * 2
|
||||
redis.call('EXPIRE', totalKey, ttlKeep)
|
||||
redis.call('EXPIRE', updatedKey, ttlKeep)
|
||||
redis.call('EXPIRE', countsKey, ttlKeep)
|
||||
|
||||
local result = tonumber(redis.call('GET', totalKey) or '0')
|
||||
if result < 0 then
|
||||
result = 0
|
||||
redis.call('SET', totalKey, 0)
|
||||
end
|
||||
return result
|
||||
`)
|
||||
|
||||
// decrementAccountWaitScript - account-level wait queue decrement
|
||||
decrementAccountWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
redis.call('DECR', KEYS[1])
|
||||
@@ -383,9 +244,7 @@ func userSlotKey(userID int64) string {
|
||||
}
|
||||
|
||||
func waitQueueKey(userID int64) string {
|
||||
// Historical: per-user string keys were used.
|
||||
// Now we use global structures keyed by userID string.
|
||||
return strconv.FormatInt(userID, 10)
|
||||
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
}
|
||||
|
||||
func accountWaitKey(accountID int64) string {
|
||||
@@ -449,16 +308,8 @@ func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64)
|
||||
// Wait queue operations
|
||||
|
||||
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
userKey := waitQueueKey(userID)
|
||||
result, err := incrementWaitScript.Run(
|
||||
ctx,
|
||||
c.rdb,
|
||||
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
|
||||
userKey,
|
||||
maxWait,
|
||||
c.waitQueueTTLSeconds,
|
||||
200, // cleanup limit per call
|
||||
).Int()
|
||||
key := waitQueueKey(userID)
|
||||
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -466,35 +317,11 @@ func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64,
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||
userKey := waitQueueKey(userID)
|
||||
_, err := decrementWaitScript.Run(
|
||||
ctx,
|
||||
c.rdb,
|
||||
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
|
||||
userKey,
|
||||
c.waitQueueTTLSeconds,
|
||||
200, // cleanup limit per call
|
||||
).Result()
|
||||
key := waitQueueKey(userID)
|
||||
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetTotalWaitCount(ctx context.Context) (int, error) {
|
||||
if c.rdb == nil {
|
||||
return 0, nil
|
||||
}
|
||||
total, err := getTotalWaitScript.Run(
|
||||
ctx,
|
||||
c.rdb,
|
||||
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
|
||||
c.waitQueueTTLSeconds,
|
||||
500, // cleanup limit per query (rare)
|
||||
).Int64()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(total), nil
|
||||
}
|
||||
|
||||
// Account wait queue operations
|
||||
|
||||
func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
@@ -508,7 +335,7 @@ func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accoun
|
||||
|
||||
func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
|
||||
key := accountWaitKey(accountID)
|
||||
_, err := decrementAccountWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -158,7 +158,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
|
||||
userID := int64(20)
|
||||
userKey := waitQueueKey(userID)
|
||||
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
|
||||
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 1")
|
||||
@@ -172,31 +172,31 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 3")
|
||||
require.False(s.T(), ok, "expected wait increment over max to fail")
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, waitQueueTotalKey).Result()
|
||||
require.NoError(s.T(), err, "TTL wait total key")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL*2)
|
||||
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
|
||||
require.NoError(s.T(), err, "TTL waitKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
|
||||
|
||||
val, err := s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Int()
|
||||
require.NoError(s.T(), err, "HGET wait queue count")
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.Equal(s.T(), 1, val, "expected wait count 1")
|
||||
|
||||
total, err := s.rdb.Get(s.ctx, waitQueueTotalKey).Int()
|
||||
require.NoError(s.T(), err, "GET wait queue total")
|
||||
require.Equal(s.T(), 1, total, "expected total wait count 1")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
|
||||
userID := int64(300)
|
||||
userKey := waitQueueKey(userID)
|
||||
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
|
||||
// Test decrement on non-existent key - should not error and should not create negative value
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key")
|
||||
|
||||
// Verify count remains zero / absent.
|
||||
val, err := s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Int()
|
||||
require.True(s.T(), errors.Is(err, redis.Nil))
|
||||
// Verify no key was created or it's not negative
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty")
|
||||
|
||||
// Set count to 1, then decrement twice
|
||||
@@ -210,15 +210,12 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
|
||||
// Decrement again on 0 - should not go negative
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero")
|
||||
|
||||
// Verify per-user count is absent and total is non-negative.
|
||||
_, err = s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Result()
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected count field removed on zero")
|
||||
|
||||
total, err := s.rdb.Get(s.ctx, waitQueueTotalKey).Int()
|
||||
// Verify count is 0, not negative
|
||||
val, err = s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err)
|
||||
require.NoError(s.T(), err, "Get waitKey after double decrement")
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), total, 0, "expected non-negative total wait count")
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
|
||||
|
||||
@@ -4,9 +4,11 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
)
|
||||
@@ -17,11 +19,11 @@ type UsageLogRepository interface {
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
|
||||
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
|
||||
@@ -32,10 +34,10 @@ type UsageLogRepository interface {
|
||||
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error)
|
||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error)
|
||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||
GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error)
|
||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
|
||||
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
|
||||
GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error)
|
||||
|
||||
// User dashboard stats
|
||||
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
|
||||
@@ -51,7 +53,7 @@ type UsageLogRepository interface {
|
||||
|
||||
// Aggregated stats (optimized)
|
||||
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
|
||||
@@ -69,13 +71,33 @@ type windowStatsCache struct {
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
apiCacheMap = sync.Map{} // 缓存 API 响应
|
||||
windowStatsCacheMap = sync.Map{} // 缓存窗口统计
|
||||
// antigravityUsageCache 缓存 Antigravity 额度数据
|
||||
type antigravityUsageCache struct {
|
||||
usageInfo *UsageInfo
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
apiCacheTTL = 10 * time.Minute
|
||||
windowStatsCacheTTL = 1 * time.Minute
|
||||
)
|
||||
|
||||
// UsageCache 封装账户使用量相关的缓存
|
||||
type UsageCache struct {
|
||||
apiCache *sync.Map // accountID -> *apiUsageCache
|
||||
windowStatsCache *sync.Map // accountID -> *windowStatsCache
|
||||
antigravityCache *sync.Map // accountID -> *antigravityUsageCache
|
||||
}
|
||||
|
||||
// NewUsageCache 创建 UsageCache 实例
|
||||
func NewUsageCache() *UsageCache {
|
||||
return &UsageCache{
|
||||
apiCache: &sync.Map{},
|
||||
antigravityCache: &sync.Map{},
|
||||
windowStatsCache: &sync.Map{},
|
||||
}
|
||||
}
|
||||
|
||||
// WindowStats 窗口期统计
|
||||
type WindowStats struct {
|
||||
Requests int64 `json:"requests"`
|
||||
@@ -91,6 +113,12 @@ type UsageProgress struct {
|
||||
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
|
||||
}
|
||||
|
||||
// AntigravityModelQuota Antigravity 单个模型的配额信息
|
||||
type AntigravityModelQuota struct {
|
||||
Utilization int `json:"utilization"` // 使用率 0-100
|
||||
ResetTime string `json:"reset_time"` // 重置时间 ISO8601
|
||||
}
|
||||
|
||||
// UsageInfo 账号使用量信息
|
||||
type UsageInfo struct {
|
||||
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
||||
@@ -99,6 +127,9 @@ type UsageInfo struct {
|
||||
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
|
||||
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
|
||||
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
|
||||
|
||||
// Antigravity 多模型配额
|
||||
AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeUsageResponse Anthropic API返回的usage结构
|
||||
@@ -124,19 +155,51 @@ type ClaudeUsageFetcher interface {
|
||||
|
||||
// AccountUsageService 账号使用量查询服务
|
||||
type AccountUsageService struct {
|
||||
accountRepo AccountRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
usageFetcher ClaudeUsageFetcher
|
||||
geminiQuotaService *GeminiQuotaService
|
||||
accountRepo AccountRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
usageFetcher ClaudeUsageFetcher
|
||||
geminiQuotaService *GeminiQuotaService
|
||||
antigravityQuotaFetcher QuotaFetcher
|
||||
cache *UsageCache
|
||||
}
|
||||
|
||||
// NewAccountUsageService 创建AccountUsageService实例
|
||||
func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher, geminiQuotaService *GeminiQuotaService) *AccountUsageService {
|
||||
func NewAccountUsageService(
|
||||
accountRepo AccountRepository,
|
||||
usageLogRepo UsageLogRepository,
|
||||
usageFetcher ClaudeUsageFetcher,
|
||||
geminiQuotaService *GeminiQuotaService,
|
||||
antigravityQuotaFetcher *AntigravityQuotaFetcher,
|
||||
cache *UsageCache,
|
||||
) *AccountUsageService {
|
||||
if cache == nil {
|
||||
cache = &UsageCache{
|
||||
apiCache: &sync.Map{},
|
||||
antigravityCache: &sync.Map{},
|
||||
windowStatsCache: &sync.Map{},
|
||||
}
|
||||
}
|
||||
if cache.apiCache == nil {
|
||||
cache.apiCache = &sync.Map{}
|
||||
}
|
||||
if cache.antigravityCache == nil {
|
||||
cache.antigravityCache = &sync.Map{}
|
||||
}
|
||||
if cache.windowStatsCache == nil {
|
||||
cache.windowStatsCache = &sync.Map{}
|
||||
}
|
||||
|
||||
var quotaFetcher QuotaFetcher
|
||||
if antigravityQuotaFetcher != nil {
|
||||
quotaFetcher = antigravityQuotaFetcher
|
||||
}
|
||||
return &AccountUsageService{
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
usageFetcher: usageFetcher,
|
||||
geminiQuotaService: geminiQuotaService,
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
usageFetcher: usageFetcher,
|
||||
geminiQuotaService: geminiQuotaService,
|
||||
antigravityQuotaFetcher: quotaFetcher,
|
||||
cache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -154,12 +217,17 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
||||
return s.getGeminiUsage(ctx, account)
|
||||
}
|
||||
|
||||
// Antigravity 平台:使用 AntigravityQuotaFetcher 获取额度
|
||||
if account.Platform == PlatformAntigravity {
|
||||
return s.getAntigravityUsage(ctx, account)
|
||||
}
|
||||
|
||||
// 只有oauth类型账号可以通过API获取usage(有profile scope)
|
||||
if account.CanGetUsage() {
|
||||
var apiResp *ClaudeUsageResponse
|
||||
|
||||
// 1. 检查 API 缓存(10 分钟)
|
||||
if cached, ok := apiCacheMap.Load(accountID); ok {
|
||||
if cached, ok := s.cache.apiCache.Load(accountID); ok {
|
||||
if cache, ok := cached.(*apiUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
|
||||
apiResp = cache.response
|
||||
}
|
||||
@@ -172,7 +240,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
||||
return nil, err
|
||||
}
|
||||
// 缓存 API 响应
|
||||
apiCacheMap.Store(accountID, &apiUsageCache{
|
||||
s.cache.apiCache.Store(accountID, &apiUsageCache{
|
||||
response: apiResp,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
@@ -224,12 +292,70 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
|
||||
totals := geminiAggregateUsage(stats)
|
||||
resetAt := geminiDailyResetTime(now)
|
||||
|
||||
usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost, now)
|
||||
usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost, now)
|
||||
usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost)
|
||||
usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost)
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// getAntigravityUsage 获取 Antigravity 账户额度
|
||||
func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
|
||||
if s.antigravityQuotaFetcher == nil || !s.antigravityQuotaFetcher.CanFetch(account) {
|
||||
now := time.Now()
|
||||
return &UsageInfo{UpdatedAt: &now}, nil
|
||||
}
|
||||
|
||||
// Ensure project_id is stable for quota queries.
|
||||
if strings.TrimSpace(account.GetCredential("project_id")) == "" {
|
||||
projectID := antigravity.GenerateMockProjectID()
|
||||
if account.Credentials == nil {
|
||||
account.Credentials = map[string]any{}
|
||||
}
|
||||
account.Credentials["project_id"] = projectID
|
||||
if s.accountRepo != nil {
|
||||
_, err := s.accountRepo.BulkUpdate(ctx, []int64{account.ID}, AccountBulkUpdate{
|
||||
Credentials: map[string]any{"project_id": projectID},
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Failed to persist antigravity project_id for account %d: %v", account.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 1. 检查缓存(10 分钟)
|
||||
if cached, ok := s.cache.antigravityCache.Load(account.ID); ok {
|
||||
if cache, ok := cached.(*antigravityUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
|
||||
// 重新计算 RemainingSeconds
|
||||
usage := cloneUsageInfo(cache.usageInfo)
|
||||
if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil {
|
||||
usage.FiveHour.RemainingSeconds = remainingSecondsUntil(*usage.FiveHour.ResetsAt)
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 获取代理 URL
|
||||
proxyURL, err := s.antigravityQuotaFetcher.GetProxyURL(ctx, account)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get proxy URL for account %d: %v", account.ID, err)
|
||||
proxyURL = ""
|
||||
}
|
||||
|
||||
// 3. 调用 API 获取额度
|
||||
result, err := s.antigravityQuotaFetcher.FetchQuota(ctx, account, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch antigravity quota failed: %w", err)
|
||||
}
|
||||
|
||||
// 4. 缓存结果
|
||||
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
|
||||
usageInfo: result.UsageInfo,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
|
||||
return result.UsageInfo, nil
|
||||
}
|
||||
|
||||
// addWindowStats 为 usage 数据添加窗口期统计
|
||||
// 使用独立缓存(1 分钟),与 API 缓存分离
|
||||
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) {
|
||||
@@ -241,7 +367,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
|
||||
|
||||
// 检查窗口统计缓存(1 分钟)
|
||||
var windowStats *WindowStats
|
||||
if cached, ok := windowStatsCacheMap.Load(account.ID); ok {
|
||||
if cached, ok := s.cache.windowStatsCache.Load(account.ID); ok {
|
||||
if cache, ok := cached.(*windowStatsCache); ok && time.Since(cache.timestamp) < windowStatsCacheTTL {
|
||||
windowStats = cache.stats
|
||||
}
|
||||
@@ -269,7 +395,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
|
||||
}
|
||||
|
||||
// 缓存窗口统计(1 分钟)
|
||||
windowStatsCacheMap.Store(account.ID, &windowStatsCache{
|
||||
s.cache.windowStatsCache.Store(account.ID, &windowStatsCache{
|
||||
stats: windowStats,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
@@ -342,12 +468,12 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA
|
||||
|
||||
// 5小时窗口 - 始终创建对象(即使 ResetsAt 为空)
|
||||
info.FiveHour = &UsageProgress{
|
||||
Utilization: resp.FiveHour.Utilization,
|
||||
Utilization: clampFloat64(resp.FiveHour.Utilization, 0, 100),
|
||||
}
|
||||
if resp.FiveHour.ResetsAt != "" {
|
||||
if fiveHourReset, err := parseTime(resp.FiveHour.ResetsAt); err == nil {
|
||||
info.FiveHour.ResetsAt = &fiveHourReset
|
||||
info.FiveHour.RemainingSeconds = int(time.Until(fiveHourReset).Seconds())
|
||||
info.FiveHour.RemainingSeconds = remainingSecondsUntil(fiveHourReset)
|
||||
} else {
|
||||
log.Printf("Failed to parse FiveHour.ResetsAt: %s, error: %v", resp.FiveHour.ResetsAt, err)
|
||||
}
|
||||
@@ -357,14 +483,14 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA
|
||||
if resp.SevenDay.ResetsAt != "" {
|
||||
if sevenDayReset, err := parseTime(resp.SevenDay.ResetsAt); err == nil {
|
||||
info.SevenDay = &UsageProgress{
|
||||
Utilization: resp.SevenDay.Utilization,
|
||||
Utilization: clampFloat64(resp.SevenDay.Utilization, 0, 100),
|
||||
ResetsAt: &sevenDayReset,
|
||||
RemainingSeconds: int(time.Until(sevenDayReset).Seconds()),
|
||||
RemainingSeconds: remainingSecondsUntil(sevenDayReset),
|
||||
}
|
||||
} else {
|
||||
log.Printf("Failed to parse SevenDay.ResetsAt: %s, error: %v", resp.SevenDay.ResetsAt, err)
|
||||
info.SevenDay = &UsageProgress{
|
||||
Utilization: resp.SevenDay.Utilization,
|
||||
Utilization: clampFloat64(resp.SevenDay.Utilization, 0, 100),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -373,14 +499,14 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA
|
||||
if resp.SevenDaySonnet.ResetsAt != "" {
|
||||
if sonnetReset, err := parseTime(resp.SevenDaySonnet.ResetsAt); err == nil {
|
||||
info.SevenDaySonnet = &UsageProgress{
|
||||
Utilization: resp.SevenDaySonnet.Utilization,
|
||||
Utilization: clampFloat64(resp.SevenDaySonnet.Utilization, 0, 100),
|
||||
ResetsAt: &sonnetReset,
|
||||
RemainingSeconds: int(time.Until(sonnetReset).Seconds()),
|
||||
RemainingSeconds: remainingSecondsUntil(sonnetReset),
|
||||
}
|
||||
} else {
|
||||
log.Printf("Failed to parse SevenDaySonnet.ResetsAt: %s, error: %v", resp.SevenDaySonnet.ResetsAt, err)
|
||||
info.SevenDaySonnet = &UsageProgress{
|
||||
Utilization: resp.SevenDaySonnet.Utilization,
|
||||
Utilization: clampFloat64(resp.SevenDaySonnet.Utilization, 0, 100),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -394,10 +520,7 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
|
||||
|
||||
// 如果有session_window信息
|
||||
if account.SessionWindowEnd != nil {
|
||||
remaining := int(time.Until(*account.SessionWindowEnd).Seconds())
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
remaining := remainingSecondsUntil(*account.SessionWindowEnd)
|
||||
|
||||
// 根据状态估算使用率 (百分比形式,100 = 100%)
|
||||
var utilization float64
|
||||
@@ -409,6 +532,7 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
|
||||
default:
|
||||
utilization = 0.0
|
||||
}
|
||||
utilization = clampFloat64(utilization, 0, 100)
|
||||
|
||||
info.FiveHour = &UsageProgress{
|
||||
Utilization: utilization,
|
||||
@@ -427,15 +551,12 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
|
||||
return info
|
||||
}
|
||||
|
||||
func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64, now time.Time) *UsageProgress {
|
||||
func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64) *UsageProgress {
|
||||
if limit <= 0 {
|
||||
return nil
|
||||
}
|
||||
utilization := (float64(used) / float64(limit)) * 100
|
||||
remainingSeconds := int(resetAt.Sub(now).Seconds())
|
||||
if remainingSeconds < 0 {
|
||||
remainingSeconds = 0
|
||||
}
|
||||
utilization := clampFloat64((float64(used)/float64(limit))*100, 0, 100)
|
||||
remainingSeconds := remainingSecondsUntil(resetAt)
|
||||
resetCopy := resetAt
|
||||
return &UsageProgress{
|
||||
Utilization: utilization,
|
||||
@@ -448,3 +569,47 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func cloneUsageInfo(src *UsageInfo) *UsageInfo {
|
||||
if src == nil {
|
||||
return nil
|
||||
}
|
||||
dst := *src
|
||||
if src.UpdatedAt != nil {
|
||||
t := *src.UpdatedAt
|
||||
dst.UpdatedAt = &t
|
||||
}
|
||||
dst.FiveHour = cloneUsageProgress(src.FiveHour)
|
||||
dst.SevenDay = cloneUsageProgress(src.SevenDay)
|
||||
dst.SevenDaySonnet = cloneUsageProgress(src.SevenDaySonnet)
|
||||
dst.GeminiProDaily = cloneUsageProgress(src.GeminiProDaily)
|
||||
dst.GeminiFlashDaily = cloneUsageProgress(src.GeminiFlashDaily)
|
||||
if src.AntigravityQuota != nil {
|
||||
dst.AntigravityQuota = make(map[string]*AntigravityModelQuota, len(src.AntigravityQuota))
|
||||
for k, v := range src.AntigravityQuota {
|
||||
if v == nil {
|
||||
dst.AntigravityQuota[k] = nil
|
||||
continue
|
||||
}
|
||||
copyVal := *v
|
||||
dst.AntigravityQuota[k] = ©Val
|
||||
}
|
||||
}
|
||||
return &dst
|
||||
}
|
||||
|
||||
func cloneUsageProgress(src *UsageProgress) *UsageProgress {
|
||||
if src == nil {
|
||||
return nil
|
||||
}
|
||||
dst := *src
|
||||
if src.ResetsAt != nil {
|
||||
t := *src.ResetsAt
|
||||
dst.ResetsAt = &t
|
||||
}
|
||||
if src.WindowStats != nil {
|
||||
statsCopy := *src.WindowStats
|
||||
dst.WindowStats = &statsCopy
|
||||
}
|
||||
return &dst
|
||||
}
|
||||
|
||||
@@ -32,7 +32,6 @@ type ConcurrencyCache interface {
|
||||
// 等待队列计数(只在首次创建时设置 TTL)
|
||||
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
|
||||
DecrementWaitCount(ctx context.Context, userID int64) error
|
||||
GetTotalWaitCount(ctx context.Context) (int, error)
|
||||
|
||||
// 批量负载查询(只读)
|
||||
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
|
||||
@@ -201,14 +200,6 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6
|
||||
}
|
||||
}
|
||||
|
||||
// GetTotalWaitCount returns the total wait queue depth across users.
|
||||
func (s *ConcurrencyService) GetTotalWaitCount(ctx context.Context) (int, error) {
|
||||
if s.cache == nil {
|
||||
return 0, nil
|
||||
}
|
||||
return s.cache.GetTotalWaitCount(ctx)
|
||||
}
|
||||
|
||||
// IncrementAccountWaitCount increments the wait queue counter for an account.
|
||||
func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
if s.cache == nil {
|
||||
|
||||
@@ -136,6 +136,12 @@ func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int6
|
||||
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ClearTempUnschedulable(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
@@ -276,7 +282,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
@@ -617,7 +623,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
|
||||
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
|
||||
@@ -75,9 +75,19 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
|
||||
// FilterThinkingBlocks removes thinking blocks from request body
|
||||
// Returns filtered body or original body if filtering fails (fail-safe)
|
||||
// This prevents 400 errors from invalid thinking block signatures
|
||||
//
|
||||
// Strategy:
|
||||
// - When thinking.type != "enabled": Remove all thinking blocks
|
||||
// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures
|
||||
// (blocks with missing/empty/dummy signatures that would cause 400 errors)
|
||||
func FilterThinkingBlocks(body []byte) []byte {
|
||||
// Fast path: if body doesn't contain "thinking", skip parsing
|
||||
if !bytes.Contains(body, []byte("thinking")) {
|
||||
if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type": "thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"thinking":`)) &&
|
||||
!bytes.Contains(body, []byte(`"thinking" :`)) {
|
||||
return body
|
||||
}
|
||||
|
||||
@@ -86,6 +96,14 @@ func FilterThinkingBlocks(body []byte) []byte {
|
||||
return body // Return original on parse error
|
||||
}
|
||||
|
||||
// Check if thinking is enabled
|
||||
thinkingEnabled := false
|
||||
if thinking, ok := req["thinking"].(map[string]any); ok {
|
||||
if thinkType, ok := thinking["type"].(string); ok && thinkType == "enabled" {
|
||||
thinkingEnabled = true
|
||||
}
|
||||
}
|
||||
|
||||
messages, ok := req["messages"].([]any)
|
||||
if !ok {
|
||||
return body // No messages array
|
||||
@@ -98,6 +116,7 @@ func FilterThinkingBlocks(body []byte) []byte {
|
||||
continue
|
||||
}
|
||||
|
||||
role, _ := msgMap["role"].(string)
|
||||
content, ok := msgMap["content"].([]any)
|
||||
if !ok {
|
||||
continue
|
||||
@@ -106,6 +125,7 @@ func FilterThinkingBlocks(body []byte) []byte {
|
||||
// Filter thinking blocks from content array
|
||||
newContent := make([]any, 0, len(content))
|
||||
filteredThisMessage := false
|
||||
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
@@ -114,22 +134,34 @@ func FilterThinkingBlocks(body []byte) []byte {
|
||||
}
|
||||
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
// Explicit Anthropic-style thinking block: {"type":"thinking", ...}
|
||||
if blockType == "thinking" {
|
||||
|
||||
// Handle thinking/redacted_thinking blocks
|
||||
if blockType == "thinking" || blockType == "redacted_thinking" {
|
||||
// When thinking is enabled and this is an assistant message,
|
||||
// only keep thinking blocks with valid (non-empty, non-dummy) signatures
|
||||
if thinkingEnabled && role == "assistant" {
|
||||
signature, _ := blockMap["signature"].(string)
|
||||
// Keep blocks with valid signatures, remove those without
|
||||
if signature != "" && signature != "skip_thought_signature_validator" {
|
||||
newContent = append(newContent, block)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
filtered = true
|
||||
filteredThisMessage = true
|
||||
continue // Skip thinking blocks
|
||||
continue
|
||||
}
|
||||
|
||||
// Some clients send the "thinking" object without a "type" discriminator.
|
||||
// Vertex/Claude still expects a signature for any thinking block, so we drop it.
|
||||
// We intentionally do not drop other typed blocks (e.g. tool_use) that might
|
||||
// legitimately contain a "thinking" key inside their payload.
|
||||
if blockType == "" {
|
||||
if _, hasThinking := blockMap["thinking"]; hasThinking {
|
||||
if thinkingContent, hasThinking := blockMap["thinking"]; hasThinking {
|
||||
_ = thinkingContent
|
||||
filtered = true
|
||||
filteredThisMessage = true
|
||||
continue // Skip thinking blocks
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -905,7 +905,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
|
||||
case AccountTypeOAuth, AccountTypeSetupToken:
|
||||
// Both oauth and setup-token use OAuth token flow
|
||||
return s.getOAuthToken(ctx, account)
|
||||
case AccountTypeAPIKey:
|
||||
case AccountTypeApiKey:
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if apiKey == "" {
|
||||
return "", "", errors.New("api_key not found in credentials")
|
||||
@@ -976,7 +976,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
|
||||
// 应用模型映射(仅对apikey类型账号)
|
||||
originalModel := reqModel
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if account.Type == AccountTypeApiKey {
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
// 替换请求体中的模型名
|
||||
@@ -1013,8 +1013,37 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
|
||||
// 检查是否需要重试
|
||||
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
|
||||
// 优先检测thinking block签名错误(400)并重试一次
|
||||
if resp.StatusCode == 400 {
|
||||
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
if readErr == nil {
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if s.isThinkingBlockSignatureError(respBody) {
|
||||
log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID)
|
||||
|
||||
// 过滤thinking blocks并重试
|
||||
filteredBody := FilterThinkingBlocks(body)
|
||||
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
|
||||
if buildErr == nil {
|
||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||
if retryErr == nil {
|
||||
// 使用重试后的响应,继续后续处理
|
||||
resp = retryResp
|
||||
break
|
||||
}
|
||||
}
|
||||
// 重试失败,恢复原始响应体继续处理
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
break
|
||||
}
|
||||
// 不是thinking签名错误,恢复响应体
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否需要通用重试(排除400,因为400已经在上面特殊处理过了)
|
||||
if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
|
||||
if attempt < maxRetries {
|
||||
log.Printf("Account %d: upstream error %d, retry %d/%d after %v",
|
||||
account.ID, resp.StatusCode, attempt, maxRetries, retryDelay)
|
||||
@@ -1047,13 +1076,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
|
||||
// 处理错误响应(不可重试的错误)
|
||||
if resp.StatusCode >= 400 {
|
||||
// 可选:对部分 400 触发 failover(默认关闭以保持语义)
|
||||
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
// ReadAll failed, fall back to normal error handling without consuming the stream
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
if resp.StatusCode >= 400 {
|
||||
// 可选:对部分 400 触发 failover(默认关闭以保持语义)
|
||||
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
|
||||
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
if readErr != nil {
|
||||
// ReadAll failed, fall back to normal error handling without consuming the stream
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
@@ -1110,7 +1139,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
|
||||
// 确定目标URL
|
||||
targetURL := claudeAPIURL
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if account.Type == AccountTypeApiKey {
|
||||
baseURL := account.GetBaseURL()
|
||||
targetURL = baseURL + "/v1/messages"
|
||||
}
|
||||
@@ -1136,10 +1165,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
}
|
||||
|
||||
// Filter thinking blocks from request body (prevents 400 errors from missing/invalid signatures).
|
||||
// We apply this for the main /v1/messages path as well as count_tokens.
|
||||
body = FilterThinkingBlocks(body)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1178,10 +1203,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
// 处理anthropic-beta header(OAuth账号需要特殊处理)
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
|
||||
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
|
||||
if requestNeedsBetaFeatures(body) {
|
||||
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
|
||||
if beta := defaultApiKeyBetaHeader(body); beta != "" {
|
||||
req.Header.Set("anthropic-beta", beta)
|
||||
}
|
||||
}
|
||||
@@ -1248,12 +1273,12 @@ func requestNeedsBetaFeatures(body []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func defaultAPIKeyBetaHeader(body []byte) string {
|
||||
func defaultApiKeyBetaHeader(body []byte) string {
|
||||
modelID := gjson.GetBytes(body, "model").String()
|
||||
if strings.Contains(strings.ToLower(modelID), "haiku") {
|
||||
return claude.APIKeyHaikuBetaHeader
|
||||
return claude.ApiKeyHaikuBetaHeader
|
||||
}
|
||||
return claude.APIKeyBetaHeader
|
||||
return claude.ApiKeyBetaHeader
|
||||
}
|
||||
|
||||
func truncateForLog(b []byte, maxBytes int) string {
|
||||
@@ -1270,6 +1295,20 @@ func truncateForLog(b []byte, maxBytes int) string {
|
||||
return s
|
||||
}
|
||||
|
||||
// isThinkingBlockSignatureError 检测是否是thinking block签名错误
|
||||
// 这类错误可以通过过滤thinking blocks并重试来解决
|
||||
func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
|
||||
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
|
||||
if msg == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检测thinking block签名相关的错误
|
||||
// 例如: "Invalid `signature` in `thinking` block"
|
||||
return (strings.Contains(msg, "thinking") || strings.Contains(msg, "thought")) &&
|
||||
strings.Contains(msg, "signature")
|
||||
}
|
||||
|
||||
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
|
||||
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
|
||||
// 默认保守:无法识别则不切换。
|
||||
@@ -1318,7 +1357,13 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
// 处理上游错误,标记账号状态
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
shouldDisable := false
|
||||
if s.rateLimitService != nil {
|
||||
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
if shouldDisable {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
|
||||
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
|
||||
var errType, errMsg string
|
||||
@@ -1630,7 +1675,7 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
|
||||
// RecordUsageInput 记录使用量的输入参数
|
||||
type RecordUsageInput struct {
|
||||
Result *ForwardResult
|
||||
APIKey *APIKey
|
||||
ApiKey *ApiKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
@@ -1639,7 +1684,7 @@ type RecordUsageInput struct {
|
||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
|
||||
result := input.Result
|
||||
apiKey := input.APIKey
|
||||
apiKey := input.ApiKey
|
||||
user := input.User
|
||||
account := input.Account
|
||||
subscription := input.Subscription
|
||||
@@ -1676,7 +1721,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
Model: result.Model,
|
||||
@@ -1754,15 +1799,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
body := parsed.Body
|
||||
reqModel := parsed.Model
|
||||
|
||||
// Antigravity 账户不支持 count_tokens 转发,返回估算值
|
||||
// 参考 Antigravity-Manager 和 proxycast 实现
|
||||
// Antigravity 账户不支持 count_tokens 转发,直接返回空值
|
||||
if account.Platform == PlatformAntigravity {
|
||||
c.JSON(http.StatusOK, gin.H{"input_tokens": 100})
|
||||
c.JSON(http.StatusOK, gin.H{"input_tokens": 0})
|
||||
return nil
|
||||
}
|
||||
|
||||
// 应用模型映射(仅对 apikey 类型账号)
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if reqModel != "" {
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
@@ -1848,7 +1892,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
|
||||
// 确定目标 URL
|
||||
targetURL := claudeAPICountTokensURL
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if account.Type == AccountTypeApiKey {
|
||||
baseURL := account.GetBaseURL()
|
||||
targetURL = baseURL + "/v1/messages/count_tokens"
|
||||
}
|
||||
@@ -1866,9 +1910,6 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
}
|
||||
}
|
||||
|
||||
// Filter thinking blocks from request body (prevents 400 errors from invalid signatures)
|
||||
body = FilterThinkingBlocks(body)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1910,10 +1951,10 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
// OAuth 账号:处理 anthropic-beta header
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
|
||||
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
|
||||
if requestNeedsBetaFeatures(body) {
|
||||
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
|
||||
if beta := defaultApiKeyBetaHeader(body); beta != "" {
|
||||
req.Header.Set("anthropic-beta", beta)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user