feat(antigravity): comprehensive enhancements - model mapping, rate limiting, scheduling & ops
Key changes: - Upgrade model mapping: Opus 4.5 → Opus 4.6-thinking with precise matching - Unified rate limiting: scope-level → model-level with Redis snapshot sync - Load-balanced scheduling by call count with smart retry mechanism - Force cache billing support - Model identity injection in prompts with leak prevention - Thinking mode auto-handling (max_tokens/budget_tokens fix) - Frontend: whitelist mode toggle, model mapping validation, status indicators - Gemini session fallback with Redis Trie O(L) matching - Ops: enhanced concurrency monitoring, account availability, retry logic - Migration scripts: 049-051 for model mapping unification
This commit is contained in:
@@ -194,6 +194,53 @@ var (
|
||||
return result
|
||||
`)
|
||||
|
||||
// getUsersLoadBatchScript - batch load query for users with expired slot cleanup
|
||||
// ARGV[1] = slot TTL (seconds)
|
||||
// ARGV[2..n] = userID1, maxConcurrency1, userID2, maxConcurrency2, ...
|
||||
getUsersLoadBatchScript = redis.NewScript(`
|
||||
local result = {}
|
||||
local slotTTL = tonumber(ARGV[1])
|
||||
|
||||
-- Get current server time
|
||||
local timeResult = redis.call('TIME')
|
||||
local nowSeconds = tonumber(timeResult[1])
|
||||
local cutoffTime = nowSeconds - slotTTL
|
||||
|
||||
local i = 2
|
||||
while i <= #ARGV do
|
||||
local userID = ARGV[i]
|
||||
local maxConcurrency = tonumber(ARGV[i + 1])
|
||||
|
||||
local slotKey = 'concurrency:user:' .. userID
|
||||
|
||||
-- Clean up expired slots before counting
|
||||
redis.call('ZREMRANGEBYSCORE', slotKey, '-inf', cutoffTime)
|
||||
local currentConcurrency = redis.call('ZCARD', slotKey)
|
||||
|
||||
local waitKey = 'concurrency:wait:' .. userID
|
||||
local waitingCount = redis.call('GET', waitKey)
|
||||
if waitingCount == false then
|
||||
waitingCount = 0
|
||||
else
|
||||
waitingCount = tonumber(waitingCount)
|
||||
end
|
||||
|
||||
local loadRate = 0
|
||||
if maxConcurrency > 0 then
|
||||
loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency)
|
||||
end
|
||||
|
||||
table.insert(result, userID)
|
||||
table.insert(result, currentConcurrency)
|
||||
table.insert(result, waitingCount)
|
||||
table.insert(result, loadRate)
|
||||
|
||||
i = i + 2
|
||||
end
|
||||
|
||||
return result
|
||||
`)
|
||||
|
||||
// cleanupExpiredSlotsScript - remove expired slots
|
||||
// KEYS[1] = concurrency:account:{accountID}
|
||||
// ARGV[1] = TTL (seconds)
|
||||
@@ -384,6 +431,43 @@ func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []
|
||||
return loadMap, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
|
||||
if len(users) == 0 {
|
||||
return map[int64]*service.UserLoadInfo{}, nil
|
||||
}
|
||||
|
||||
args := []any{c.slotTTLSeconds}
|
||||
for _, u := range users {
|
||||
args = append(args, u.ID, u.MaxConcurrency)
|
||||
}
|
||||
|
||||
result, err := getUsersLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
loadMap := make(map[int64]*service.UserLoadInfo)
|
||||
for i := 0; i < len(result); i += 4 {
|
||||
if i+3 >= len(result) {
|
||||
break
|
||||
}
|
||||
|
||||
userID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64)
|
||||
currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1]))
|
||||
waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2]))
|
||||
loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3]))
|
||||
|
||||
loadMap[userID] = &service.UserLoadInfo{
|
||||
UserID: userID,
|
||||
CurrentConcurrency: currentConcurrency,
|
||||
WaitingCount: waitingCount,
|
||||
LoadRate: loadRate,
|
||||
}
|
||||
}
|
||||
|
||||
return loadMap, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||||
key := accountSlotKey(accountID)
|
||||
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
|
||||
|
||||
@@ -11,6 +11,63 @@ import (
|
||||
|
||||
const stickySessionPrefix = "sticky_session:"
|
||||
|
||||
// Gemini Trie Lua 脚本
|
||||
const (
|
||||
// geminiTrieFindScript 查找最长前缀匹配的 Lua 脚本
|
||||
// KEYS[1] = trie key
|
||||
// ARGV[1] = digestChain (如 "u:a-m:b-u:c-m:d")
|
||||
// ARGV[2] = TTL seconds (用于刷新)
|
||||
// 返回: 最长匹配的 value (uuid:accountID) 或 nil
|
||||
// 查找成功时自动刷新 TTL,防止活跃会话意外过期
|
||||
geminiTrieFindScript = `
|
||||
local chain = ARGV[1]
|
||||
local ttl = tonumber(ARGV[2])
|
||||
local lastMatch = nil
|
||||
local path = ""
|
||||
|
||||
for part in string.gmatch(chain, "[^-]+") do
|
||||
path = path == "" and part or path .. "-" .. part
|
||||
local val = redis.call('HGET', KEYS[1], path)
|
||||
if val and val ~= "" then
|
||||
lastMatch = val
|
||||
end
|
||||
end
|
||||
|
||||
if lastMatch then
|
||||
redis.call('EXPIRE', KEYS[1], ttl)
|
||||
end
|
||||
|
||||
return lastMatch
|
||||
`
|
||||
|
||||
// geminiTrieSaveScript 保存会话到 Trie 的 Lua 脚本
|
||||
// KEYS[1] = trie key
|
||||
// ARGV[1] = digestChain
|
||||
// ARGV[2] = value (uuid:accountID)
|
||||
// ARGV[3] = TTL seconds
|
||||
geminiTrieSaveScript = `
|
||||
local chain = ARGV[1]
|
||||
local value = ARGV[2]
|
||||
local ttl = tonumber(ARGV[3])
|
||||
local path = ""
|
||||
|
||||
for part in string.gmatch(chain, "[^-]+") do
|
||||
path = path == "" and part or path .. "-" .. part
|
||||
end
|
||||
redis.call('HSET', KEYS[1], path, value)
|
||||
redis.call('EXPIRE', KEYS[1], ttl)
|
||||
return "OK"
|
||||
`
|
||||
)
|
||||
|
||||
// 模型负载统计相关常量
|
||||
const (
|
||||
modelLoadKeyPrefix = "ag:model_load:" // 模型调用次数 key 前缀
|
||||
modelLastUsedKeyPrefix = "ag:model_last_used:" // 模型最后调度时间 key 前缀
|
||||
modelLoadTTL = 24 * time.Hour // 调用次数 TTL(24 小时无调用后清零)
|
||||
modelLastUsedTTL = 24 * time.Hour // 最后调度时间 TTL
|
||||
)
|
||||
|
||||
type gatewayCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
@@ -51,3 +108,133 @@ func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64
|
||||
key := buildSessionKey(groupID, sessionHash)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// ============ Antigravity 模型负载统计方法 ============
|
||||
|
||||
// modelLoadKey 构建模型调用次数 key
|
||||
// 格式: ag:model_load:{accountID}:{model}
|
||||
func modelLoadKey(accountID int64, model string) string {
|
||||
return fmt.Sprintf("%s%d:%s", modelLoadKeyPrefix, accountID, model)
|
||||
}
|
||||
|
||||
// modelLastUsedKey 构建模型最后调度时间 key
|
||||
// 格式: ag:model_last_used:{accountID}:{model}
|
||||
func modelLastUsedKey(accountID int64, model string) string {
|
||||
return fmt.Sprintf("%s%d:%s", modelLastUsedKeyPrefix, accountID, model)
|
||||
}
|
||||
|
||||
// IncrModelCallCount 增加模型调用次数并更新最后调度时间
|
||||
// 返回更新后的调用次数
|
||||
func (c *gatewayCache) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
|
||||
loadKey := modelLoadKey(accountID, model)
|
||||
lastUsedKey := modelLastUsedKey(accountID, model)
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
incrCmd := pipe.Incr(ctx, loadKey)
|
||||
pipe.Expire(ctx, loadKey, modelLoadTTL) // 每次调用刷新 TTL
|
||||
pipe.Set(ctx, lastUsedKey, time.Now().Unix(), modelLastUsedTTL)
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return incrCmd.Val(), nil
|
||||
}
|
||||
|
||||
// GetModelLoadBatch 批量获取账号的模型负载信息
|
||||
func (c *gatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*service.ModelLoadInfo, error) {
|
||||
if len(accountIDs) == 0 {
|
||||
return make(map[int64]*service.ModelLoadInfo), nil
|
||||
}
|
||||
|
||||
loadCmds, lastUsedCmds := c.pipelineModelLoadGet(ctx, accountIDs, model)
|
||||
return c.parseModelLoadResults(accountIDs, loadCmds, lastUsedCmds), nil
|
||||
}
|
||||
|
||||
// pipelineModelLoadGet 批量获取模型负载的 Pipeline 操作
|
||||
func (c *gatewayCache) pipelineModelLoadGet(
|
||||
ctx context.Context,
|
||||
accountIDs []int64,
|
||||
model string,
|
||||
) (map[int64]*redis.StringCmd, map[int64]*redis.StringCmd) {
|
||||
pipe := c.rdb.Pipeline()
|
||||
loadCmds := make(map[int64]*redis.StringCmd, len(accountIDs))
|
||||
lastUsedCmds := make(map[int64]*redis.StringCmd, len(accountIDs))
|
||||
|
||||
for _, id := range accountIDs {
|
||||
loadCmds[id] = pipe.Get(ctx, modelLoadKey(id, model))
|
||||
lastUsedCmds[id] = pipe.Get(ctx, modelLastUsedKey(id, model))
|
||||
}
|
||||
_, _ = pipe.Exec(ctx) // 忽略错误,key 不存在是正常的
|
||||
return loadCmds, lastUsedCmds
|
||||
}
|
||||
|
||||
// parseModelLoadResults 解析 Pipeline 结果
|
||||
func (c *gatewayCache) parseModelLoadResults(
|
||||
accountIDs []int64,
|
||||
loadCmds map[int64]*redis.StringCmd,
|
||||
lastUsedCmds map[int64]*redis.StringCmd,
|
||||
) map[int64]*service.ModelLoadInfo {
|
||||
result := make(map[int64]*service.ModelLoadInfo, len(accountIDs))
|
||||
for _, id := range accountIDs {
|
||||
result[id] = &service.ModelLoadInfo{
|
||||
CallCount: getInt64OrZero(loadCmds[id]),
|
||||
LastUsedAt: getTimeOrZero(lastUsedCmds[id]),
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// getInt64OrZero 从 StringCmd 获取 int64 值,失败返回 0
|
||||
func getInt64OrZero(cmd *redis.StringCmd) int64 {
|
||||
val, _ := cmd.Int64()
|
||||
return val
|
||||
}
|
||||
|
||||
// getTimeOrZero 从 StringCmd 获取 time.Time,失败返回零值
|
||||
func getTimeOrZero(cmd *redis.StringCmd) time.Time {
|
||||
val, err := cmd.Int64()
|
||||
if err != nil {
|
||||
return time.Time{}
|
||||
}
|
||||
return time.Unix(val, 0)
|
||||
}
|
||||
|
||||
// ============ Gemini 会话 Fallback 方法 (Trie 实现) ============
|
||||
|
||||
// FindGeminiSession 查找 Gemini 会话(使用 Trie + Lua 脚本实现 O(L) 查询)
|
||||
// 返回最长匹配的会话信息,匹配成功时自动刷新 TTL
|
||||
func (c *gatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
||||
if digestChain == "" {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
trieKey := service.BuildGeminiTrieKey(groupID, prefixHash)
|
||||
ttlSeconds := int(service.GeminiSessionTTL().Seconds())
|
||||
|
||||
// 使用 Lua 脚本在 Redis 端执行 Trie 查找,O(L) 次 HGET,1 次网络往返
|
||||
// 查找成功时自动刷新 TTL,防止活跃会话意外过期
|
||||
result, err := c.rdb.Eval(ctx, geminiTrieFindScript, []string{trieKey}, digestChain, ttlSeconds).Result()
|
||||
if err != nil || result == nil {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
value, ok := result.(string)
|
||||
if !ok || value == "" {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
uuid, accountID, ok = service.ParseGeminiSessionValue(value)
|
||||
return uuid, accountID, ok
|
||||
}
|
||||
|
||||
// SaveGeminiSession 保存 Gemini 会话(使用 Trie + Lua 脚本)
|
||||
func (c *gatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
||||
if digestChain == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
trieKey := service.BuildGeminiTrieKey(groupID, prefixHash)
|
||||
value := service.FormatGeminiSessionValue(uuid, accountID)
|
||||
ttlSeconds := int(service.GeminiSessionTTL().Seconds())
|
||||
|
||||
return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err()
|
||||
}
|
||||
|
||||
@@ -104,6 +104,158 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
|
||||
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
|
||||
}
|
||||
|
||||
// ============ Gemini Trie 会话测试 ============
|
||||
|
||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_SaveAndFind() {
|
||||
groupID := int64(1)
|
||||
prefixHash := "testprefix"
|
||||
digestChain := "u:hash1-m:hash2-u:hash3"
|
||||
uuid := "test-uuid-123"
|
||||
accountID := int64(42)
|
||||
|
||||
// 保存会话
|
||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, uuid, accountID)
|
||||
require.NoError(s.T(), err, "SaveGeminiSession")
|
||||
|
||||
// 精确匹配查找
|
||||
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, digestChain)
|
||||
require.True(s.T(), found, "should find exact match")
|
||||
require.Equal(s.T(), uuid, foundUUID)
|
||||
require.Equal(s.T(), accountID, foundAccountID)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_PrefixMatch() {
|
||||
groupID := int64(1)
|
||||
prefixHash := "prefixmatch"
|
||||
shortChain := "u:a-m:b"
|
||||
longChain := "u:a-m:b-u:c-m:d"
|
||||
uuid := "uuid-prefix"
|
||||
accountID := int64(100)
|
||||
|
||||
// 保存短链
|
||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, shortChain, uuid, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// 用长链查找,应该匹配到短链(前缀匹配)
|
||||
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, longChain)
|
||||
require.True(s.T(), found, "should find prefix match")
|
||||
require.Equal(s.T(), uuid, foundUUID)
|
||||
require.Equal(s.T(), accountID, foundAccountID)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_LongestPrefixMatch() {
|
||||
groupID := int64(1)
|
||||
prefixHash := "longestmatch"
|
||||
|
||||
// 保存多个不同长度的链
|
||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a", "uuid-short", 1)
|
||||
require.NoError(s.T(), err)
|
||||
err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b", "uuid-medium", 2)
|
||||
require.NoError(s.T(), err)
|
||||
err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c", "uuid-long", 3)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// 查找更长的链,应该匹配到最长的前缀
|
||||
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c-m:d-u:e")
|
||||
require.True(s.T(), found, "should find longest prefix match")
|
||||
require.Equal(s.T(), "uuid-long", foundUUID)
|
||||
require.Equal(s.T(), int64(3), foundAccountID)
|
||||
|
||||
// 查找中等长度的链
|
||||
foundUUID, foundAccountID, found = s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:x")
|
||||
require.True(s.T(), found)
|
||||
require.Equal(s.T(), "uuid-medium", foundUUID)
|
||||
require.Equal(s.T(), int64(2), foundAccountID)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_NoMatch() {
|
||||
groupID := int64(1)
|
||||
prefixHash := "nomatch"
|
||||
digestChain := "u:a-m:b"
|
||||
|
||||
// 保存一个会话
|
||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, "uuid", 1)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// 用不同的链查找,应该找不到
|
||||
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:x-m:y")
|
||||
require.False(s.T(), found, "should not find non-matching chain")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentPrefixHash() {
|
||||
groupID := int64(1)
|
||||
digestChain := "u:a-m:b"
|
||||
|
||||
// 保存到 prefixHash1
|
||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, "prefix1", digestChain, "uuid1", 1)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// 用 prefixHash2 查找,应该找不到(不同用户/客户端隔离)
|
||||
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, "prefix2", digestChain)
|
||||
require.False(s.T(), found, "different prefixHash should be isolated")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentGroupID() {
|
||||
prefixHash := "sameprefix"
|
||||
digestChain := "u:a-m:b"
|
||||
|
||||
// 保存到 groupID 1
|
||||
err := s.cache.SaveGeminiSession(s.ctx, 1, prefixHash, digestChain, "uuid1", 1)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// 用 groupID 2 查找,应该找不到(分组隔离)
|
||||
_, _, found := s.cache.FindGeminiSession(s.ctx, 2, prefixHash, digestChain)
|
||||
require.False(s.T(), found, "different groupID should be isolated")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_EmptyDigestChain() {
|
||||
groupID := int64(1)
|
||||
prefixHash := "emptytest"
|
||||
|
||||
// 空链不应该保存
|
||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "", "uuid", 1)
|
||||
require.NoError(s.T(), err, "empty chain should not error")
|
||||
|
||||
// 空链查找应该返回 false
|
||||
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "")
|
||||
require.False(s.T(), found, "empty chain should not match")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_MultipleSessions() {
|
||||
groupID := int64(1)
|
||||
prefixHash := "multisession"
|
||||
|
||||
// 保存多个不同会话(模拟 1000 个并发会话的场景)
|
||||
sessions := []struct {
|
||||
chain string
|
||||
uuid string
|
||||
accountID int64
|
||||
}{
|
||||
{"u:session1", "uuid-1", 1},
|
||||
{"u:session2-m:reply2", "uuid-2", 2},
|
||||
{"u:session3-m:reply3-u:msg3", "uuid-3", 3},
|
||||
}
|
||||
|
||||
for _, sess := range sessions {
|
||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, sess.chain, sess.uuid, sess.accountID)
|
||||
require.NoError(s.T(), err)
|
||||
}
|
||||
|
||||
// 验证每个会话都能正确查找
|
||||
for _, sess := range sessions {
|
||||
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, sess.chain)
|
||||
require.True(s.T(), found, "should find session: %s", sess.chain)
|
||||
require.Equal(s.T(), sess.uuid, foundUUID)
|
||||
require.Equal(s.T(), sess.accountID, foundAccountID)
|
||||
}
|
||||
|
||||
// 验证继续对话的场景
|
||||
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:session2-m:reply2-u:newmsg")
|
||||
require.True(s.T(), found)
|
||||
require.Equal(s.T(), "uuid-2", foundUUID)
|
||||
require.Equal(s.T(), int64(2), foundAccountID)
|
||||
}
|
||||
|
||||
func TestGatewayCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(GatewayCacheSuite))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,234 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// ============ Gateway Cache 模型负载统计集成测试 ============
|
||||
|
||||
type GatewayCacheModelLoadSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func TestGatewayCacheModelLoadSuite(t *testing.T) {
|
||||
suite.Run(t, new(GatewayCacheModelLoadSuite))
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_Basic() {
|
||||
t := s.T()
|
||||
rdb := testRedis(t)
|
||||
cache := &gatewayCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
accountID := int64(123)
|
||||
model := "claude-sonnet-4-20250514"
|
||||
|
||||
// 首次调用应返回 1
|
||||
count1, err := cache.IncrModelCallCount(ctx, accountID, model)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), count1)
|
||||
|
||||
// 第二次调用应返回 2
|
||||
count2, err := cache.IncrModelCallCount(ctx, accountID, model)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), count2)
|
||||
|
||||
// 第三次调用应返回 3
|
||||
count3, err := cache.IncrModelCallCount(ctx, accountID, model)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(3), count3)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_DifferentModels() {
|
||||
t := s.T()
|
||||
rdb := testRedis(t)
|
||||
cache := &gatewayCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
accountID := int64(456)
|
||||
model1 := "claude-sonnet-4-20250514"
|
||||
model2 := "claude-opus-4-5-20251101"
|
||||
|
||||
// 不同模型应该独立计数
|
||||
count1, err := cache.IncrModelCallCount(ctx, accountID, model1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), count1)
|
||||
|
||||
count2, err := cache.IncrModelCallCount(ctx, accountID, model2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), count2)
|
||||
|
||||
count1Again, err := cache.IncrModelCallCount(ctx, accountID, model1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), count1Again)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_DifferentAccounts() {
|
||||
t := s.T()
|
||||
rdb := testRedis(t)
|
||||
cache := &gatewayCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
account1 := int64(111)
|
||||
account2 := int64(222)
|
||||
model := "gemini-2.5-pro"
|
||||
|
||||
// 不同账号应该独立计数
|
||||
count1, err := cache.IncrModelCallCount(ctx, account1, model)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), count1)
|
||||
|
||||
count2, err := cache.IncrModelCallCount(ctx, account2, model)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), count2)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_Empty() {
|
||||
t := s.T()
|
||||
rdb := testRedis(t)
|
||||
cache := &gatewayCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := cache.GetModelLoadBatch(ctx, []int64{}, "any-model")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Empty(t, result)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_NonExistent() {
|
||||
t := s.T()
|
||||
rdb := testRedis(t)
|
||||
cache := &gatewayCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
// 查询不存在的账号应返回零值
|
||||
result, err := cache.GetModelLoadBatch(ctx, []int64{9999, 9998}, "claude-sonnet-4-20250514")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result, 2)
|
||||
|
||||
require.Equal(t, int64(0), result[9999].CallCount)
|
||||
require.True(t, result[9999].LastUsedAt.IsZero())
|
||||
require.Equal(t, int64(0), result[9998].CallCount)
|
||||
require.True(t, result[9998].LastUsedAt.IsZero())
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_AfterIncrement() {
|
||||
t := s.T()
|
||||
rdb := testRedis(t)
|
||||
cache := &gatewayCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
accountID := int64(789)
|
||||
model := "claude-sonnet-4-20250514"
|
||||
|
||||
// 先增加调用次数
|
||||
beforeIncr := time.Now()
|
||||
_, err := cache.IncrModelCallCount(ctx, accountID, model)
|
||||
require.NoError(t, err)
|
||||
_, err = cache.IncrModelCallCount(ctx, accountID, model)
|
||||
require.NoError(t, err)
|
||||
_, err = cache.IncrModelCallCount(ctx, accountID, model)
|
||||
require.NoError(t, err)
|
||||
afterIncr := time.Now()
|
||||
|
||||
// 获取负载信息
|
||||
result, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result, 1)
|
||||
|
||||
loadInfo := result[accountID]
|
||||
require.NotNil(t, loadInfo)
|
||||
require.Equal(t, int64(3), loadInfo.CallCount)
|
||||
require.False(t, loadInfo.LastUsedAt.IsZero())
|
||||
// LastUsedAt 应该在 beforeIncr 和 afterIncr 之间
|
||||
require.True(t, loadInfo.LastUsedAt.After(beforeIncr.Add(-time.Second)) || loadInfo.LastUsedAt.Equal(beforeIncr))
|
||||
require.True(t, loadInfo.LastUsedAt.Before(afterIncr.Add(time.Second)) || loadInfo.LastUsedAt.Equal(afterIncr))
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_MultipleAccounts() {
|
||||
t := s.T()
|
||||
rdb := testRedis(t)
|
||||
cache := &gatewayCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
model := "claude-opus-4-5-20251101"
|
||||
account1 := int64(1001)
|
||||
account2 := int64(1002)
|
||||
account3 := int64(1003) // 不调用
|
||||
|
||||
// account1 调用 2 次
|
||||
_, err := cache.IncrModelCallCount(ctx, account1, model)
|
||||
require.NoError(t, err)
|
||||
_, err = cache.IncrModelCallCount(ctx, account1, model)
|
||||
require.NoError(t, err)
|
||||
|
||||
// account2 调用 5 次
|
||||
for i := 0; i < 5; i++ {
|
||||
_, err = cache.IncrModelCallCount(ctx, account2, model)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// 批量获取
|
||||
result, err := cache.GetModelLoadBatch(ctx, []int64{account1, account2, account3}, model)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result, 3)
|
||||
|
||||
require.Equal(t, int64(2), result[account1].CallCount)
|
||||
require.False(t, result[account1].LastUsedAt.IsZero())
|
||||
|
||||
require.Equal(t, int64(5), result[account2].CallCount)
|
||||
require.False(t, result[account2].LastUsedAt.IsZero())
|
||||
|
||||
require.Equal(t, int64(0), result[account3].CallCount)
|
||||
require.True(t, result[account3].LastUsedAt.IsZero())
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_ModelIsolation() {
|
||||
t := s.T()
|
||||
rdb := testRedis(t)
|
||||
cache := &gatewayCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
accountID := int64(2001)
|
||||
model1 := "claude-sonnet-4-20250514"
|
||||
model2 := "gemini-2.5-pro"
|
||||
|
||||
// 对 model1 调用 3 次
|
||||
for i := 0; i < 3; i++ {
|
||||
_, err := cache.IncrModelCallCount(ctx, accountID, model1)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// 获取 model1 的负载
|
||||
result1, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(3), result1[accountID].CallCount)
|
||||
|
||||
// 获取 model2 的负载(应该为 0)
|
||||
result2, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(0), result2[accountID].CallCount)
|
||||
}
|
||||
|
||||
// ============ 辅助函数测试 ============
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestModelLoadKey_Format() {
|
||||
t := s.T()
|
||||
|
||||
key := modelLoadKey(123, "claude-sonnet-4")
|
||||
require.Equal(t, "ag:model_load:123:claude-sonnet-4", key)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheModelLoadSuite) TestModelLastUsedKey_Format() {
|
||||
t := s.T()
|
||||
|
||||
key := modelLastUsedKey(456, "gemini-2.5-pro")
|
||||
require.Equal(t, "ag:model_last_used:456:gemini-2.5-pro", key)
|
||||
}
|
||||
@@ -98,12 +98,16 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = out.Close() }()
|
||||
|
||||
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
|
||||
limited := io.LimitReader(resp.Body, maxSize+1)
|
||||
written, err := io.Copy(out, limited)
|
||||
|
||||
// Close file before attempting to remove (required on Windows)
|
||||
_ = out.Close()
|
||||
|
||||
if err != nil {
|
||||
_ = os.Remove(dest) // Clean up partial file (best-effort)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user