perf(gateway): 优化负载感知调度
主要改进: - 优化负载感知调度的准确性和响应速度 - 将 AccountUsageService 的包级缓存改为依赖注入 - 修复 SSE/JSON 转义和 nil 安全问题 - 恢复 Google One 功能兼容性
This commit is contained in:
@@ -27,14 +27,8 @@ const (
|
|||||||
accountSlotKeyPrefix = "concurrency:account:"
|
accountSlotKeyPrefix = "concurrency:account:"
|
||||||
// 格式: concurrency:user:{userID}
|
// 格式: concurrency:user:{userID}
|
||||||
userSlotKeyPrefix = "concurrency:user:"
|
userSlotKeyPrefix = "concurrency:user:"
|
||||||
|
// 等待队列计数器格式: concurrency:wait:{userID}
|
||||||
// Wait queue keys (global structures)
|
waitQueueKeyPrefix = "concurrency:wait:"
|
||||||
// - total: integer total queue depth across all users
|
|
||||||
// - updated: sorted set of userID -> lastUpdateUnixSec (for TTL cleanup)
|
|
||||||
// - counts: hash of userID -> current wait count
|
|
||||||
waitQueueTotalKey = "concurrency:wait:total"
|
|
||||||
waitQueueUpdatedKey = "concurrency:wait:updated"
|
|
||||||
waitQueueCountsKey = "concurrency:wait:counts"
|
|
||||||
// 账号级等待队列计数器格式: wait:account:{accountID}
|
// 账号级等待队列计数器格式: wait:account:{accountID}
|
||||||
accountWaitKeyPrefix = "wait:account:"
|
accountWaitKeyPrefix = "wait:account:"
|
||||||
|
|
||||||
@@ -100,55 +94,27 @@ var (
|
|||||||
`)
|
`)
|
||||||
|
|
||||||
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
|
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
|
||||||
// KEYS[1] = total key
|
// KEYS[1] = wait queue key
|
||||||
// KEYS[2] = updated zset key
|
// ARGV[1] = maxWait
|
||||||
// KEYS[3] = counts hash key
|
// ARGV[2] = TTL in seconds
|
||||||
// ARGV[1] = userID
|
|
||||||
// ARGV[2] = maxWait
|
|
||||||
// ARGV[3] = TTL in seconds
|
|
||||||
// ARGV[4] = cleanup limit
|
|
||||||
incrementWaitScript = redis.NewScript(`
|
incrementWaitScript = redis.NewScript(`
|
||||||
local totalKey = KEYS[1]
|
local current = redis.call('GET', KEYS[1])
|
||||||
local updatedKey = KEYS[2]
|
if current == false then
|
||||||
local countsKey = KEYS[3]
|
current = 0
|
||||||
|
else
|
||||||
local userID = ARGV[1]
|
current = tonumber(current)
|
||||||
local maxWait = tonumber(ARGV[2])
|
|
||||||
local ttl = tonumber(ARGV[3])
|
|
||||||
local cleanupLimit = tonumber(ARGV[4])
|
|
||||||
|
|
||||||
redis.call('SETNX', totalKey, 0)
|
|
||||||
|
|
||||||
local timeResult = redis.call('TIME')
|
|
||||||
local now = tonumber(timeResult[1])
|
|
||||||
local expireBefore = now - ttl
|
|
||||||
|
|
||||||
-- Cleanup expired users (bounded)
|
|
||||||
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
|
|
||||||
for _, uid in ipairs(expired) do
|
|
||||||
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
|
|
||||||
if c > 0 then
|
|
||||||
redis.call('DECRBY', totalKey, c)
|
|
||||||
end
|
|
||||||
redis.call('HDEL', countsKey, uid)
|
|
||||||
redis.call('ZREM', updatedKey, uid)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
local current = tonumber(redis.call('HGET', countsKey, userID) or '0')
|
if current >= tonumber(ARGV[1]) then
|
||||||
if current >= maxWait then
|
|
||||||
return 0
|
return 0
|
||||||
end
|
end
|
||||||
|
|
||||||
local newVal = current + 1
|
local newVal = redis.call('INCR', KEYS[1])
|
||||||
redis.call('HSET', countsKey, userID, newVal)
|
|
||||||
redis.call('ZADD', updatedKey, now, userID)
|
|
||||||
redis.call('INCR', totalKey)
|
|
||||||
|
|
||||||
-- Keep global structures from living forever in totally idle deployments.
|
-- Only set TTL on first creation to avoid refreshing zombie data
|
||||||
local ttlKeep = ttl * 2
|
if newVal == 1 then
|
||||||
redis.call('EXPIRE', totalKey, ttlKeep)
|
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||||
redis.call('EXPIRE', updatedKey, ttlKeep)
|
end
|
||||||
redis.call('EXPIRE', countsKey, ttlKeep)
|
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
`)
|
`)
|
||||||
@@ -178,111 +144,6 @@ var (
|
|||||||
|
|
||||||
// decrementWaitScript - same as before
|
// decrementWaitScript - same as before
|
||||||
decrementWaitScript = redis.NewScript(`
|
decrementWaitScript = redis.NewScript(`
|
||||||
local totalKey = KEYS[1]
|
|
||||||
local updatedKey = KEYS[2]
|
|
||||||
local countsKey = KEYS[3]
|
|
||||||
|
|
||||||
local userID = ARGV[1]
|
|
||||||
local ttl = tonumber(ARGV[2])
|
|
||||||
local cleanupLimit = tonumber(ARGV[3])
|
|
||||||
|
|
||||||
redis.call('SETNX', totalKey, 0)
|
|
||||||
|
|
||||||
local timeResult = redis.call('TIME')
|
|
||||||
local now = tonumber(timeResult[1])
|
|
||||||
local expireBefore = now - ttl
|
|
||||||
|
|
||||||
-- Cleanup expired users (bounded)
|
|
||||||
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
|
|
||||||
for _, uid in ipairs(expired) do
|
|
||||||
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
|
|
||||||
if c > 0 then
|
|
||||||
redis.call('DECRBY', totalKey, c)
|
|
||||||
end
|
|
||||||
redis.call('HDEL', countsKey, uid)
|
|
||||||
redis.call('ZREM', updatedKey, uid)
|
|
||||||
end
|
|
||||||
|
|
||||||
local current = tonumber(redis.call('HGET', countsKey, userID) or '0')
|
|
||||||
if current <= 0 then
|
|
||||||
return 1
|
|
||||||
end
|
|
||||||
|
|
||||||
local newVal = current - 1
|
|
||||||
if newVal <= 0 then
|
|
||||||
redis.call('HDEL', countsKey, userID)
|
|
||||||
redis.call('ZREM', updatedKey, userID)
|
|
||||||
else
|
|
||||||
redis.call('HSET', countsKey, userID, newVal)
|
|
||||||
redis.call('ZADD', updatedKey, now, userID)
|
|
||||||
end
|
|
||||||
redis.call('DECR', totalKey)
|
|
||||||
|
|
||||||
local ttlKeep = ttl * 2
|
|
||||||
redis.call('EXPIRE', totalKey, ttlKeep)
|
|
||||||
redis.call('EXPIRE', updatedKey, ttlKeep)
|
|
||||||
redis.call('EXPIRE', countsKey, ttlKeep)
|
|
||||||
|
|
||||||
return 1
|
|
||||||
`)
|
|
||||||
|
|
||||||
// getTotalWaitScript returns the global wait depth with TTL cleanup.
|
|
||||||
// KEYS[1] = total key
|
|
||||||
// KEYS[2] = updated zset key
|
|
||||||
// KEYS[3] = counts hash key
|
|
||||||
// ARGV[1] = TTL in seconds
|
|
||||||
// ARGV[2] = cleanup limit
|
|
||||||
getTotalWaitScript = redis.NewScript(`
|
|
||||||
local totalKey = KEYS[1]
|
|
||||||
local updatedKey = KEYS[2]
|
|
||||||
local countsKey = KEYS[3]
|
|
||||||
|
|
||||||
local ttl = tonumber(ARGV[1])
|
|
||||||
local cleanupLimit = tonumber(ARGV[2])
|
|
||||||
|
|
||||||
redis.call('SETNX', totalKey, 0)
|
|
||||||
|
|
||||||
local timeResult = redis.call('TIME')
|
|
||||||
local now = tonumber(timeResult[1])
|
|
||||||
local expireBefore = now - ttl
|
|
||||||
|
|
||||||
-- Cleanup expired users (bounded)
|
|
||||||
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
|
|
||||||
for _, uid in ipairs(expired) do
|
|
||||||
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
|
|
||||||
if c > 0 then
|
|
||||||
redis.call('DECRBY', totalKey, c)
|
|
||||||
end
|
|
||||||
redis.call('HDEL', countsKey, uid)
|
|
||||||
redis.call('ZREM', updatedKey, uid)
|
|
||||||
end
|
|
||||||
|
|
||||||
-- If totalKey got lost but counts exist (e.g. Redis restart), recompute once.
|
|
||||||
local total = redis.call('GET', totalKey)
|
|
||||||
if total == false then
|
|
||||||
total = 0
|
|
||||||
local vals = redis.call('HVALS', countsKey)
|
|
||||||
for _, v in ipairs(vals) do
|
|
||||||
total = total + tonumber(v)
|
|
||||||
end
|
|
||||||
redis.call('SET', totalKey, total)
|
|
||||||
end
|
|
||||||
|
|
||||||
local ttlKeep = ttl * 2
|
|
||||||
redis.call('EXPIRE', totalKey, ttlKeep)
|
|
||||||
redis.call('EXPIRE', updatedKey, ttlKeep)
|
|
||||||
redis.call('EXPIRE', countsKey, ttlKeep)
|
|
||||||
|
|
||||||
local result = tonumber(redis.call('GET', totalKey) or '0')
|
|
||||||
if result < 0 then
|
|
||||||
result = 0
|
|
||||||
redis.call('SET', totalKey, 0)
|
|
||||||
end
|
|
||||||
return result
|
|
||||||
`)
|
|
||||||
|
|
||||||
// decrementAccountWaitScript - account-level wait queue decrement
|
|
||||||
decrementAccountWaitScript = redis.NewScript(`
|
|
||||||
local current = redis.call('GET', KEYS[1])
|
local current = redis.call('GET', KEYS[1])
|
||||||
if current ~= false and tonumber(current) > 0 then
|
if current ~= false and tonumber(current) > 0 then
|
||||||
redis.call('DECR', KEYS[1])
|
redis.call('DECR', KEYS[1])
|
||||||
@@ -383,9 +244,7 @@ func userSlotKey(userID int64) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func waitQueueKey(userID int64) string {
|
func waitQueueKey(userID int64) string {
|
||||||
// Historical: per-user string keys were used.
|
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||||
// Now we use global structures keyed by userID string.
|
|
||||||
return strconv.FormatInt(userID, 10)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func accountWaitKey(accountID int64) string {
|
func accountWaitKey(accountID int64) string {
|
||||||
@@ -449,16 +308,8 @@ func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64)
|
|||||||
// Wait queue operations
|
// Wait queue operations
|
||||||
|
|
||||||
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||||
userKey := waitQueueKey(userID)
|
key := waitQueueKey(userID)
|
||||||
result, err := incrementWaitScript.Run(
|
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
|
||||||
ctx,
|
|
||||||
c.rdb,
|
|
||||||
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
|
|
||||||
userKey,
|
|
||||||
maxWait,
|
|
||||||
c.waitQueueTTLSeconds,
|
|
||||||
200, // cleanup limit per call
|
|
||||||
).Int()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@@ -466,35 +317,11 @@ func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
|
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||||
userKey := waitQueueKey(userID)
|
key := waitQueueKey(userID)
|
||||||
_, err := decrementWaitScript.Run(
|
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||||
ctx,
|
|
||||||
c.rdb,
|
|
||||||
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
|
|
||||||
userKey,
|
|
||||||
c.waitQueueTTLSeconds,
|
|
||||||
200, // cleanup limit per call
|
|
||||||
).Result()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *concurrencyCache) GetTotalWaitCount(ctx context.Context) (int, error) {
|
|
||||||
if c.rdb == nil {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
total, err := getTotalWaitScript.Run(
|
|
||||||
ctx,
|
|
||||||
c.rdb,
|
|
||||||
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
|
|
||||||
c.waitQueueTTLSeconds,
|
|
||||||
500, // cleanup limit per query (rare)
|
|
||||||
).Int64()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return int(total), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Account wait queue operations
|
// Account wait queue operations
|
||||||
|
|
||||||
func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||||
@@ -508,7 +335,7 @@ func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accoun
|
|||||||
|
|
||||||
func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
|
func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
|
||||||
key := accountWaitKey(accountID)
|
key := accountWaitKey(accountID)
|
||||||
_, err := decrementAccountWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
|
|||||||
|
|
||||||
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
|
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
|
||||||
userID := int64(20)
|
userID := int64(20)
|
||||||
userKey := waitQueueKey(userID)
|
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||||
|
|
||||||
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||||
require.NoError(s.T(), err, "IncrementWaitCount 1")
|
require.NoError(s.T(), err, "IncrementWaitCount 1")
|
||||||
@@ -172,31 +172,31 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
|
|||||||
require.NoError(s.T(), err, "IncrementWaitCount 3")
|
require.NoError(s.T(), err, "IncrementWaitCount 3")
|
||||||
require.False(s.T(), ok, "expected wait increment over max to fail")
|
require.False(s.T(), ok, "expected wait increment over max to fail")
|
||||||
|
|
||||||
ttl, err := s.rdb.TTL(s.ctx, waitQueueTotalKey).Result()
|
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
|
||||||
require.NoError(s.T(), err, "TTL wait total key")
|
require.NoError(s.T(), err, "TTL waitKey")
|
||||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL*2)
|
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||||
|
|
||||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
|
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
|
||||||
|
|
||||||
val, err := s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Int()
|
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||||
require.NoError(s.T(), err, "HGET wait queue count")
|
if !errors.Is(err, redis.Nil) {
|
||||||
|
require.NoError(s.T(), err, "Get waitKey")
|
||||||
|
}
|
||||||
require.Equal(s.T(), 1, val, "expected wait count 1")
|
require.Equal(s.T(), 1, val, "expected wait count 1")
|
||||||
|
|
||||||
total, err := s.rdb.Get(s.ctx, waitQueueTotalKey).Int()
|
|
||||||
require.NoError(s.T(), err, "GET wait queue total")
|
|
||||||
require.Equal(s.T(), 1, total, "expected total wait count 1")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
|
func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
|
||||||
userID := int64(300)
|
userID := int64(300)
|
||||||
userKey := waitQueueKey(userID)
|
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||||
|
|
||||||
// Test decrement on non-existent key - should not error and should not create negative value
|
// Test decrement on non-existent key - should not error and should not create negative value
|
||||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key")
|
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key")
|
||||||
|
|
||||||
// Verify count remains zero / absent.
|
// Verify no key was created or it's not negative
|
||||||
val, err := s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Int()
|
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||||
require.True(s.T(), errors.Is(err, redis.Nil))
|
if !errors.Is(err, redis.Nil) {
|
||||||
|
require.NoError(s.T(), err, "Get waitKey")
|
||||||
|
}
|
||||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty")
|
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty")
|
||||||
|
|
||||||
// Set count to 1, then decrement twice
|
// Set count to 1, then decrement twice
|
||||||
@@ -210,15 +210,12 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
|
|||||||
// Decrement again on 0 - should not go negative
|
// Decrement again on 0 - should not go negative
|
||||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero")
|
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero")
|
||||||
|
|
||||||
// Verify per-user count is absent and total is non-negative.
|
// Verify count is 0, not negative
|
||||||
_, err = s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Result()
|
val, err = s.rdb.Get(s.ctx, waitKey).Int()
|
||||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected count field removed on zero")
|
|
||||||
|
|
||||||
total, err := s.rdb.Get(s.ctx, waitQueueTotalKey).Int()
|
|
||||||
if !errors.Is(err, redis.Nil) {
|
if !errors.Is(err, redis.Nil) {
|
||||||
require.NoError(s.T(), err)
|
require.NoError(s.T(), err, "Get waitKey after double decrement")
|
||||||
}
|
}
|
||||||
require.GreaterOrEqual(s.T(), total, 0, "expected non-negative total wait count")
|
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
|
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
|
||||||
|
|||||||
@@ -4,9 +4,11 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
)
|
)
|
||||||
@@ -17,11 +19,11 @@ type UsageLogRepository interface {
|
|||||||
Delete(ctx context.Context, id int64) error
|
Delete(ctx context.Context, id int64) error
|
||||||
|
|
||||||
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||||
ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||||
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||||
|
|
||||||
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||||
ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||||
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||||
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||||
|
|
||||||
@@ -32,10 +34,10 @@ type UsageLogRepository interface {
|
|||||||
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
||||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error)
|
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error)
|
||||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error)
|
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error)
|
||||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error)
|
||||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||||
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
|
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
|
||||||
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
|
GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error)
|
||||||
|
|
||||||
// User dashboard stats
|
// User dashboard stats
|
||||||
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
|
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
|
||||||
@@ -51,7 +53,7 @@ type UsageLogRepository interface {
|
|||||||
|
|
||||||
// Aggregated stats (optimized)
|
// Aggregated stats (optimized)
|
||||||
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||||
GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||||
GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||||
GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||||
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
|
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
|
||||||
@@ -69,13 +71,33 @@ type windowStatsCache struct {
|
|||||||
timestamp time.Time
|
timestamp time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
// antigravityUsageCache 缓存 Antigravity 额度数据
|
||||||
apiCacheMap = sync.Map{} // 缓存 API 响应
|
type antigravityUsageCache struct {
|
||||||
windowStatsCacheMap = sync.Map{} // 缓存窗口统计
|
usageInfo *UsageInfo
|
||||||
|
timestamp time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
apiCacheTTL = 10 * time.Minute
|
apiCacheTTL = 10 * time.Minute
|
||||||
windowStatsCacheTTL = 1 * time.Minute
|
windowStatsCacheTTL = 1 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// UsageCache 封装账户使用量相关的缓存
|
||||||
|
type UsageCache struct {
|
||||||
|
apiCache *sync.Map // accountID -> *apiUsageCache
|
||||||
|
windowStatsCache *sync.Map // accountID -> *windowStatsCache
|
||||||
|
antigravityCache *sync.Map // accountID -> *antigravityUsageCache
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUsageCache 创建 UsageCache 实例
|
||||||
|
func NewUsageCache() *UsageCache {
|
||||||
|
return &UsageCache{
|
||||||
|
apiCache: &sync.Map{},
|
||||||
|
antigravityCache: &sync.Map{},
|
||||||
|
windowStatsCache: &sync.Map{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WindowStats 窗口期统计
|
// WindowStats 窗口期统计
|
||||||
type WindowStats struct {
|
type WindowStats struct {
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
@@ -91,6 +113,12 @@ type UsageProgress struct {
|
|||||||
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
|
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AntigravityModelQuota Antigravity 单个模型的配额信息
|
||||||
|
type AntigravityModelQuota struct {
|
||||||
|
Utilization int `json:"utilization"` // 使用率 0-100
|
||||||
|
ResetTime string `json:"reset_time"` // 重置时间 ISO8601
|
||||||
|
}
|
||||||
|
|
||||||
// UsageInfo 账号使用量信息
|
// UsageInfo 账号使用量信息
|
||||||
type UsageInfo struct {
|
type UsageInfo struct {
|
||||||
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
||||||
@@ -99,6 +127,9 @@ type UsageInfo struct {
|
|||||||
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
|
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
|
||||||
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
|
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
|
||||||
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
|
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
|
||||||
|
|
||||||
|
// Antigravity 多模型配额
|
||||||
|
AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClaudeUsageResponse Anthropic API返回的usage结构
|
// ClaudeUsageResponse Anthropic API返回的usage结构
|
||||||
@@ -124,19 +155,51 @@ type ClaudeUsageFetcher interface {
|
|||||||
|
|
||||||
// AccountUsageService 账号使用量查询服务
|
// AccountUsageService 账号使用量查询服务
|
||||||
type AccountUsageService struct {
|
type AccountUsageService struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
usageLogRepo UsageLogRepository
|
usageLogRepo UsageLogRepository
|
||||||
usageFetcher ClaudeUsageFetcher
|
usageFetcher ClaudeUsageFetcher
|
||||||
geminiQuotaService *GeminiQuotaService
|
geminiQuotaService *GeminiQuotaService
|
||||||
|
antigravityQuotaFetcher QuotaFetcher
|
||||||
|
cache *UsageCache
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccountUsageService 创建AccountUsageService实例
|
// NewAccountUsageService 创建AccountUsageService实例
|
||||||
func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher, geminiQuotaService *GeminiQuotaService) *AccountUsageService {
|
func NewAccountUsageService(
|
||||||
|
accountRepo AccountRepository,
|
||||||
|
usageLogRepo UsageLogRepository,
|
||||||
|
usageFetcher ClaudeUsageFetcher,
|
||||||
|
geminiQuotaService *GeminiQuotaService,
|
||||||
|
antigravityQuotaFetcher *AntigravityQuotaFetcher,
|
||||||
|
cache *UsageCache,
|
||||||
|
) *AccountUsageService {
|
||||||
|
if cache == nil {
|
||||||
|
cache = &UsageCache{
|
||||||
|
apiCache: &sync.Map{},
|
||||||
|
antigravityCache: &sync.Map{},
|
||||||
|
windowStatsCache: &sync.Map{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cache.apiCache == nil {
|
||||||
|
cache.apiCache = &sync.Map{}
|
||||||
|
}
|
||||||
|
if cache.antigravityCache == nil {
|
||||||
|
cache.antigravityCache = &sync.Map{}
|
||||||
|
}
|
||||||
|
if cache.windowStatsCache == nil {
|
||||||
|
cache.windowStatsCache = &sync.Map{}
|
||||||
|
}
|
||||||
|
|
||||||
|
var quotaFetcher QuotaFetcher
|
||||||
|
if antigravityQuotaFetcher != nil {
|
||||||
|
quotaFetcher = antigravityQuotaFetcher
|
||||||
|
}
|
||||||
return &AccountUsageService{
|
return &AccountUsageService{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
usageLogRepo: usageLogRepo,
|
usageLogRepo: usageLogRepo,
|
||||||
usageFetcher: usageFetcher,
|
usageFetcher: usageFetcher,
|
||||||
geminiQuotaService: geminiQuotaService,
|
geminiQuotaService: geminiQuotaService,
|
||||||
|
antigravityQuotaFetcher: quotaFetcher,
|
||||||
|
cache: cache,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -154,12 +217,17 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
|||||||
return s.getGeminiUsage(ctx, account)
|
return s.getGeminiUsage(ctx, account)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Antigravity 平台:使用 AntigravityQuotaFetcher 获取额度
|
||||||
|
if account.Platform == PlatformAntigravity {
|
||||||
|
return s.getAntigravityUsage(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
// 只有oauth类型账号可以通过API获取usage(有profile scope)
|
// 只有oauth类型账号可以通过API获取usage(有profile scope)
|
||||||
if account.CanGetUsage() {
|
if account.CanGetUsage() {
|
||||||
var apiResp *ClaudeUsageResponse
|
var apiResp *ClaudeUsageResponse
|
||||||
|
|
||||||
// 1. 检查 API 缓存(10 分钟)
|
// 1. 检查 API 缓存(10 分钟)
|
||||||
if cached, ok := apiCacheMap.Load(accountID); ok {
|
if cached, ok := s.cache.apiCache.Load(accountID); ok {
|
||||||
if cache, ok := cached.(*apiUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
|
if cache, ok := cached.(*apiUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
|
||||||
apiResp = cache.response
|
apiResp = cache.response
|
||||||
}
|
}
|
||||||
@@ -172,7 +240,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// 缓存 API 响应
|
// 缓存 API 响应
|
||||||
apiCacheMap.Store(accountID, &apiUsageCache{
|
s.cache.apiCache.Store(accountID, &apiUsageCache{
|
||||||
response: apiResp,
|
response: apiResp,
|
||||||
timestamp: time.Now(),
|
timestamp: time.Now(),
|
||||||
})
|
})
|
||||||
@@ -224,12 +292,70 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
|
|||||||
totals := geminiAggregateUsage(stats)
|
totals := geminiAggregateUsage(stats)
|
||||||
resetAt := geminiDailyResetTime(now)
|
resetAt := geminiDailyResetTime(now)
|
||||||
|
|
||||||
usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost, now)
|
usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost)
|
||||||
usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost, now)
|
usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost)
|
||||||
|
|
||||||
return usage, nil
|
return usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getAntigravityUsage 获取 Antigravity 账户额度
|
||||||
|
func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
|
||||||
|
if s.antigravityQuotaFetcher == nil || !s.antigravityQuotaFetcher.CanFetch(account) {
|
||||||
|
now := time.Now()
|
||||||
|
return &UsageInfo{UpdatedAt: &now}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure project_id is stable for quota queries.
|
||||||
|
if strings.TrimSpace(account.GetCredential("project_id")) == "" {
|
||||||
|
projectID := antigravity.GenerateMockProjectID()
|
||||||
|
if account.Credentials == nil {
|
||||||
|
account.Credentials = map[string]any{}
|
||||||
|
}
|
||||||
|
account.Credentials["project_id"] = projectID
|
||||||
|
if s.accountRepo != nil {
|
||||||
|
_, err := s.accountRepo.BulkUpdate(ctx, []int64{account.ID}, AccountBulkUpdate{
|
||||||
|
Credentials: map[string]any{"project_id": projectID},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to persist antigravity project_id for account %d: %v", account.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. 检查缓存(10 分钟)
|
||||||
|
if cached, ok := s.cache.antigravityCache.Load(account.ID); ok {
|
||||||
|
if cache, ok := cached.(*antigravityUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
|
||||||
|
// 重新计算 RemainingSeconds
|
||||||
|
usage := cloneUsageInfo(cache.usageInfo)
|
||||||
|
if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil {
|
||||||
|
usage.FiveHour.RemainingSeconds = remainingSecondsUntil(*usage.FiveHour.ResetsAt)
|
||||||
|
}
|
||||||
|
return usage, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 获取代理 URL
|
||||||
|
proxyURL, err := s.antigravityQuotaFetcher.GetProxyURL(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to get proxy URL for account %d: %v", account.ID, err)
|
||||||
|
proxyURL = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 调用 API 获取额度
|
||||||
|
result, err := s.antigravityQuotaFetcher.FetchQuota(ctx, account, proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("fetch antigravity quota failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 缓存结果
|
||||||
|
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
|
||||||
|
usageInfo: result.UsageInfo,
|
||||||
|
timestamp: time.Now(),
|
||||||
|
})
|
||||||
|
|
||||||
|
return result.UsageInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
// addWindowStats 为 usage 数据添加窗口期统计
|
// addWindowStats 为 usage 数据添加窗口期统计
|
||||||
// 使用独立缓存(1 分钟),与 API 缓存分离
|
// 使用独立缓存(1 分钟),与 API 缓存分离
|
||||||
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) {
|
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) {
|
||||||
@@ -241,7 +367,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
|
|||||||
|
|
||||||
// 检查窗口统计缓存(1 分钟)
|
// 检查窗口统计缓存(1 分钟)
|
||||||
var windowStats *WindowStats
|
var windowStats *WindowStats
|
||||||
if cached, ok := windowStatsCacheMap.Load(account.ID); ok {
|
if cached, ok := s.cache.windowStatsCache.Load(account.ID); ok {
|
||||||
if cache, ok := cached.(*windowStatsCache); ok && time.Since(cache.timestamp) < windowStatsCacheTTL {
|
if cache, ok := cached.(*windowStatsCache); ok && time.Since(cache.timestamp) < windowStatsCacheTTL {
|
||||||
windowStats = cache.stats
|
windowStats = cache.stats
|
||||||
}
|
}
|
||||||
@@ -269,7 +395,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 缓存窗口统计(1 分钟)
|
// 缓存窗口统计(1 分钟)
|
||||||
windowStatsCacheMap.Store(account.ID, &windowStatsCache{
|
s.cache.windowStatsCache.Store(account.ID, &windowStatsCache{
|
||||||
stats: windowStats,
|
stats: windowStats,
|
||||||
timestamp: time.Now(),
|
timestamp: time.Now(),
|
||||||
})
|
})
|
||||||
@@ -342,12 +468,12 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA
|
|||||||
|
|
||||||
// 5小时窗口 - 始终创建对象(即使 ResetsAt 为空)
|
// 5小时窗口 - 始终创建对象(即使 ResetsAt 为空)
|
||||||
info.FiveHour = &UsageProgress{
|
info.FiveHour = &UsageProgress{
|
||||||
Utilization: resp.FiveHour.Utilization,
|
Utilization: clampFloat64(resp.FiveHour.Utilization, 0, 100),
|
||||||
}
|
}
|
||||||
if resp.FiveHour.ResetsAt != "" {
|
if resp.FiveHour.ResetsAt != "" {
|
||||||
if fiveHourReset, err := parseTime(resp.FiveHour.ResetsAt); err == nil {
|
if fiveHourReset, err := parseTime(resp.FiveHour.ResetsAt); err == nil {
|
||||||
info.FiveHour.ResetsAt = &fiveHourReset
|
info.FiveHour.ResetsAt = &fiveHourReset
|
||||||
info.FiveHour.RemainingSeconds = int(time.Until(fiveHourReset).Seconds())
|
info.FiveHour.RemainingSeconds = remainingSecondsUntil(fiveHourReset)
|
||||||
} else {
|
} else {
|
||||||
log.Printf("Failed to parse FiveHour.ResetsAt: %s, error: %v", resp.FiveHour.ResetsAt, err)
|
log.Printf("Failed to parse FiveHour.ResetsAt: %s, error: %v", resp.FiveHour.ResetsAt, err)
|
||||||
}
|
}
|
||||||
@@ -357,14 +483,14 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA
|
|||||||
if resp.SevenDay.ResetsAt != "" {
|
if resp.SevenDay.ResetsAt != "" {
|
||||||
if sevenDayReset, err := parseTime(resp.SevenDay.ResetsAt); err == nil {
|
if sevenDayReset, err := parseTime(resp.SevenDay.ResetsAt); err == nil {
|
||||||
info.SevenDay = &UsageProgress{
|
info.SevenDay = &UsageProgress{
|
||||||
Utilization: resp.SevenDay.Utilization,
|
Utilization: clampFloat64(resp.SevenDay.Utilization, 0, 100),
|
||||||
ResetsAt: &sevenDayReset,
|
ResetsAt: &sevenDayReset,
|
||||||
RemainingSeconds: int(time.Until(sevenDayReset).Seconds()),
|
RemainingSeconds: remainingSecondsUntil(sevenDayReset),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Printf("Failed to parse SevenDay.ResetsAt: %s, error: %v", resp.SevenDay.ResetsAt, err)
|
log.Printf("Failed to parse SevenDay.ResetsAt: %s, error: %v", resp.SevenDay.ResetsAt, err)
|
||||||
info.SevenDay = &UsageProgress{
|
info.SevenDay = &UsageProgress{
|
||||||
Utilization: resp.SevenDay.Utilization,
|
Utilization: clampFloat64(resp.SevenDay.Utilization, 0, 100),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -373,14 +499,14 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA
|
|||||||
if resp.SevenDaySonnet.ResetsAt != "" {
|
if resp.SevenDaySonnet.ResetsAt != "" {
|
||||||
if sonnetReset, err := parseTime(resp.SevenDaySonnet.ResetsAt); err == nil {
|
if sonnetReset, err := parseTime(resp.SevenDaySonnet.ResetsAt); err == nil {
|
||||||
info.SevenDaySonnet = &UsageProgress{
|
info.SevenDaySonnet = &UsageProgress{
|
||||||
Utilization: resp.SevenDaySonnet.Utilization,
|
Utilization: clampFloat64(resp.SevenDaySonnet.Utilization, 0, 100),
|
||||||
ResetsAt: &sonnetReset,
|
ResetsAt: &sonnetReset,
|
||||||
RemainingSeconds: int(time.Until(sonnetReset).Seconds()),
|
RemainingSeconds: remainingSecondsUntil(sonnetReset),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Printf("Failed to parse SevenDaySonnet.ResetsAt: %s, error: %v", resp.SevenDaySonnet.ResetsAt, err)
|
log.Printf("Failed to parse SevenDaySonnet.ResetsAt: %s, error: %v", resp.SevenDaySonnet.ResetsAt, err)
|
||||||
info.SevenDaySonnet = &UsageProgress{
|
info.SevenDaySonnet = &UsageProgress{
|
||||||
Utilization: resp.SevenDaySonnet.Utilization,
|
Utilization: clampFloat64(resp.SevenDaySonnet.Utilization, 0, 100),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -394,10 +520,7 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
|
|||||||
|
|
||||||
// 如果有session_window信息
|
// 如果有session_window信息
|
||||||
if account.SessionWindowEnd != nil {
|
if account.SessionWindowEnd != nil {
|
||||||
remaining := int(time.Until(*account.SessionWindowEnd).Seconds())
|
remaining := remainingSecondsUntil(*account.SessionWindowEnd)
|
||||||
if remaining < 0 {
|
|
||||||
remaining = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// 根据状态估算使用率 (百分比形式,100 = 100%)
|
// 根据状态估算使用率 (百分比形式,100 = 100%)
|
||||||
var utilization float64
|
var utilization float64
|
||||||
@@ -409,6 +532,7 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
|
|||||||
default:
|
default:
|
||||||
utilization = 0.0
|
utilization = 0.0
|
||||||
}
|
}
|
||||||
|
utilization = clampFloat64(utilization, 0, 100)
|
||||||
|
|
||||||
info.FiveHour = &UsageProgress{
|
info.FiveHour = &UsageProgress{
|
||||||
Utilization: utilization,
|
Utilization: utilization,
|
||||||
@@ -427,15 +551,12 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
|
|||||||
return info
|
return info
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64, now time.Time) *UsageProgress {
|
func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64) *UsageProgress {
|
||||||
if limit <= 0 {
|
if limit <= 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
utilization := (float64(used) / float64(limit)) * 100
|
utilization := clampFloat64((float64(used)/float64(limit))*100, 0, 100)
|
||||||
remainingSeconds := int(resetAt.Sub(now).Seconds())
|
remainingSeconds := remainingSecondsUntil(resetAt)
|
||||||
if remainingSeconds < 0 {
|
|
||||||
remainingSeconds = 0
|
|
||||||
}
|
|
||||||
resetCopy := resetAt
|
resetCopy := resetAt
|
||||||
return &UsageProgress{
|
return &UsageProgress{
|
||||||
Utilization: utilization,
|
Utilization: utilization,
|
||||||
@@ -448,3 +569,47 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cloneUsageInfo(src *UsageInfo) *UsageInfo {
|
||||||
|
if src == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
dst := *src
|
||||||
|
if src.UpdatedAt != nil {
|
||||||
|
t := *src.UpdatedAt
|
||||||
|
dst.UpdatedAt = &t
|
||||||
|
}
|
||||||
|
dst.FiveHour = cloneUsageProgress(src.FiveHour)
|
||||||
|
dst.SevenDay = cloneUsageProgress(src.SevenDay)
|
||||||
|
dst.SevenDaySonnet = cloneUsageProgress(src.SevenDaySonnet)
|
||||||
|
dst.GeminiProDaily = cloneUsageProgress(src.GeminiProDaily)
|
||||||
|
dst.GeminiFlashDaily = cloneUsageProgress(src.GeminiFlashDaily)
|
||||||
|
if src.AntigravityQuota != nil {
|
||||||
|
dst.AntigravityQuota = make(map[string]*AntigravityModelQuota, len(src.AntigravityQuota))
|
||||||
|
for k, v := range src.AntigravityQuota {
|
||||||
|
if v == nil {
|
||||||
|
dst.AntigravityQuota[k] = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
copyVal := *v
|
||||||
|
dst.AntigravityQuota[k] = ©Val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &dst
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneUsageProgress(src *UsageProgress) *UsageProgress {
|
||||||
|
if src == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
dst := *src
|
||||||
|
if src.ResetsAt != nil {
|
||||||
|
t := *src.ResetsAt
|
||||||
|
dst.ResetsAt = &t
|
||||||
|
}
|
||||||
|
if src.WindowStats != nil {
|
||||||
|
statsCopy := *src.WindowStats
|
||||||
|
dst.WindowStats = &statsCopy
|
||||||
|
}
|
||||||
|
return &dst
|
||||||
|
}
|
||||||
|
|||||||
@@ -32,7 +32,6 @@ type ConcurrencyCache interface {
|
|||||||
// 等待队列计数(只在首次创建时设置 TTL)
|
// 等待队列计数(只在首次创建时设置 TTL)
|
||||||
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
|
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
|
||||||
DecrementWaitCount(ctx context.Context, userID int64) error
|
DecrementWaitCount(ctx context.Context, userID int64) error
|
||||||
GetTotalWaitCount(ctx context.Context) (int, error)
|
|
||||||
|
|
||||||
// 批量负载查询(只读)
|
// 批量负载查询(只读)
|
||||||
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
|
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
|
||||||
@@ -201,14 +200,6 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTotalWaitCount returns the total wait queue depth across users.
|
|
||||||
func (s *ConcurrencyService) GetTotalWaitCount(ctx context.Context) (int, error) {
|
|
||||||
if s.cache == nil {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
return s.cache.GetTotalWaitCount(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// IncrementAccountWaitCount increments the wait queue counter for an account.
|
// IncrementAccountWaitCount increments the wait queue counter for an account.
|
||||||
func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||||
if s.cache == nil {
|
if s.cache == nil {
|
||||||
|
|||||||
@@ -136,6 +136,12 @@ func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int6
|
|||||||
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (m *mockAccountRepoForPlatform) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (m *mockAccountRepoForPlatform) ClearTempUnschedulable(ctx context.Context, id int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int64) error {
|
func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int64) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -276,7 +282,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(
|
|||||||
|
|
||||||
repo := &mockAccountRepoForPlatform{
|
repo := &mockAccountRepoForPlatform{
|
||||||
accounts: []Account{
|
accounts: []Account{
|
||||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
|
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
|
||||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
|
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
|
||||||
},
|
},
|
||||||
accountsByID: map[int64]*Account{},
|
accountsByID: map[int64]*Account{},
|
||||||
@@ -617,7 +623,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
|
|||||||
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
|
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
|
||||||
repo := &mockAccountRepoForPlatform{
|
repo := &mockAccountRepoForPlatform{
|
||||||
accounts: []Account{
|
accounts: []Account{
|
||||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
|
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
|
||||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
|
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
|
||||||
},
|
},
|
||||||
accountsByID: map[int64]*Account{},
|
accountsByID: map[int64]*Account{},
|
||||||
|
|||||||
@@ -75,9 +75,19 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
|
|||||||
// FilterThinkingBlocks removes thinking blocks from request body
|
// FilterThinkingBlocks removes thinking blocks from request body
|
||||||
// Returns filtered body or original body if filtering fails (fail-safe)
|
// Returns filtered body or original body if filtering fails (fail-safe)
|
||||||
// This prevents 400 errors from invalid thinking block signatures
|
// This prevents 400 errors from invalid thinking block signatures
|
||||||
|
//
|
||||||
|
// Strategy:
|
||||||
|
// - When thinking.type != "enabled": Remove all thinking blocks
|
||||||
|
// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures
|
||||||
|
// (blocks with missing/empty/dummy signatures that would cause 400 errors)
|
||||||
func FilterThinkingBlocks(body []byte) []byte {
|
func FilterThinkingBlocks(body []byte) []byte {
|
||||||
// Fast path: if body doesn't contain "thinking", skip parsing
|
// Fast path: if body doesn't contain "thinking", skip parsing
|
||||||
if !bytes.Contains(body, []byte("thinking")) {
|
if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
|
||||||
|
!bytes.Contains(body, []byte(`"type": "thinking"`)) &&
|
||||||
|
!bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) &&
|
||||||
|
!bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) &&
|
||||||
|
!bytes.Contains(body, []byte(`"thinking":`)) &&
|
||||||
|
!bytes.Contains(body, []byte(`"thinking" :`)) {
|
||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -86,6 +96,14 @@ func FilterThinkingBlocks(body []byte) []byte {
|
|||||||
return body // Return original on parse error
|
return body // Return original on parse error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if thinking is enabled
|
||||||
|
thinkingEnabled := false
|
||||||
|
if thinking, ok := req["thinking"].(map[string]any); ok {
|
||||||
|
if thinkType, ok := thinking["type"].(string); ok && thinkType == "enabled" {
|
||||||
|
thinkingEnabled = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
messages, ok := req["messages"].([]any)
|
messages, ok := req["messages"].([]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return body // No messages array
|
return body // No messages array
|
||||||
@@ -98,6 +116,7 @@ func FilterThinkingBlocks(body []byte) []byte {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
role, _ := msgMap["role"].(string)
|
||||||
content, ok := msgMap["content"].([]any)
|
content, ok := msgMap["content"].([]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
@@ -106,6 +125,7 @@ func FilterThinkingBlocks(body []byte) []byte {
|
|||||||
// Filter thinking blocks from content array
|
// Filter thinking blocks from content array
|
||||||
newContent := make([]any, 0, len(content))
|
newContent := make([]any, 0, len(content))
|
||||||
filteredThisMessage := false
|
filteredThisMessage := false
|
||||||
|
|
||||||
for _, block := range content {
|
for _, block := range content {
|
||||||
blockMap, ok := block.(map[string]any)
|
blockMap, ok := block.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -114,22 +134,34 @@ func FilterThinkingBlocks(body []byte) []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
blockType, _ := blockMap["type"].(string)
|
blockType, _ := blockMap["type"].(string)
|
||||||
// Explicit Anthropic-style thinking block: {"type":"thinking", ...}
|
|
||||||
if blockType == "thinking" {
|
// Handle thinking/redacted_thinking blocks
|
||||||
|
if blockType == "thinking" || blockType == "redacted_thinking" {
|
||||||
|
// When thinking is enabled and this is an assistant message,
|
||||||
|
// only keep thinking blocks with valid (non-empty, non-dummy) signatures
|
||||||
|
if thinkingEnabled && role == "assistant" {
|
||||||
|
signature, _ := blockMap["signature"].(string)
|
||||||
|
// Keep blocks with valid signatures, remove those without
|
||||||
|
if signature != "" && signature != "skip_thought_signature_validator" {
|
||||||
|
newContent = append(newContent, block)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
filtered = true
|
filtered = true
|
||||||
filteredThisMessage = true
|
filteredThisMessage = true
|
||||||
continue // Skip thinking blocks
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Some clients send the "thinking" object without a "type" discriminator.
|
// Some clients send the "thinking" object without a "type" discriminator.
|
||||||
// Vertex/Claude still expects a signature for any thinking block, so we drop it.
|
|
||||||
// We intentionally do not drop other typed blocks (e.g. tool_use) that might
|
// We intentionally do not drop other typed blocks (e.g. tool_use) that might
|
||||||
// legitimately contain a "thinking" key inside their payload.
|
// legitimately contain a "thinking" key inside their payload.
|
||||||
if blockType == "" {
|
if blockType == "" {
|
||||||
if _, hasThinking := blockMap["thinking"]; hasThinking {
|
if thinkingContent, hasThinking := blockMap["thinking"]; hasThinking {
|
||||||
|
_ = thinkingContent
|
||||||
filtered = true
|
filtered = true
|
||||||
filteredThisMessage = true
|
filteredThisMessage = true
|
||||||
continue // Skip thinking blocks
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -905,7 +905,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
|
|||||||
case AccountTypeOAuth, AccountTypeSetupToken:
|
case AccountTypeOAuth, AccountTypeSetupToken:
|
||||||
// Both oauth and setup-token use OAuth token flow
|
// Both oauth and setup-token use OAuth token flow
|
||||||
return s.getOAuthToken(ctx, account)
|
return s.getOAuthToken(ctx, account)
|
||||||
case AccountTypeAPIKey:
|
case AccountTypeApiKey:
|
||||||
apiKey := account.GetCredential("api_key")
|
apiKey := account.GetCredential("api_key")
|
||||||
if apiKey == "" {
|
if apiKey == "" {
|
||||||
return "", "", errors.New("api_key not found in credentials")
|
return "", "", errors.New("api_key not found in credentials")
|
||||||
@@ -976,7 +976,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
|
|
||||||
// 应用模型映射(仅对apikey类型账号)
|
// 应用模型映射(仅对apikey类型账号)
|
||||||
originalModel := reqModel
|
originalModel := reqModel
|
||||||
if account.Type == AccountTypeAPIKey {
|
if account.Type == AccountTypeApiKey {
|
||||||
mappedModel := account.GetMappedModel(reqModel)
|
mappedModel := account.GetMappedModel(reqModel)
|
||||||
if mappedModel != reqModel {
|
if mappedModel != reqModel {
|
||||||
// 替换请求体中的模型名
|
// 替换请求体中的模型名
|
||||||
@@ -1013,8 +1013,37 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查是否需要重试
|
// 优先检测thinking block签名错误(400)并重试一次
|
||||||
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
|
if resp.StatusCode == 400 {
|
||||||
|
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
if readErr == nil {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
|
if s.isThinkingBlockSignatureError(respBody) {
|
||||||
|
log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID)
|
||||||
|
|
||||||
|
// 过滤thinking blocks并重试
|
||||||
|
filteredBody := FilterThinkingBlocks(body)
|
||||||
|
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
|
||||||
|
if buildErr == nil {
|
||||||
|
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||||
|
if retryErr == nil {
|
||||||
|
// 使用重试后的响应,继续后续处理
|
||||||
|
resp = retryResp
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 重试失败,恢复原始响应体继续处理
|
||||||
|
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// 不是thinking签名错误,恢复响应体
|
||||||
|
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否需要通用重试(排除400,因为400已经在上面特殊处理过了)
|
||||||
|
if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
|
||||||
if attempt < maxRetries {
|
if attempt < maxRetries {
|
||||||
log.Printf("Account %d: upstream error %d, retry %d/%d after %v",
|
log.Printf("Account %d: upstream error %d, retry %d/%d after %v",
|
||||||
account.ID, resp.StatusCode, attempt, maxRetries, retryDelay)
|
account.ID, resp.StatusCode, attempt, maxRetries, retryDelay)
|
||||||
@@ -1047,13 +1076,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 处理错误响应(不可重试的错误)
|
// 处理错误响应(不可重试的错误)
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
// 可选:对部分 400 触发 failover(默认关闭以保持语义)
|
// 可选:对部分 400 触发 failover(默认关闭以保持语义)
|
||||||
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
|
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
|
||||||
respBody, readErr := io.ReadAll(resp.Body)
|
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
if readErr != nil {
|
if readErr != nil {
|
||||||
// ReadAll failed, fall back to normal error handling without consuming the stream
|
// ReadAll failed, fall back to normal error handling without consuming the stream
|
||||||
return s.handleErrorResponse(ctx, resp, c, account)
|
return s.handleErrorResponse(ctx, resp, c, account)
|
||||||
}
|
}
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||||
@@ -1110,7 +1139,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
|
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
|
||||||
// 确定目标URL
|
// 确定目标URL
|
||||||
targetURL := claudeAPIURL
|
targetURL := claudeAPIURL
|
||||||
if account.Type == AccountTypeAPIKey {
|
if account.Type == AccountTypeApiKey {
|
||||||
baseURL := account.GetBaseURL()
|
baseURL := account.GetBaseURL()
|
||||||
targetURL = baseURL + "/v1/messages"
|
targetURL = baseURL + "/v1/messages"
|
||||||
}
|
}
|
||||||
@@ -1136,10 +1165,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter thinking blocks from request body (prevents 400 errors from missing/invalid signatures).
|
|
||||||
// We apply this for the main /v1/messages path as well as count_tokens.
|
|
||||||
body = FilterThinkingBlocks(body)
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -1178,10 +1203,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
// 处理anthropic-beta header(OAuth账号需要特殊处理)
|
// 处理anthropic-beta header(OAuth账号需要特殊处理)
|
||||||
if tokenType == "oauth" {
|
if tokenType == "oauth" {
|
||||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
|
||||||
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
|
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
|
||||||
if requestNeedsBetaFeatures(body) {
|
if requestNeedsBetaFeatures(body) {
|
||||||
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
|
if beta := defaultApiKeyBetaHeader(body); beta != "" {
|
||||||
req.Header.Set("anthropic-beta", beta)
|
req.Header.Set("anthropic-beta", beta)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1248,12 +1273,12 @@ func requestNeedsBetaFeatures(body []byte) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func defaultAPIKeyBetaHeader(body []byte) string {
|
func defaultApiKeyBetaHeader(body []byte) string {
|
||||||
modelID := gjson.GetBytes(body, "model").String()
|
modelID := gjson.GetBytes(body, "model").String()
|
||||||
if strings.Contains(strings.ToLower(modelID), "haiku") {
|
if strings.Contains(strings.ToLower(modelID), "haiku") {
|
||||||
return claude.APIKeyHaikuBetaHeader
|
return claude.ApiKeyHaikuBetaHeader
|
||||||
}
|
}
|
||||||
return claude.APIKeyBetaHeader
|
return claude.ApiKeyBetaHeader
|
||||||
}
|
}
|
||||||
|
|
||||||
func truncateForLog(b []byte, maxBytes int) string {
|
func truncateForLog(b []byte, maxBytes int) string {
|
||||||
@@ -1270,6 +1295,20 @@ func truncateForLog(b []byte, maxBytes int) string {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isThinkingBlockSignatureError 检测是否是thinking block签名错误
|
||||||
|
// 这类错误可以通过过滤thinking blocks并重试来解决
|
||||||
|
func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
|
||||||
|
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
|
||||||
|
if msg == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检测thinking block签名相关的错误
|
||||||
|
// 例如: "Invalid `signature` in `thinking` block"
|
||||||
|
return (strings.Contains(msg, "thinking") || strings.Contains(msg, "thought")) &&
|
||||||
|
strings.Contains(msg, "signature")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
|
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
|
||||||
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
|
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
|
||||||
// 默认保守:无法识别则不切换。
|
// 默认保守:无法识别则不切换。
|
||||||
@@ -1318,7 +1357,13 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
|
|||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
// 处理上游错误,标记账号状态
|
// 处理上游错误,标记账号状态
|
||||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
shouldDisable := false
|
||||||
|
if s.rateLimitService != nil {
|
||||||
|
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||||
|
}
|
||||||
|
if shouldDisable {
|
||||||
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||||
|
}
|
||||||
|
|
||||||
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
|
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
|
||||||
var errType, errMsg string
|
var errType, errMsg string
|
||||||
@@ -1630,7 +1675,7 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
|
|||||||
// RecordUsageInput 记录使用量的输入参数
|
// RecordUsageInput 记录使用量的输入参数
|
||||||
type RecordUsageInput struct {
|
type RecordUsageInput struct {
|
||||||
Result *ForwardResult
|
Result *ForwardResult
|
||||||
APIKey *APIKey
|
ApiKey *ApiKey
|
||||||
User *User
|
User *User
|
||||||
Account *Account
|
Account *Account
|
||||||
Subscription *UserSubscription // 可选:订阅信息
|
Subscription *UserSubscription // 可选:订阅信息
|
||||||
@@ -1639,7 +1684,7 @@ type RecordUsageInput struct {
|
|||||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||||
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
|
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
|
||||||
result := input.Result
|
result := input.Result
|
||||||
apiKey := input.APIKey
|
apiKey := input.ApiKey
|
||||||
user := input.User
|
user := input.User
|
||||||
account := input.Account
|
account := input.Account
|
||||||
subscription := input.Subscription
|
subscription := input.Subscription
|
||||||
@@ -1676,7 +1721,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
durationMs := int(result.Duration.Milliseconds())
|
durationMs := int(result.Duration.Milliseconds())
|
||||||
usageLog := &UsageLog{
|
usageLog := &UsageLog{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
APIKeyID: apiKey.ID,
|
ApiKeyID: apiKey.ID,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
RequestID: result.RequestID,
|
RequestID: result.RequestID,
|
||||||
Model: result.Model,
|
Model: result.Model,
|
||||||
@@ -1754,15 +1799,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
body := parsed.Body
|
body := parsed.Body
|
||||||
reqModel := parsed.Model
|
reqModel := parsed.Model
|
||||||
|
|
||||||
// Antigravity 账户不支持 count_tokens 转发,返回估算值
|
// Antigravity 账户不支持 count_tokens 转发,直接返回空值
|
||||||
// 参考 Antigravity-Manager 和 proxycast 实现
|
|
||||||
if account.Platform == PlatformAntigravity {
|
if account.Platform == PlatformAntigravity {
|
||||||
c.JSON(http.StatusOK, gin.H{"input_tokens": 100})
|
c.JSON(http.StatusOK, gin.H{"input_tokens": 0})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 应用模型映射(仅对 apikey 类型账号)
|
// 应用模型映射(仅对 apikey 类型账号)
|
||||||
if account.Type == AccountTypeAPIKey {
|
if account.Type == AccountTypeApiKey {
|
||||||
if reqModel != "" {
|
if reqModel != "" {
|
||||||
mappedModel := account.GetMappedModel(reqModel)
|
mappedModel := account.GetMappedModel(reqModel)
|
||||||
if mappedModel != reqModel {
|
if mappedModel != reqModel {
|
||||||
@@ -1848,7 +1892,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
|
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
|
||||||
// 确定目标 URL
|
// 确定目标 URL
|
||||||
targetURL := claudeAPICountTokensURL
|
targetURL := claudeAPICountTokensURL
|
||||||
if account.Type == AccountTypeAPIKey {
|
if account.Type == AccountTypeApiKey {
|
||||||
baseURL := account.GetBaseURL()
|
baseURL := account.GetBaseURL()
|
||||||
targetURL = baseURL + "/v1/messages/count_tokens"
|
targetURL = baseURL + "/v1/messages/count_tokens"
|
||||||
}
|
}
|
||||||
@@ -1866,9 +1910,6 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter thinking blocks from request body (prevents 400 errors from invalid signatures)
|
|
||||||
body = FilterThinkingBlocks(body)
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -1910,10 +1951,10 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
|||||||
// OAuth 账号:处理 anthropic-beta header
|
// OAuth 账号:处理 anthropic-beta header
|
||||||
if tokenType == "oauth" {
|
if tokenType == "oauth" {
|
||||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
|
||||||
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
|
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
|
||||||
if requestNeedsBetaFeatures(body) {
|
if requestNeedsBetaFeatures(body) {
|
||||||
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
|
if beta := defaultApiKeyBetaHeader(body); beta != "" {
|
||||||
req.Header.Set("anthropic-beta", beta)
|
req.Header.Set("anthropic-beta", beta)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user