refactor: replace Trie-based digest session store with flat cache
This commit is contained in:
@@ -154,7 +154,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
identityService := service.NewIdentityService(identityCache)
|
identityService := service.NewIdentityService(identityCache)
|
||||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
||||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache)
|
digestSessionStore := service.NewDigestSessionStore()
|
||||||
|
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, digestSessionStore)
|
||||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||||
|
|||||||
@@ -103,6 +103,7 @@ require (
|
|||||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||||
|
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||||
github.com/pkg/errors v0.9.1 // indirect
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||||
|
|||||||
@@ -213,6 +213,8 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8
|
|||||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||||
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||||
|
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
||||||
|
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
|
|||||||
@@ -259,6 +259,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
var geminiDigestChain string
|
var geminiDigestChain string
|
||||||
var geminiPrefixHash string
|
var geminiPrefixHash string
|
||||||
var geminiSessionUUID string
|
var geminiSessionUUID string
|
||||||
|
var matchedDigestChain string
|
||||||
useDigestFallback := sessionBoundAccountID == 0
|
useDigestFallback := sessionBoundAccountID == 0
|
||||||
|
|
||||||
if useDigestFallback {
|
if useDigestFallback {
|
||||||
@@ -285,13 +286,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
// 查找会话
|
// 查找会话
|
||||||
foundUUID, foundAccountID, found := h.gatewayService.FindGeminiSession(
|
foundUUID, foundAccountID, foundMatchedChain, found := h.gatewayService.FindGeminiSession(
|
||||||
c.Request.Context(),
|
c.Request.Context(),
|
||||||
derefGroupID(apiKey.GroupID),
|
derefGroupID(apiKey.GroupID),
|
||||||
geminiPrefixHash,
|
geminiPrefixHash,
|
||||||
geminiDigestChain,
|
geminiDigestChain,
|
||||||
)
|
)
|
||||||
if found {
|
if found {
|
||||||
|
matchedDigestChain = foundMatchedChain
|
||||||
sessionBoundAccountID = foundAccountID
|
sessionBoundAccountID = foundAccountID
|
||||||
geminiSessionUUID = foundUUID
|
geminiSessionUUID = foundUUID
|
||||||
log.Printf("[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s",
|
log.Printf("[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s",
|
||||||
@@ -458,6 +460,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
geminiDigestChain,
|
geminiDigestChain,
|
||||||
geminiSessionUUID,
|
geminiSessionUUID,
|
||||||
account.ID,
|
account.ID,
|
||||||
|
matchedDigestChain,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
log.Printf("[Gemini] Failed to save digest session: %v", err)
|
log.Printf("[Gemini] Failed to save digest session: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,64 +11,6 @@ import (
|
|||||||
|
|
||||||
const stickySessionPrefix = "sticky_session:"
|
const stickySessionPrefix = "sticky_session:"
|
||||||
|
|
||||||
// Gemini Trie Lua 脚本
|
|
||||||
const (
|
|
||||||
// geminiTrieFindScript 查找最长前缀匹配的 Lua 脚本
|
|
||||||
// KEYS[1] = trie key
|
|
||||||
// ARGV[1] = digestChain (如 "u:a-m:b-u:c-m:d")
|
|
||||||
// ARGV[2] = TTL seconds (用于刷新)
|
|
||||||
// 返回: 最长匹配的 value (uuid:accountID) 或 nil
|
|
||||||
// 查找成功时自动刷新 TTL,防止活跃会话意外过期
|
|
||||||
// 从最长前缀(完整 chain)开始逐步缩短,第一次命中即返回
|
|
||||||
geminiTrieFindScript = `
|
|
||||||
local chain = ARGV[1]
|
|
||||||
local ttl = tonumber(ARGV[2])
|
|
||||||
|
|
||||||
-- 先尝试完整 chain(最常见场景:同一对话的下一轮请求)
|
|
||||||
local val = redis.call('HGET', KEYS[1], chain)
|
|
||||||
if val and val ~= "" then
|
|
||||||
redis.call('EXPIRE', KEYS[1], ttl)
|
|
||||||
return val
|
|
||||||
end
|
|
||||||
|
|
||||||
-- 从最长前缀开始逐步缩短(去掉最后一个 "-xxx" 段)
|
|
||||||
local path = chain
|
|
||||||
while true do
|
|
||||||
local i = string.find(path, "-[^-]*$")
|
|
||||||
if not i or i <= 1 then
|
|
||||||
break
|
|
||||||
end
|
|
||||||
path = string.sub(path, 1, i - 1)
|
|
||||||
val = redis.call('HGET', KEYS[1], path)
|
|
||||||
if val and val ~= "" then
|
|
||||||
redis.call('EXPIRE', KEYS[1], ttl)
|
|
||||||
return val
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
return nil
|
|
||||||
`
|
|
||||||
|
|
||||||
// geminiTrieSaveScript 保存会话到 Trie 的 Lua 脚本
|
|
||||||
// KEYS[1] = trie key
|
|
||||||
// ARGV[1] = digestChain
|
|
||||||
// ARGV[2] = value (uuid:accountID)
|
|
||||||
// ARGV[3] = TTL seconds
|
|
||||||
geminiTrieSaveScript = `
|
|
||||||
local chain = ARGV[1]
|
|
||||||
local value = ARGV[2]
|
|
||||||
local ttl = tonumber(ARGV[3])
|
|
||||||
local path = ""
|
|
||||||
|
|
||||||
for part in string.gmatch(chain, "[^-]+") do
|
|
||||||
path = path == "" and part or path .. "-" .. part
|
|
||||||
end
|
|
||||||
redis.call('HSET', KEYS[1], path, value)
|
|
||||||
redis.call('EXPIRE', KEYS[1], ttl)
|
|
||||||
return "OK"
|
|
||||||
`
|
|
||||||
)
|
|
||||||
|
|
||||||
// 模型负载统计相关常量
|
// 模型负载统计相关常量
|
||||||
const (
|
const (
|
||||||
modelLoadKeyPrefix = "ag:model_load:" // 模型调用次数 key 前缀
|
modelLoadKeyPrefix = "ag:model_load:" // 模型调用次数 key 前缀
|
||||||
@@ -206,82 +148,3 @@ func getTimeOrZero(cmd *redis.StringCmd) time.Time {
|
|||||||
}
|
}
|
||||||
return time.Unix(val, 0)
|
return time.Unix(val, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============ Gemini 会话 Fallback 方法 (Trie 实现) ============
|
|
||||||
|
|
||||||
// FindGeminiSession 查找 Gemini 会话(使用 Trie + Lua 脚本实现 O(L) 查询)
|
|
||||||
// 返回最长匹配的会话信息,匹配成功时自动刷新 TTL
|
|
||||||
func (c *gatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
|
||||||
if digestChain == "" {
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
trieKey := service.BuildGeminiTrieKey(groupID, prefixHash)
|
|
||||||
ttlSeconds := int(service.GeminiSessionTTL().Seconds())
|
|
||||||
|
|
||||||
// 使用 Lua 脚本在 Redis 端执行 Trie 查找,O(L) 次 HGET,1 次网络往返
|
|
||||||
// 查找成功时自动刷新 TTL,防止活跃会话意外过期
|
|
||||||
result, err := c.rdb.Eval(ctx, geminiTrieFindScript, []string{trieKey}, digestChain, ttlSeconds).Result()
|
|
||||||
if err != nil || result == nil {
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
value, ok := result.(string)
|
|
||||||
if !ok || value == "" {
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
uuid, accountID, ok = service.ParseGeminiSessionValue(value)
|
|
||||||
return uuid, accountID, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveGeminiSession 保存 Gemini 会话(使用 Trie + Lua 脚本)
|
|
||||||
func (c *gatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
|
||||||
if digestChain == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
trieKey := service.BuildGeminiTrieKey(groupID, prefixHash)
|
|
||||||
value := service.FormatGeminiSessionValue(uuid, accountID)
|
|
||||||
ttlSeconds := int(service.GeminiSessionTTL().Seconds())
|
|
||||||
|
|
||||||
return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============ Anthropic 会话 Fallback 方法 (复用 Trie 实现) ============
|
|
||||||
|
|
||||||
// FindAnthropicSession 查找 Anthropic 会话(复用 Gemini Trie Lua 脚本)
|
|
||||||
func (c *gatewayCache) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
|
||||||
if digestChain == "" {
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
trieKey := service.BuildAnthropicTrieKey(groupID, prefixHash)
|
|
||||||
ttlSeconds := int(service.AnthropicSessionTTL().Seconds())
|
|
||||||
|
|
||||||
result, err := c.rdb.Eval(ctx, geminiTrieFindScript, []string{trieKey}, digestChain, ttlSeconds).Result()
|
|
||||||
if err != nil || result == nil {
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
value, ok := result.(string)
|
|
||||||
if !ok || value == "" {
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
uuid, accountID, ok = service.ParseGeminiSessionValue(value)
|
|
||||||
return uuid, accountID, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveAnthropicSession 保存 Anthropic 会话(复用 Gemini Trie Lua 脚本)
|
|
||||||
func (c *gatewayCache) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
|
||||||
if digestChain == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
trieKey := service.BuildAnthropicTrieKey(groupID, prefixHash)
|
|
||||||
value := service.FormatGeminiSessionValue(uuid, accountID)
|
|
||||||
ttlSeconds := int(service.AnthropicSessionTTL().Seconds())
|
|
||||||
|
|
||||||
return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err()
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -104,157 +104,6 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
|
|||||||
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
|
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============ Gemini Trie 会话测试 ============
|
|
||||||
|
|
||||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_SaveAndFind() {
|
|
||||||
groupID := int64(1)
|
|
||||||
prefixHash := "testprefix"
|
|
||||||
digestChain := "u:hash1-m:hash2-u:hash3"
|
|
||||||
uuid := "test-uuid-123"
|
|
||||||
accountID := int64(42)
|
|
||||||
|
|
||||||
// 保存会话
|
|
||||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, uuid, accountID)
|
|
||||||
require.NoError(s.T(), err, "SaveGeminiSession")
|
|
||||||
|
|
||||||
// 精确匹配查找
|
|
||||||
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, digestChain)
|
|
||||||
require.True(s.T(), found, "should find exact match")
|
|
||||||
require.Equal(s.T(), uuid, foundUUID)
|
|
||||||
require.Equal(s.T(), accountID, foundAccountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_PrefixMatch() {
|
|
||||||
groupID := int64(1)
|
|
||||||
prefixHash := "prefixmatch"
|
|
||||||
shortChain := "u:a-m:b"
|
|
||||||
longChain := "u:a-m:b-u:c-m:d"
|
|
||||||
uuid := "uuid-prefix"
|
|
||||||
accountID := int64(100)
|
|
||||||
|
|
||||||
// 保存短链
|
|
||||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, shortChain, uuid, accountID)
|
|
||||||
require.NoError(s.T(), err)
|
|
||||||
|
|
||||||
// 用长链查找,应该匹配到短链(前缀匹配)
|
|
||||||
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, longChain)
|
|
||||||
require.True(s.T(), found, "should find prefix match")
|
|
||||||
require.Equal(s.T(), uuid, foundUUID)
|
|
||||||
require.Equal(s.T(), accountID, foundAccountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_LongestPrefixMatch() {
|
|
||||||
groupID := int64(1)
|
|
||||||
prefixHash := "longestmatch"
|
|
||||||
|
|
||||||
// 保存多个不同长度的链
|
|
||||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a", "uuid-short", 1)
|
|
||||||
require.NoError(s.T(), err)
|
|
||||||
err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b", "uuid-medium", 2)
|
|
||||||
require.NoError(s.T(), err)
|
|
||||||
err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c", "uuid-long", 3)
|
|
||||||
require.NoError(s.T(), err)
|
|
||||||
|
|
||||||
// 查找更长的链,应该匹配到最长的前缀
|
|
||||||
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c-m:d-u:e")
|
|
||||||
require.True(s.T(), found, "should find longest prefix match")
|
|
||||||
require.Equal(s.T(), "uuid-long", foundUUID)
|
|
||||||
require.Equal(s.T(), int64(3), foundAccountID)
|
|
||||||
|
|
||||||
// 查找中等长度的链
|
|
||||||
foundUUID, foundAccountID, found = s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:x")
|
|
||||||
require.True(s.T(), found)
|
|
||||||
require.Equal(s.T(), "uuid-medium", foundUUID)
|
|
||||||
require.Equal(s.T(), int64(2), foundAccountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_NoMatch() {
|
|
||||||
groupID := int64(1)
|
|
||||||
prefixHash := "nomatch"
|
|
||||||
digestChain := "u:a-m:b"
|
|
||||||
|
|
||||||
// 保存一个会话
|
|
||||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, "uuid", 1)
|
|
||||||
require.NoError(s.T(), err)
|
|
||||||
|
|
||||||
// 用不同的链查找,应该找不到
|
|
||||||
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:x-m:y")
|
|
||||||
require.False(s.T(), found, "should not find non-matching chain")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentPrefixHash() {
|
|
||||||
groupID := int64(1)
|
|
||||||
digestChain := "u:a-m:b"
|
|
||||||
|
|
||||||
// 保存到 prefixHash1
|
|
||||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, "prefix1", digestChain, "uuid1", 1)
|
|
||||||
require.NoError(s.T(), err)
|
|
||||||
|
|
||||||
// 用 prefixHash2 查找,应该找不到(不同用户/客户端隔离)
|
|
||||||
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, "prefix2", digestChain)
|
|
||||||
require.False(s.T(), found, "different prefixHash should be isolated")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentGroupID() {
|
|
||||||
prefixHash := "sameprefix"
|
|
||||||
digestChain := "u:a-m:b"
|
|
||||||
|
|
||||||
// 保存到 groupID 1
|
|
||||||
err := s.cache.SaveGeminiSession(s.ctx, 1, prefixHash, digestChain, "uuid1", 1)
|
|
||||||
require.NoError(s.T(), err)
|
|
||||||
|
|
||||||
// 用 groupID 2 查找,应该找不到(分组隔离)
|
|
||||||
_, _, found := s.cache.FindGeminiSession(s.ctx, 2, prefixHash, digestChain)
|
|
||||||
require.False(s.T(), found, "different groupID should be isolated")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_EmptyDigestChain() {
|
|
||||||
groupID := int64(1)
|
|
||||||
prefixHash := "emptytest"
|
|
||||||
|
|
||||||
// 空链不应该保存
|
|
||||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "", "uuid", 1)
|
|
||||||
require.NoError(s.T(), err, "empty chain should not error")
|
|
||||||
|
|
||||||
// 空链查找应该返回 false
|
|
||||||
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "")
|
|
||||||
require.False(s.T(), found, "empty chain should not match")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayCacheSuite) TestGeminiSessionTrie_MultipleSessions() {
|
|
||||||
groupID := int64(1)
|
|
||||||
prefixHash := "multisession"
|
|
||||||
|
|
||||||
// 保存多个不同会话(模拟 1000 个并发会话的场景)
|
|
||||||
sessions := []struct {
|
|
||||||
chain string
|
|
||||||
uuid string
|
|
||||||
accountID int64
|
|
||||||
}{
|
|
||||||
{"u:session1", "uuid-1", 1},
|
|
||||||
{"u:session2-m:reply2", "uuid-2", 2},
|
|
||||||
{"u:session3-m:reply3-u:msg3", "uuid-3", 3},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, sess := range sessions {
|
|
||||||
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, sess.chain, sess.uuid, sess.accountID)
|
|
||||||
require.NoError(s.T(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证每个会话都能正确查找
|
|
||||||
for _, sess := range sessions {
|
|
||||||
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, sess.chain)
|
|
||||||
require.True(s.T(), found, "should find session: %s", sess.chain)
|
|
||||||
require.Equal(s.T(), sess.uuid, foundUUID)
|
|
||||||
require.Equal(s.T(), sess.accountID, foundAccountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证继续对话的场景
|
|
||||||
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:session2-m:reply2-u:newmsg")
|
|
||||||
require.True(s.T(), found)
|
|
||||||
require.Equal(s.T(), "uuid-2", foundUUID)
|
|
||||||
require.Equal(s.T(), int64(2), foundAccountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGatewayCacheSuite(t *testing.T) {
|
func TestGatewayCacheSuite(t *testing.T) {
|
||||||
suite.Run(t, new(GatewayCacheSuite))
|
suite.Run(t, new(GatewayCacheSuite))
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -12,9 +11,6 @@ const (
|
|||||||
// anthropicSessionTTLSeconds Anthropic 会话缓存 TTL(5 分钟)
|
// anthropicSessionTTLSeconds Anthropic 会话缓存 TTL(5 分钟)
|
||||||
anthropicSessionTTLSeconds = 300
|
anthropicSessionTTLSeconds = 300
|
||||||
|
|
||||||
// anthropicTrieKeyPrefix Anthropic Trie 会话 key 前缀
|
|
||||||
anthropicTrieKeyPrefix = "anthropic:trie:"
|
|
||||||
|
|
||||||
// anthropicDigestSessionKeyPrefix Anthropic 摘要 fallback 会话 key 前缀
|
// anthropicDigestSessionKeyPrefix Anthropic 摘要 fallback 会话 key 前缀
|
||||||
anthropicDigestSessionKeyPrefix = "anthropic:digest:"
|
anthropicDigestSessionKeyPrefix = "anthropic:digest:"
|
||||||
)
|
)
|
||||||
@@ -68,12 +64,6 @@ func rolePrefix(role string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildAnthropicTrieKey 构建 Anthropic Trie Redis key
|
|
||||||
// 格式: anthropic:trie:{groupID}:{prefixHash}
|
|
||||||
func BuildAnthropicTrieKey(groupID int64, prefixHash string) string {
|
|
||||||
return anthropicTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateAnthropicDigestSessionKey 生成 Anthropic 摘要 fallback 的 sessionKey
|
// GenerateAnthropicDigestSessionKey 生成 Anthropic 摘要 fallback 的 sessionKey
|
||||||
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
|
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
|
||||||
func GenerateAnthropicDigestSessionKey(prefixHash, uuid string) string {
|
func GenerateAnthropicDigestSessionKey(prefixHash, uuid string) string {
|
||||||
|
|||||||
@@ -236,43 +236,6 @@ func TestBuildAnthropicDigestChain_Deterministic(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBuildAnthropicTrieKey(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
groupID int64
|
|
||||||
prefixHash string
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "normal",
|
|
||||||
groupID: 123,
|
|
||||||
prefixHash: "abcdef12",
|
|
||||||
want: "anthropic:trie:123:abcdef12",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "zero group",
|
|
||||||
groupID: 0,
|
|
||||||
prefixHash: "xyz",
|
|
||||||
want: "anthropic:trie:0:xyz",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty prefix",
|
|
||||||
groupID: 1,
|
|
||||||
prefixHash: "",
|
|
||||||
want: "anthropic:trie:1:",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got := BuildAnthropicTrieKey(tt.groupID, tt.prefixHash)
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("BuildAnthropicTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGenerateAnthropicDigestSessionKey(t *testing.T) {
|
func TestGenerateAnthropicDigestSessionKey(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
69
backend/internal/service/digest_session_store.go
Normal file
69
backend/internal/service/digest_session_store.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
gocache "github.com/patrickmn/go-cache"
|
||||||
|
)
|
||||||
|
|
||||||
|
// digestSessionTTL 摘要会话默认 TTL
|
||||||
|
const digestSessionTTL = 5 * time.Minute
|
||||||
|
|
||||||
|
// sessionEntry flat cache 条目
|
||||||
|
type sessionEntry struct {
|
||||||
|
uuid string
|
||||||
|
accountID int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// DigestSessionStore 内存摘要会话存储(flat cache 实现)
|
||||||
|
// key: "{groupID}:{prefixHash}|{digestChain}" → *sessionEntry
|
||||||
|
type DigestSessionStore struct {
|
||||||
|
cache *gocache.Cache
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDigestSessionStore 创建内存摘要会话存储
|
||||||
|
func NewDigestSessionStore() *DigestSessionStore {
|
||||||
|
return &DigestSessionStore{
|
||||||
|
cache: gocache.New(digestSessionTTL, time.Minute),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save 保存摘要会话。oldDigestChain 为 Find 返回的 matchedChain,用于删旧 key。
|
||||||
|
func (s *DigestSessionStore) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) {
|
||||||
|
if digestChain == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ns := buildNS(groupID, prefixHash)
|
||||||
|
s.cache.Set(ns+digestChain, &sessionEntry{uuid: uuid, accountID: accountID}, gocache.DefaultExpiration)
|
||||||
|
if oldDigestChain != "" && oldDigestChain != digestChain {
|
||||||
|
s.cache.Delete(ns + oldDigestChain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find 查找摘要会话,从完整 chain 逐段截断,返回最长匹配及对应 matchedChain。
|
||||||
|
func (s *DigestSessionStore) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
|
||||||
|
if digestChain == "" {
|
||||||
|
return "", 0, "", false
|
||||||
|
}
|
||||||
|
ns := buildNS(groupID, prefixHash)
|
||||||
|
chain := digestChain
|
||||||
|
for {
|
||||||
|
if val, ok := s.cache.Get(ns + chain); ok {
|
||||||
|
if e, ok := val.(*sessionEntry); ok {
|
||||||
|
return e.uuid, e.accountID, chain, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
i := strings.LastIndex(chain, "-")
|
||||||
|
if i < 0 {
|
||||||
|
return "", 0, "", false
|
||||||
|
}
|
||||||
|
chain = chain[:i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildNS 构建 namespace 前缀
|
||||||
|
func buildNS(groupID int64, prefixHash string) string {
|
||||||
|
return strconv.FormatInt(groupID, 10) + ":" + prefixHash + "|"
|
||||||
|
}
|
||||||
312
backend/internal/service/digest_session_store_test.go
Normal file
312
backend/internal/service/digest_session_store_test.go
Normal file
@@ -0,0 +1,312 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
gocache "github.com/patrickmn/go-cache"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDigestSessionStore_SaveAndFind(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
store.Save(1, "prefix", "s:a1-u:b2-m:c3", "uuid-1", 100, "")
|
||||||
|
|
||||||
|
uuid, accountID, _, found := store.Find(1, "prefix", "s:a1-u:b2-m:c3")
|
||||||
|
require.True(t, found)
|
||||||
|
assert.Equal(t, "uuid-1", uuid)
|
||||||
|
assert.Equal(t, int64(100), accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_PrefixMatch(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
// 保存短链
|
||||||
|
store.Save(1, "prefix", "u:a-m:b", "uuid-short", 10, "")
|
||||||
|
|
||||||
|
// 用长链查找,应前缀匹配到短链
|
||||||
|
uuid, accountID, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d")
|
||||||
|
require.True(t, found)
|
||||||
|
assert.Equal(t, "uuid-short", uuid)
|
||||||
|
assert.Equal(t, int64(10), accountID)
|
||||||
|
assert.Equal(t, "u:a-m:b", matchedChain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_LongestPrefixMatch(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
store.Save(1, "prefix", "u:a", "uuid-1", 1, "")
|
||||||
|
store.Save(1, "prefix", "u:a-m:b", "uuid-2", 2, "")
|
||||||
|
store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-3", 3, "")
|
||||||
|
|
||||||
|
// 应匹配最深的 "u:a-m:b-u:c"(从完整 chain 逐段截断,先命中最长的)
|
||||||
|
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e")
|
||||||
|
require.True(t, found)
|
||||||
|
assert.Equal(t, "uuid-3", uuid)
|
||||||
|
assert.Equal(t, int64(3), accountID)
|
||||||
|
|
||||||
|
// 查找中等长度,应匹配到 "u:a-m:b"
|
||||||
|
uuid, accountID, _, found = store.Find(1, "prefix", "u:a-m:b-u:x")
|
||||||
|
require.True(t, found)
|
||||||
|
assert.Equal(t, "uuid-2", uuid)
|
||||||
|
assert.Equal(t, int64(2), accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_SaveDeletesOldChain(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
// 第一轮:保存 "u:a-m:b"
|
||||||
|
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
|
||||||
|
|
||||||
|
// 第二轮:同一 uuid 保存更长的链,传入旧 chain
|
||||||
|
store.Save(1, "prefix", "u:a-m:b-u:c-m:d", "uuid-1", 100, "u:a-m:b")
|
||||||
|
|
||||||
|
// 旧链 "u:a-m:b" 应已被删除
|
||||||
|
_, _, _, found := store.Find(1, "prefix", "u:a-m:b")
|
||||||
|
assert.False(t, found, "old chain should be deleted")
|
||||||
|
|
||||||
|
// 新链应能找到
|
||||||
|
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d")
|
||||||
|
require.True(t, found)
|
||||||
|
assert.Equal(t, "uuid-1", uuid)
|
||||||
|
assert.Equal(t, int64(100), accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_DifferentSessionsNoInterference(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
// 相同系统提示词,不同用户提示词
|
||||||
|
store.Save(1, "prefix", "s:sys-u:user1", "uuid-1", 100, "")
|
||||||
|
store.Save(1, "prefix", "s:sys-u:user2", "uuid-2", 200, "")
|
||||||
|
|
||||||
|
uuid, accountID, _, found := store.Find(1, "prefix", "s:sys-u:user1-m:reply1")
|
||||||
|
require.True(t, found)
|
||||||
|
assert.Equal(t, "uuid-1", uuid)
|
||||||
|
assert.Equal(t, int64(100), accountID)
|
||||||
|
|
||||||
|
uuid, accountID, _, found = store.Find(1, "prefix", "s:sys-u:user2-m:reply2")
|
||||||
|
require.True(t, found)
|
||||||
|
assert.Equal(t, "uuid-2", uuid)
|
||||||
|
assert.Equal(t, int64(200), accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_NoMatch(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
|
||||||
|
|
||||||
|
// 完全不同的 chain
|
||||||
|
_, _, _, found := store.Find(1, "prefix", "u:x-m:y")
|
||||||
|
assert.False(t, found)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_DifferentPrefixHash(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
store.Save(1, "prefix1", "u:a-m:b", "uuid-1", 100, "")
|
||||||
|
|
||||||
|
// 不同 prefixHash 应隔离
|
||||||
|
_, _, _, found := store.Find(1, "prefix2", "u:a-m:b")
|
||||||
|
assert.False(t, found)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_DifferentGroupID(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
|
||||||
|
|
||||||
|
// 不同 groupID 应隔离
|
||||||
|
_, _, _, found := store.Find(2, "prefix", "u:a-m:b")
|
||||||
|
assert.False(t, found)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_EmptyDigestChain(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
// 空链不应保存
|
||||||
|
store.Save(1, "prefix", "", "uuid-1", 100, "")
|
||||||
|
_, _, _, found := store.Find(1, "prefix", "")
|
||||||
|
assert.False(t, found)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_TTLExpiration(t *testing.T) {
|
||||||
|
store := &DigestSessionStore{
|
||||||
|
cache: gocache.New(100*time.Millisecond, 50*time.Millisecond),
|
||||||
|
}
|
||||||
|
|
||||||
|
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
|
||||||
|
|
||||||
|
// 立即应该能找到
|
||||||
|
_, _, _, found := store.Find(1, "prefix", "u:a-m:b")
|
||||||
|
require.True(t, found)
|
||||||
|
|
||||||
|
// 等待过期 + 清理周期
|
||||||
|
time.Sleep(300 * time.Millisecond)
|
||||||
|
|
||||||
|
// 过期后应找不到
|
||||||
|
_, _, _, found = store.Find(1, "prefix", "u:a-m:b")
|
||||||
|
assert.False(t, found)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_ConcurrentSafety(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
const goroutines = 50
|
||||||
|
const operations = 100
|
||||||
|
|
||||||
|
wg.Add(goroutines)
|
||||||
|
for g := 0; g < goroutines; g++ {
|
||||||
|
go func(id int) {
|
||||||
|
defer wg.Done()
|
||||||
|
prefix := fmt.Sprintf("prefix-%d", id%5)
|
||||||
|
for i := 0; i < operations; i++ {
|
||||||
|
chain := fmt.Sprintf("u:%d-m:%d", id, i)
|
||||||
|
uuid := fmt.Sprintf("uuid-%d-%d", id, i)
|
||||||
|
store.Save(1, prefix, chain, uuid, int64(id), "")
|
||||||
|
store.Find(1, prefix, chain)
|
||||||
|
}
|
||||||
|
}(g)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_MultipleSessions(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
sessions := []struct {
|
||||||
|
chain string
|
||||||
|
uuid string
|
||||||
|
accountID int64
|
||||||
|
}{
|
||||||
|
{"u:session1", "uuid-1", 1},
|
||||||
|
{"u:session2-m:reply2", "uuid-2", 2},
|
||||||
|
{"u:session3-m:reply3-u:msg3", "uuid-3", 3},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, sess := range sessions {
|
||||||
|
store.Save(1, "prefix", sess.chain, sess.uuid, sess.accountID, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证每个会话都能正确查找
|
||||||
|
for _, sess := range sessions {
|
||||||
|
uuid, accountID, _, found := store.Find(1, "prefix", sess.chain)
|
||||||
|
require.True(t, found, "should find session: %s", sess.chain)
|
||||||
|
assert.Equal(t, sess.uuid, uuid)
|
||||||
|
assert.Equal(t, sess.accountID, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证继续对话的场景
|
||||||
|
uuid, accountID, _, found := store.Find(1, "prefix", "u:session2-m:reply2-u:newmsg")
|
||||||
|
require.True(t, found)
|
||||||
|
assert.Equal(t, "uuid-2", uuid)
|
||||||
|
assert.Equal(t, int64(2), accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_Performance1000Sessions(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
// 插入 1000 个会话
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d", i, i)
|
||||||
|
store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查找性能测试
|
||||||
|
start := time.Now()
|
||||||
|
const lookups = 10000
|
||||||
|
for i := 0; i < lookups; i++ {
|
||||||
|
idx := i % 1000
|
||||||
|
chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d-u:newmsg", idx, idx)
|
||||||
|
_, _, _, found := store.Find(1, "prefix", chain)
|
||||||
|
assert.True(t, found)
|
||||||
|
}
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
t.Logf("%d lookups in %v (%.0f ns/op)", lookups, elapsed, float64(elapsed.Nanoseconds())/lookups)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_FindReturnsMatchedChain(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-1", 100, "")
|
||||||
|
|
||||||
|
// 精确匹配
|
||||||
|
_, _, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c")
|
||||||
|
require.True(t, found)
|
||||||
|
assert.Equal(t, "u:a-m:b-u:c", matchedChain)
|
||||||
|
|
||||||
|
// 前缀匹配(截断后命中)
|
||||||
|
_, _, matchedChain, found = store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e")
|
||||||
|
require.True(t, found)
|
||||||
|
assert.Equal(t, "u:a-m:b-u:c", matchedChain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_CacheItemCountStable(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
// 模拟 100 个独立会话,每个进行 10 轮对话
|
||||||
|
// 正确传递 oldDigestChain 时,每个会话始终只保留 1 个 key
|
||||||
|
for conv := 0; conv < 100; conv++ {
|
||||||
|
var prevMatchedChain string
|
||||||
|
for round := 0; round < 10; round++ {
|
||||||
|
chain := fmt.Sprintf("s:sys-u:user%d", conv)
|
||||||
|
for r := 0; r < round; r++ {
|
||||||
|
chain += fmt.Sprintf("-m:a%d-u:q%d", r, r+1)
|
||||||
|
}
|
||||||
|
uuid := fmt.Sprintf("uuid-conv%d", conv)
|
||||||
|
|
||||||
|
_, _, matched, _ := store.Find(1, "prefix", chain)
|
||||||
|
store.Save(1, "prefix", chain, uuid, int64(conv), matched)
|
||||||
|
prevMatchedChain = matched
|
||||||
|
_ = prevMatchedChain
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 100 个会话 × 1 key/会话 = 应该 ≤ 100 个 key
|
||||||
|
// 允许少量并发残留,但绝不能接近 100×10=1000
|
||||||
|
itemCount := store.cache.ItemCount()
|
||||||
|
assert.LessOrEqual(t, itemCount, 100, "cache should have at most 100 items (1 per conversation), got %d", itemCount)
|
||||||
|
t.Logf("Cache item count after 100 conversations × 10 rounds: %d", itemCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_TTLPreventsUnboundedGrowth(t *testing.T) {
|
||||||
|
// 使用极短 TTL 验证大量写入后 cache 能被清理
|
||||||
|
store := &DigestSessionStore{
|
||||||
|
cache: gocache.New(100*time.Millisecond, 50*time.Millisecond),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 插入 500 个不同的 key(无 oldDigestChain,模拟最坏场景:全是新会话首轮)
|
||||||
|
for i := 0; i < 500; i++ {
|
||||||
|
chain := fmt.Sprintf("u:user%d", i)
|
||||||
|
store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "")
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 500, store.cache.ItemCount())
|
||||||
|
|
||||||
|
// 等待 TTL + 清理周期
|
||||||
|
time.Sleep(300 * time.Millisecond)
|
||||||
|
|
||||||
|
assert.Equal(t, 0, store.cache.ItemCount(), "all items should be expired and cleaned up")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestSessionStore_SaveSameChainNoDelete(t *testing.T) {
|
||||||
|
store := NewDigestSessionStore()
|
||||||
|
|
||||||
|
// 保存 chain
|
||||||
|
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
|
||||||
|
|
||||||
|
// 用户重发相同消息:oldDigestChain == digestChain,不应删掉刚设置的 key
|
||||||
|
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "u:a-m:b")
|
||||||
|
|
||||||
|
// 仍然能找到
|
||||||
|
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b")
|
||||||
|
require.True(t, found)
|
||||||
|
assert.Equal(t, "uuid-1", uuid)
|
||||||
|
assert.Equal(t, int64(100), accountID)
|
||||||
|
}
|
||||||
@@ -224,22 +224,6 @@ func (m *mockGatewayCacheForPlatform) GetModelLoadBatch(ctx context.Context, acc
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockGatewayCacheForPlatform) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockGatewayCacheForPlatform) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockGatewayCacheForPlatform) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockGatewayCacheForPlatform) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockGroupRepoForGateway struct {
|
type mockGroupRepoForGateway struct {
|
||||||
groups map[int64]*Group
|
groups map[int64]*Group
|
||||||
getByIDCalls int
|
getByIDCalls int
|
||||||
|
|||||||
@@ -305,23 +305,6 @@ type GatewayCache interface {
|
|||||||
// GetModelLoadBatch 批量获取账号的模型负载信息(Antigravity 专用)
|
// GetModelLoadBatch 批量获取账号的模型负载信息(Antigravity 专用)
|
||||||
// Batch get model load info for accounts (Antigravity only)
|
// Batch get model load info for accounts (Antigravity only)
|
||||||
GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error)
|
GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error)
|
||||||
|
|
||||||
// FindGeminiSession 查找 Gemini 会话(MGET 倒序匹配)
|
|
||||||
// Find Gemini session using MGET reverse order matching
|
|
||||||
// 返回最长匹配的会话信息(uuid, accountID)
|
|
||||||
FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool)
|
|
||||||
|
|
||||||
// SaveGeminiSession 保存 Gemini 会话
|
|
||||||
// Save Gemini session binding
|
|
||||||
SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error
|
|
||||||
|
|
||||||
// FindAnthropicSession 查找 Anthropic 会话(Trie 匹配)
|
|
||||||
// Find Anthropic session using Trie matching
|
|
||||||
FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool)
|
|
||||||
|
|
||||||
// SaveAnthropicSession 保存 Anthropic 会话
|
|
||||||
// Save Anthropic session binding
|
|
||||||
SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
|
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
|
||||||
@@ -416,6 +399,7 @@ type GatewayService struct {
|
|||||||
userSubRepo UserSubscriptionRepository
|
userSubRepo UserSubscriptionRepository
|
||||||
userGroupRateRepo UserGroupRateRepository
|
userGroupRateRepo UserGroupRateRepository
|
||||||
cache GatewayCache
|
cache GatewayCache
|
||||||
|
digestStore *DigestSessionStore
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
schedulerSnapshot *SchedulerSnapshotService
|
schedulerSnapshot *SchedulerSnapshotService
|
||||||
billingService *BillingService
|
billingService *BillingService
|
||||||
@@ -449,6 +433,7 @@ func NewGatewayService(
|
|||||||
deferredService *DeferredService,
|
deferredService *DeferredService,
|
||||||
claudeTokenProvider *ClaudeTokenProvider,
|
claudeTokenProvider *ClaudeTokenProvider,
|
||||||
sessionLimitCache SessionLimitCache,
|
sessionLimitCache SessionLimitCache,
|
||||||
|
digestStore *DigestSessionStore,
|
||||||
) *GatewayService {
|
) *GatewayService {
|
||||||
return &GatewayService{
|
return &GatewayService{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
@@ -458,6 +443,7 @@ func NewGatewayService(
|
|||||||
userSubRepo: userSubRepo,
|
userSubRepo: userSubRepo,
|
||||||
userGroupRateRepo: userGroupRateRepo,
|
userGroupRateRepo: userGroupRateRepo,
|
||||||
cache: cache,
|
cache: cache,
|
||||||
|
digestStore: digestStore,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
schedulerSnapshot: schedulerSnapshot,
|
schedulerSnapshot: schedulerSnapshot,
|
||||||
concurrencyService: concurrencyService,
|
concurrencyService: concurrencyService,
|
||||||
@@ -557,35 +543,37 @@ func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID
|
|||||||
|
|
||||||
// FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配)
|
// FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配)
|
||||||
// 返回最长匹配的会话信息(uuid, accountID)
|
// 返回最长匹配的会话信息(uuid, accountID)
|
||||||
func (s *GatewayService) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
func (s *GatewayService) FindGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
|
||||||
if digestChain == "" || s.cache == nil {
|
if digestChain == "" || s.digestStore == nil {
|
||||||
return "", 0, false
|
return "", 0, "", false
|
||||||
}
|
}
|
||||||
return s.cache.FindGeminiSession(ctx, groupID, prefixHash, digestChain)
|
return s.digestStore.Find(groupID, prefixHash, digestChain)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveGeminiSession 保存 Gemini 会话
|
// SaveGeminiSession 保存 Gemini 会话。oldDigestChain 为 Find 返回的 matchedChain,用于删旧 key。
|
||||||
func (s *GatewayService) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
func (s *GatewayService) SaveGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error {
|
||||||
if digestChain == "" || s.cache == nil {
|
if digestChain == "" || s.digestStore == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return s.cache.SaveGeminiSession(ctx, groupID, prefixHash, digestChain, uuid, accountID)
|
s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// FindAnthropicSession 查找 Anthropic 会话(基于内容摘要链的 Fallback 匹配)
|
// FindAnthropicSession 查找 Anthropic 会话(基于内容摘要链的 Fallback 匹配)
|
||||||
func (s *GatewayService) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
func (s *GatewayService) FindAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
|
||||||
if digestChain == "" || s.cache == nil {
|
if digestChain == "" || s.digestStore == nil {
|
||||||
return "", 0, false
|
return "", 0, "", false
|
||||||
}
|
}
|
||||||
return s.cache.FindAnthropicSession(ctx, groupID, prefixHash, digestChain)
|
return s.digestStore.Find(groupID, prefixHash, digestChain)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveAnthropicSession 保存 Anthropic 会话
|
// SaveAnthropicSession 保存 Anthropic 会话
|
||||||
func (s *GatewayService) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
func (s *GatewayService) SaveAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error {
|
||||||
if digestChain == "" || s.cache == nil {
|
if digestChain == "" || s.digestStore == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return s.cache.SaveAnthropicSession(ctx, groupID, prefixHash, digestChain, uuid, accountID)
|
s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
|
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
|
||||||
|
|||||||
@@ -277,22 +277,6 @@ func (m *mockGatewayCacheForGemini) GetModelLoadBatch(ctx context.Context, accou
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockGatewayCacheForGemini) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockGatewayCacheForGemini) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockGatewayCacheForGemini) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockGatewayCacheForGemini) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
|
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
|
||||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
|
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|||||||
@@ -6,26 +6,11 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
"github.com/cespare/xxhash/v2"
|
"github.com/cespare/xxhash/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Gemini 会话 ID Fallback 相关常量
|
|
||||||
const (
|
|
||||||
// geminiSessionTTLSeconds Gemini 会话缓存 TTL(5 分钟)
|
|
||||||
geminiSessionTTLSeconds = 300
|
|
||||||
|
|
||||||
// geminiSessionKeyPrefix Gemini 会话 Redis key 前缀
|
|
||||||
geminiSessionKeyPrefix = "gemini:sess:"
|
|
||||||
)
|
|
||||||
|
|
||||||
// GeminiSessionTTL 返回 Gemini 会话缓存 TTL
|
|
||||||
func GeminiSessionTTL() time.Duration {
|
|
||||||
return geminiSessionTTLSeconds * time.Second
|
|
||||||
}
|
|
||||||
|
|
||||||
// shortHash 使用 XXHash64 + Base36 生成短 hash(16 字符)
|
// shortHash 使用 XXHash64 + Base36 生成短 hash(16 字符)
|
||||||
// XXHash64 比 SHA256 快约 10 倍,Base36 比 Hex 短约 20%
|
// XXHash64 比 SHA256 快约 10 倍,Base36 比 Hex 短约 20%
|
||||||
func shortHash(data []byte) string {
|
func shortHash(data []byte) string {
|
||||||
@@ -79,35 +64,6 @@ func GenerateGeminiPrefixHash(userID, apiKeyID int64, ip, userAgent, platform, m
|
|||||||
return base64.RawURLEncoding.EncodeToString(hash[:12])
|
return base64.RawURLEncoding.EncodeToString(hash[:12])
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildGeminiSessionKey 构建 Gemini 会话 Redis key
|
|
||||||
// 格式: gemini:sess:{groupID}:{prefixHash}:{digestChain}
|
|
||||||
func BuildGeminiSessionKey(groupID int64, prefixHash, digestChain string) string {
|
|
||||||
return geminiSessionKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash + ":" + digestChain
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateDigestChainPrefixes 生成摘要链的所有前缀(从长到短)
|
|
||||||
// 用于 MGET 批量查询最长匹配
|
|
||||||
func GenerateDigestChainPrefixes(chain string) []string {
|
|
||||||
if chain == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var prefixes []string
|
|
||||||
c := chain
|
|
||||||
|
|
||||||
for c != "" {
|
|
||||||
prefixes = append(prefixes, c)
|
|
||||||
// 找到最后一个 "-" 的位置
|
|
||||||
if i := strings.LastIndex(c, "-"); i > 0 {
|
|
||||||
c = c[:i]
|
|
||||||
} else {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return prefixes
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseGeminiSessionValue 解析 Gemini 会话缓存值
|
// ParseGeminiSessionValue 解析 Gemini 会话缓存值
|
||||||
// 格式: {uuid}:{accountID}
|
// 格式: {uuid}:{accountID}
|
||||||
func ParseGeminiSessionValue(value string) (uuid string, accountID int64, ok bool) {
|
func ParseGeminiSessionValue(value string) (uuid string, accountID int64, ok bool) {
|
||||||
@@ -139,15 +95,6 @@ func FormatGeminiSessionValue(uuid string, accountID int64) string {
|
|||||||
// geminiDigestSessionKeyPrefix Gemini 摘要 fallback 会话 key 前缀
|
// geminiDigestSessionKeyPrefix Gemini 摘要 fallback 会话 key 前缀
|
||||||
const geminiDigestSessionKeyPrefix = "gemini:digest:"
|
const geminiDigestSessionKeyPrefix = "gemini:digest:"
|
||||||
|
|
||||||
// geminiTrieKeyPrefix Gemini Trie 会话 key 前缀
|
|
||||||
const geminiTrieKeyPrefix = "gemini:trie:"
|
|
||||||
|
|
||||||
// BuildGeminiTrieKey 构建 Gemini Trie Redis key
|
|
||||||
// 格式: gemini:trie:{groupID}:{prefixHash}
|
|
||||||
func BuildGeminiTrieKey(groupID int64, prefixHash string) string {
|
|
||||||
return geminiTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateGeminiDigestSessionKey 生成 Gemini 摘要 fallback 的 sessionKey
|
// GenerateGeminiDigestSessionKey 生成 Gemini 摘要 fallback 的 sessionKey
|
||||||
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
|
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
|
||||||
// 用于在 SelectAccountWithLoadAwareness 中保持粘性会话
|
// 用于在 SelectAccountWithLoadAwareness 中保持粘性会话
|
||||||
|
|||||||
@@ -1,41 +1,14 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
)
|
)
|
||||||
|
|
||||||
// mockGeminiSessionCache 模拟 Redis 缓存
|
|
||||||
type mockGeminiSessionCache struct {
|
|
||||||
sessions map[string]string // key -> value
|
|
||||||
}
|
|
||||||
|
|
||||||
func newMockGeminiSessionCache() *mockGeminiSessionCache {
|
|
||||||
return &mockGeminiSessionCache{sessions: make(map[string]string)}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockGeminiSessionCache) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64) {
|
|
||||||
key := BuildGeminiSessionKey(groupID, prefixHash, digestChain)
|
|
||||||
value := FormatGeminiSessionValue(uuid, accountID)
|
|
||||||
m.sessions[key] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockGeminiSessionCache) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
|
||||||
prefixes := GenerateDigestChainPrefixes(digestChain)
|
|
||||||
for _, p := range prefixes {
|
|
||||||
key := BuildGeminiSessionKey(groupID, prefixHash, p)
|
|
||||||
if val, ok := m.sessions[key]; ok {
|
|
||||||
return ParseGeminiSessionValue(val)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestGeminiSessionContinuousConversation 测试连续会话的摘要链匹配
|
// TestGeminiSessionContinuousConversation 测试连续会话的摘要链匹配
|
||||||
func TestGeminiSessionContinuousConversation(t *testing.T) {
|
func TestGeminiSessionContinuousConversation(t *testing.T) {
|
||||||
cache := newMockGeminiSessionCache()
|
store := NewDigestSessionStore()
|
||||||
groupID := int64(1)
|
groupID := int64(1)
|
||||||
prefixHash := "test_prefix_hash"
|
prefixHash := "test_prefix_hash"
|
||||||
sessionUUID := "session-uuid-12345"
|
sessionUUID := "session-uuid-12345"
|
||||||
@@ -54,13 +27,13 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
|
|||||||
t.Logf("Round 1 chain: %s", chain1)
|
t.Logf("Round 1 chain: %s", chain1)
|
||||||
|
|
||||||
// 第一轮:没有找到会话,创建新会话
|
// 第一轮:没有找到会话,创建新会话
|
||||||
_, _, found := cache.Find(groupID, prefixHash, chain1)
|
_, _, _, found := store.Find(groupID, prefixHash, chain1)
|
||||||
if found {
|
if found {
|
||||||
t.Error("Round 1: should not find existing session")
|
t.Error("Round 1: should not find existing session")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 保存第一轮会话
|
// 保存第一轮会话(首轮无旧 chain)
|
||||||
cache.Save(groupID, prefixHash, chain1, sessionUUID, accountID)
|
store.Save(groupID, prefixHash, chain1, sessionUUID, accountID, "")
|
||||||
|
|
||||||
// 模拟第二轮对话(用户继续对话)
|
// 模拟第二轮对话(用户继续对话)
|
||||||
req2 := &antigravity.GeminiRequest{
|
req2 := &antigravity.GeminiRequest{
|
||||||
@@ -77,7 +50,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
|
|||||||
t.Logf("Round 2 chain: %s", chain2)
|
t.Logf("Round 2 chain: %s", chain2)
|
||||||
|
|
||||||
// 第二轮:应该能找到会话(通过前缀匹配)
|
// 第二轮:应该能找到会话(通过前缀匹配)
|
||||||
foundUUID, foundAccID, found := cache.Find(groupID, prefixHash, chain2)
|
foundUUID, foundAccID, matchedChain, found := store.Find(groupID, prefixHash, chain2)
|
||||||
if !found {
|
if !found {
|
||||||
t.Error("Round 2: should find session via prefix matching")
|
t.Error("Round 2: should find session via prefix matching")
|
||||||
}
|
}
|
||||||
@@ -88,8 +61,8 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
|
|||||||
t.Errorf("Round 2: expected accountID %d, got %d", accountID, foundAccID)
|
t.Errorf("Round 2: expected accountID %d, got %d", accountID, foundAccID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 保存第二轮会话
|
// 保存第二轮会话,传入 Find 返回的 matchedChain 以删旧 key
|
||||||
cache.Save(groupID, prefixHash, chain2, sessionUUID, accountID)
|
store.Save(groupID, prefixHash, chain2, sessionUUID, accountID, matchedChain)
|
||||||
|
|
||||||
// 模拟第三轮对话
|
// 模拟第三轮对话
|
||||||
req3 := &antigravity.GeminiRequest{
|
req3 := &antigravity.GeminiRequest{
|
||||||
@@ -108,7 +81,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
|
|||||||
t.Logf("Round 3 chain: %s", chain3)
|
t.Logf("Round 3 chain: %s", chain3)
|
||||||
|
|
||||||
// 第三轮:应该能找到会话(通过第二轮的前缀匹配)
|
// 第三轮:应该能找到会话(通过第二轮的前缀匹配)
|
||||||
foundUUID, foundAccID, found = cache.Find(groupID, prefixHash, chain3)
|
foundUUID, foundAccID, _, found = store.Find(groupID, prefixHash, chain3)
|
||||||
if !found {
|
if !found {
|
||||||
t.Error("Round 3: should find session via prefix matching")
|
t.Error("Round 3: should find session via prefix matching")
|
||||||
}
|
}
|
||||||
@@ -118,13 +91,11 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
|
|||||||
if foundAccID != accountID {
|
if foundAccID != accountID {
|
||||||
t.Errorf("Round 3: expected accountID %d, got %d", accountID, foundAccID)
|
t.Errorf("Round 3: expected accountID %d, got %d", accountID, foundAccID)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Log("✓ Continuous conversation session matching works correctly!")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestGeminiSessionDifferentConversations 测试不同会话不会错误匹配
|
// TestGeminiSessionDifferentConversations 测试不同会话不会错误匹配
|
||||||
func TestGeminiSessionDifferentConversations(t *testing.T) {
|
func TestGeminiSessionDifferentConversations(t *testing.T) {
|
||||||
cache := newMockGeminiSessionCache()
|
store := NewDigestSessionStore()
|
||||||
groupID := int64(1)
|
groupID := int64(1)
|
||||||
prefixHash := "test_prefix_hash"
|
prefixHash := "test_prefix_hash"
|
||||||
|
|
||||||
@@ -135,7 +106,7 @@ func TestGeminiSessionDifferentConversations(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
chain1 := BuildGeminiDigestChain(req1)
|
chain1 := BuildGeminiDigestChain(req1)
|
||||||
cache.Save(groupID, prefixHash, chain1, "session-1", 100)
|
store.Save(groupID, prefixHash, chain1, "session-1", 100, "")
|
||||||
|
|
||||||
// 第二个完全不同的会话
|
// 第二个完全不同的会话
|
||||||
req2 := &antigravity.GeminiRequest{
|
req2 := &antigravity.GeminiRequest{
|
||||||
@@ -146,61 +117,29 @@ func TestGeminiSessionDifferentConversations(t *testing.T) {
|
|||||||
chain2 := BuildGeminiDigestChain(req2)
|
chain2 := BuildGeminiDigestChain(req2)
|
||||||
|
|
||||||
// 不同会话不应该匹配
|
// 不同会话不应该匹配
|
||||||
_, _, found := cache.Find(groupID, prefixHash, chain2)
|
_, _, _, found := store.Find(groupID, prefixHash, chain2)
|
||||||
if found {
|
if found {
|
||||||
t.Error("Different conversations should not match")
|
t.Error("Different conversations should not match")
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Log("✓ Different conversations are correctly isolated!")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestGeminiSessionPrefixMatchingOrder 测试前缀匹配的优先级(最长匹配优先)
|
// TestGeminiSessionPrefixMatchingOrder 测试前缀匹配的优先级(最长匹配优先)
|
||||||
func TestGeminiSessionPrefixMatchingOrder(t *testing.T) {
|
func TestGeminiSessionPrefixMatchingOrder(t *testing.T) {
|
||||||
cache := newMockGeminiSessionCache()
|
store := NewDigestSessionStore()
|
||||||
groupID := int64(1)
|
groupID := int64(1)
|
||||||
prefixHash := "test_prefix_hash"
|
prefixHash := "test_prefix_hash"
|
||||||
|
|
||||||
// 创建一个三轮对话
|
|
||||||
req := &antigravity.GeminiRequest{
|
|
||||||
SystemInstruction: &antigravity.GeminiContent{
|
|
||||||
Parts: []antigravity.GeminiPart{{Text: "System prompt"}},
|
|
||||||
},
|
|
||||||
Contents: []antigravity.GeminiContent{
|
|
||||||
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q1"}}},
|
|
||||||
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "A1"}}},
|
|
||||||
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q2"}}},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
fullChain := BuildGeminiDigestChain(req)
|
|
||||||
prefixes := GenerateDigestChainPrefixes(fullChain)
|
|
||||||
|
|
||||||
t.Logf("Full chain: %s", fullChain)
|
|
||||||
t.Logf("Prefixes (longest first): %v", prefixes)
|
|
||||||
|
|
||||||
// 验证前缀生成顺序(从长到短)
|
|
||||||
if len(prefixes) != 4 {
|
|
||||||
t.Errorf("Expected 4 prefixes, got %d", len(prefixes))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 保存不同轮次的会话到不同账号
|
// 保存不同轮次的会话到不同账号
|
||||||
// 第一轮(最短前缀)-> 账号 1
|
store.Save(groupID, prefixHash, "s:sys-u:q1", "session-round1", 1, "")
|
||||||
cache.Save(groupID, prefixHash, prefixes[3], "session-round1", 1)
|
store.Save(groupID, prefixHash, "s:sys-u:q1-m:a1", "session-round2", 2, "")
|
||||||
// 第二轮 -> 账号 2
|
store.Save(groupID, prefixHash, "s:sys-u:q1-m:a1-u:q2", "session-round3", 3, "")
|
||||||
cache.Save(groupID, prefixHash, prefixes[2], "session-round2", 2)
|
|
||||||
// 第三轮(最长前缀,完整链)-> 账号 3
|
|
||||||
cache.Save(groupID, prefixHash, prefixes[0], "session-round3", 3)
|
|
||||||
|
|
||||||
// 查找应该返回最长匹配(账号 3)
|
// 查找更长的链,应该返回最长匹配(账号 3)
|
||||||
_, accID, found := cache.Find(groupID, prefixHash, fullChain)
|
_, accID, _, found := store.Find(groupID, prefixHash, "s:sys-u:q1-m:a1-u:q2-m:a2")
|
||||||
if !found {
|
if !found {
|
||||||
t.Error("Should find session")
|
t.Error("Should find session")
|
||||||
}
|
}
|
||||||
if accID != 3 {
|
if accID != 3 {
|
||||||
t.Errorf("Should match longest prefix (account 3), got account %d", accID)
|
t.Errorf("Should match longest prefix (account 3), got account %d", accID)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Log("✓ Longest prefix matching works correctly!")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 确保 context 包被使用(避免未使用的导入警告)
|
|
||||||
var _ = context.Background
|
|
||||||
|
|||||||
@@ -152,61 +152,6 @@ func TestGenerateGeminiPrefixHash(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGenerateDigestChainPrefixes(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
chain string
|
|
||||||
want []string
|
|
||||||
wantLen int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "empty",
|
|
||||||
chain: "",
|
|
||||||
wantLen: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "single part",
|
|
||||||
chain: "u:abc123",
|
|
||||||
want: []string{"u:abc123"},
|
|
||||||
wantLen: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "two parts",
|
|
||||||
chain: "s:xyz-u:abc",
|
|
||||||
want: []string{"s:xyz-u:abc", "s:xyz"},
|
|
||||||
wantLen: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "four parts",
|
|
||||||
chain: "s:a-u:b-m:c-u:d",
|
|
||||||
want: []string{"s:a-u:b-m:c-u:d", "s:a-u:b-m:c", "s:a-u:b", "s:a"},
|
|
||||||
wantLen: 4,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := GenerateDigestChainPrefixes(tt.chain)
|
|
||||||
|
|
||||||
if len(result) != tt.wantLen {
|
|
||||||
t.Errorf("expected %d prefixes, got %d: %v", tt.wantLen, len(result), result)
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.want != nil {
|
|
||||||
for i, want := range tt.want {
|
|
||||||
if i >= len(result) {
|
|
||||||
t.Errorf("missing prefix at index %d", i)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if result[i] != want {
|
|
||||||
t.Errorf("prefix[%d]: expected %s, got %s", i, want, result[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseGeminiSessionValue(t *testing.T) {
|
func TestParseGeminiSessionValue(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -442,40 +387,3 @@ func TestGenerateGeminiDigestSessionKey(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBuildGeminiTrieKey(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
groupID int64
|
|
||||||
prefixHash string
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "normal",
|
|
||||||
groupID: 123,
|
|
||||||
prefixHash: "abcdef12",
|
|
||||||
want: "gemini:trie:123:abcdef12",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "zero group",
|
|
||||||
groupID: 0,
|
|
||||||
prefixHash: "xyz",
|
|
||||||
want: "gemini:trie:0:xyz",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty prefix",
|
|
||||||
groupID: 1,
|
|
||||||
prefixHash: "",
|
|
||||||
want: "gemini:trie:1:",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got := BuildGeminiTrieKey(tt.groupID, tt.prefixHash)
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("BuildGeminiTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -212,22 +212,6 @@ func (c *stubGatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []i
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *stubGatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *stubGatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *stubGatewayCache) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *stubGatewayCache) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
|
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
resetAt := now.Add(10 * time.Minute)
|
resetAt := now.Add(10 * time.Minute)
|
||||||
|
|||||||
@@ -275,4 +275,5 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewUsageCache,
|
NewUsageCache,
|
||||||
NewTotpService,
|
NewTotpService,
|
||||||
NewErrorPassthroughService,
|
NewErrorPassthroughService,
|
||||||
|
NewDigestSessionStore,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user