From b889d5017b9e61d7dec4a856d6a85a26b3c08929 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 9 Feb 2026 07:02:12 +0800 Subject: [PATCH] refactor: replace Trie-based digest session store with flat cache --- backend/cmd/server/wire_gen.go | 3 +- backend/go.mod | 1 + backend/go.sum | 2 + .../internal/handler/gemini_v1beta_handler.go | 5 +- backend/internal/repository/gateway_cache.go | 137 -------- .../gateway_cache_integration_test.go | 151 --------- backend/internal/service/anthropic_session.go | 10 - .../service/anthropic_session_test.go | 37 --- .../internal/service/digest_session_store.go | 69 ++++ .../service/digest_session_store_test.go | 312 ++++++++++++++++++ .../service/gateway_multiplatform_test.go | 16 - backend/internal/service/gateway_service.go | 52 ++- .../service/gemini_multiplatform_test.go | 16 - backend/internal/service/gemini_session.go | 53 --- .../gemini_session_integration_test.go | 95 +----- .../internal/service/gemini_session_test.go | 92 ------ .../service/openai_gateway_service_test.go | 16 - backend/internal/service/wire.go | 1 + 18 files changed, 428 insertions(+), 640 deletions(-) create mode 100644 backend/internal/service/digest_session_store.go create mode 100644 backend/internal/service/digest_session_store_test.go diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index ef205dc8..5ccd797e 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -154,7 +154,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { identityService := service.NewIdentityService(identityCache) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService) - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache) + digestSessionStore := service.NewDigestSessionStore() + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, digestSessionStore) openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) diff --git a/backend/go.mod b/backend/go.mod index 6916057f..08d54b91 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -103,6 +103,7 @@ require ( github.com/ncruces/go-strftime v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect + github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect diff --git a/backend/go.sum b/backend/go.sum index 90470fbc..71e8f504 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -213,6 +213,8 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8 github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= +github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= +github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index d29749c7..d5149f22 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -259,6 +259,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { var geminiDigestChain string var geminiPrefixHash string var geminiSessionUUID string + var matchedDigestChain string useDigestFallback := sessionBoundAccountID == 0 if useDigestFallback { @@ -285,13 +286,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ) // 查找会话 - foundUUID, foundAccountID, found := h.gatewayService.FindGeminiSession( + foundUUID, foundAccountID, foundMatchedChain, found := h.gatewayService.FindGeminiSession( c.Request.Context(), derefGroupID(apiKey.GroupID), geminiPrefixHash, geminiDigestChain, ) if found { + matchedDigestChain = foundMatchedChain sessionBoundAccountID = foundAccountID geminiSessionUUID = foundUUID log.Printf("[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s", @@ -458,6 +460,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { geminiDigestChain, geminiSessionUUID, account.ID, + matchedDigestChain, ); err != nil { log.Printf("[Gemini] Failed to save digest session: %v", err) } diff --git a/backend/internal/repository/gateway_cache.go b/backend/internal/repository/gateway_cache.go index b9cc521e..2c4f3b8e 100644 --- a/backend/internal/repository/gateway_cache.go +++ b/backend/internal/repository/gateway_cache.go @@ -11,64 +11,6 @@ import ( const stickySessionPrefix = "sticky_session:" -// Gemini Trie Lua 脚本 -const ( - // geminiTrieFindScript 查找最长前缀匹配的 Lua 脚本 - // KEYS[1] = trie key - // ARGV[1] = digestChain (如 "u:a-m:b-u:c-m:d") - // ARGV[2] = TTL seconds (用于刷新) - // 返回: 最长匹配的 value (uuid:accountID) 或 nil - // 查找成功时自动刷新 TTL,防止活跃会话意外过期 - // 从最长前缀(完整 chain)开始逐步缩短,第一次命中即返回 - geminiTrieFindScript = ` -local chain = ARGV[1] -local ttl = tonumber(ARGV[2]) - --- 先尝试完整 chain(最常见场景:同一对话的下一轮请求) -local val = redis.call('HGET', KEYS[1], chain) -if val and val ~= "" then - redis.call('EXPIRE', KEYS[1], ttl) - return val -end - --- 从最长前缀开始逐步缩短(去掉最后一个 "-xxx" 段) -local path = chain -while true do - local i = string.find(path, "-[^-]*$") - if not i or i <= 1 then - break - end - path = string.sub(path, 1, i - 1) - val = redis.call('HGET', KEYS[1], path) - if val and val ~= "" then - redis.call('EXPIRE', KEYS[1], ttl) - return val - end -end - -return nil -` - - // geminiTrieSaveScript 保存会话到 Trie 的 Lua 脚本 - // KEYS[1] = trie key - // ARGV[1] = digestChain - // ARGV[2] = value (uuid:accountID) - // ARGV[3] = TTL seconds - geminiTrieSaveScript = ` -local chain = ARGV[1] -local value = ARGV[2] -local ttl = tonumber(ARGV[3]) -local path = "" - -for part in string.gmatch(chain, "[^-]+") do - path = path == "" and part or path .. "-" .. part -end -redis.call('HSET', KEYS[1], path, value) -redis.call('EXPIRE', KEYS[1], ttl) -return "OK" -` -) - // 模型负载统计相关常量 const ( modelLoadKeyPrefix = "ag:model_load:" // 模型调用次数 key 前缀 @@ -206,82 +148,3 @@ func getTimeOrZero(cmd *redis.StringCmd) time.Time { } return time.Unix(val, 0) } - -// ============ Gemini 会话 Fallback 方法 (Trie 实现) ============ - -// FindGeminiSession 查找 Gemini 会话(使用 Trie + Lua 脚本实现 O(L) 查询) -// 返回最长匹配的会话信息,匹配成功时自动刷新 TTL -func (c *gatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - if digestChain == "" { - return "", 0, false - } - - trieKey := service.BuildGeminiTrieKey(groupID, prefixHash) - ttlSeconds := int(service.GeminiSessionTTL().Seconds()) - - // 使用 Lua 脚本在 Redis 端执行 Trie 查找,O(L) 次 HGET,1 次网络往返 - // 查找成功时自动刷新 TTL,防止活跃会话意外过期 - result, err := c.rdb.Eval(ctx, geminiTrieFindScript, []string{trieKey}, digestChain, ttlSeconds).Result() - if err != nil || result == nil { - return "", 0, false - } - - value, ok := result.(string) - if !ok || value == "" { - return "", 0, false - } - - uuid, accountID, ok = service.ParseGeminiSessionValue(value) - return uuid, accountID, ok -} - -// SaveGeminiSession 保存 Gemini 会话(使用 Trie + Lua 脚本) -func (c *gatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - if digestChain == "" { - return nil - } - - trieKey := service.BuildGeminiTrieKey(groupID, prefixHash) - value := service.FormatGeminiSessionValue(uuid, accountID) - ttlSeconds := int(service.GeminiSessionTTL().Seconds()) - - return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err() -} - -// ============ Anthropic 会话 Fallback 方法 (复用 Trie 实现) ============ - -// FindAnthropicSession 查找 Anthropic 会话(复用 Gemini Trie Lua 脚本) -func (c *gatewayCache) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - if digestChain == "" { - return "", 0, false - } - - trieKey := service.BuildAnthropicTrieKey(groupID, prefixHash) - ttlSeconds := int(service.AnthropicSessionTTL().Seconds()) - - result, err := c.rdb.Eval(ctx, geminiTrieFindScript, []string{trieKey}, digestChain, ttlSeconds).Result() - if err != nil || result == nil { - return "", 0, false - } - - value, ok := result.(string) - if !ok || value == "" { - return "", 0, false - } - - uuid, accountID, ok = service.ParseGeminiSessionValue(value) - return uuid, accountID, ok -} - -// SaveAnthropicSession 保存 Anthropic 会话(复用 Gemini Trie Lua 脚本) -func (c *gatewayCache) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - if digestChain == "" { - return nil - } - - trieKey := service.BuildAnthropicTrieKey(groupID, prefixHash) - value := service.FormatGeminiSessionValue(uuid, accountID) - ttlSeconds := int(service.AnthropicSessionTTL().Seconds()) - - return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err() -} diff --git a/backend/internal/repository/gateway_cache_integration_test.go b/backend/internal/repository/gateway_cache_integration_test.go index fc8e7372..2fdaa3d1 100644 --- a/backend/internal/repository/gateway_cache_integration_test.go +++ b/backend/internal/repository/gateway_cache_integration_test.go @@ -104,157 +104,6 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() { require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil") } -// ============ Gemini Trie 会话测试 ============ - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_SaveAndFind() { - groupID := int64(1) - prefixHash := "testprefix" - digestChain := "u:hash1-m:hash2-u:hash3" - uuid := "test-uuid-123" - accountID := int64(42) - - // 保存会话 - err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, uuid, accountID) - require.NoError(s.T(), err, "SaveGeminiSession") - - // 精确匹配查找 - foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, digestChain) - require.True(s.T(), found, "should find exact match") - require.Equal(s.T(), uuid, foundUUID) - require.Equal(s.T(), accountID, foundAccountID) -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_PrefixMatch() { - groupID := int64(1) - prefixHash := "prefixmatch" - shortChain := "u:a-m:b" - longChain := "u:a-m:b-u:c-m:d" - uuid := "uuid-prefix" - accountID := int64(100) - - // 保存短链 - err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, shortChain, uuid, accountID) - require.NoError(s.T(), err) - - // 用长链查找,应该匹配到短链(前缀匹配) - foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, longChain) - require.True(s.T(), found, "should find prefix match") - require.Equal(s.T(), uuid, foundUUID) - require.Equal(s.T(), accountID, foundAccountID) -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_LongestPrefixMatch() { - groupID := int64(1) - prefixHash := "longestmatch" - - // 保存多个不同长度的链 - err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a", "uuid-short", 1) - require.NoError(s.T(), err) - err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b", "uuid-medium", 2) - require.NoError(s.T(), err) - err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c", "uuid-long", 3) - require.NoError(s.T(), err) - - // 查找更长的链,应该匹配到最长的前缀 - foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c-m:d-u:e") - require.True(s.T(), found, "should find longest prefix match") - require.Equal(s.T(), "uuid-long", foundUUID) - require.Equal(s.T(), int64(3), foundAccountID) - - // 查找中等长度的链 - foundUUID, foundAccountID, found = s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:x") - require.True(s.T(), found) - require.Equal(s.T(), "uuid-medium", foundUUID) - require.Equal(s.T(), int64(2), foundAccountID) -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_NoMatch() { - groupID := int64(1) - prefixHash := "nomatch" - digestChain := "u:a-m:b" - - // 保存一个会话 - err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, "uuid", 1) - require.NoError(s.T(), err) - - // 用不同的链查找,应该找不到 - _, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:x-m:y") - require.False(s.T(), found, "should not find non-matching chain") -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentPrefixHash() { - groupID := int64(1) - digestChain := "u:a-m:b" - - // 保存到 prefixHash1 - err := s.cache.SaveGeminiSession(s.ctx, groupID, "prefix1", digestChain, "uuid1", 1) - require.NoError(s.T(), err) - - // 用 prefixHash2 查找,应该找不到(不同用户/客户端隔离) - _, _, found := s.cache.FindGeminiSession(s.ctx, groupID, "prefix2", digestChain) - require.False(s.T(), found, "different prefixHash should be isolated") -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentGroupID() { - prefixHash := "sameprefix" - digestChain := "u:a-m:b" - - // 保存到 groupID 1 - err := s.cache.SaveGeminiSession(s.ctx, 1, prefixHash, digestChain, "uuid1", 1) - require.NoError(s.T(), err) - - // 用 groupID 2 查找,应该找不到(分组隔离) - _, _, found := s.cache.FindGeminiSession(s.ctx, 2, prefixHash, digestChain) - require.False(s.T(), found, "different groupID should be isolated") -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_EmptyDigestChain() { - groupID := int64(1) - prefixHash := "emptytest" - - // 空链不应该保存 - err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "", "uuid", 1) - require.NoError(s.T(), err, "empty chain should not error") - - // 空链查找应该返回 false - _, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "") - require.False(s.T(), found, "empty chain should not match") -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_MultipleSessions() { - groupID := int64(1) - prefixHash := "multisession" - - // 保存多个不同会话(模拟 1000 个并发会话的场景) - sessions := []struct { - chain string - uuid string - accountID int64 - }{ - {"u:session1", "uuid-1", 1}, - {"u:session2-m:reply2", "uuid-2", 2}, - {"u:session3-m:reply3-u:msg3", "uuid-3", 3}, - } - - for _, sess := range sessions { - err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, sess.chain, sess.uuid, sess.accountID) - require.NoError(s.T(), err) - } - - // 验证每个会话都能正确查找 - for _, sess := range sessions { - foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, sess.chain) - require.True(s.T(), found, "should find session: %s", sess.chain) - require.Equal(s.T(), sess.uuid, foundUUID) - require.Equal(s.T(), sess.accountID, foundAccountID) - } - - // 验证继续对话的场景 - foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:session2-m:reply2-u:newmsg") - require.True(s.T(), found) - require.Equal(s.T(), "uuid-2", foundUUID) - require.Equal(s.T(), int64(2), foundAccountID) -} func TestGatewayCacheSuite(t *testing.T) { suite.Run(t, new(GatewayCacheSuite)) diff --git a/backend/internal/service/anthropic_session.go b/backend/internal/service/anthropic_session.go index 2d86ed35..26544c68 100644 --- a/backend/internal/service/anthropic_session.go +++ b/backend/internal/service/anthropic_session.go @@ -2,7 +2,6 @@ package service import ( "encoding/json" - "strconv" "strings" "time" ) @@ -12,9 +11,6 @@ const ( // anthropicSessionTTLSeconds Anthropic 会话缓存 TTL(5 分钟) anthropicSessionTTLSeconds = 300 - // anthropicTrieKeyPrefix Anthropic Trie 会话 key 前缀 - anthropicTrieKeyPrefix = "anthropic:trie:" - // anthropicDigestSessionKeyPrefix Anthropic 摘要 fallback 会话 key 前缀 anthropicDigestSessionKeyPrefix = "anthropic:digest:" ) @@ -68,12 +64,6 @@ func rolePrefix(role string) string { } } -// BuildAnthropicTrieKey 构建 Anthropic Trie Redis key -// 格式: anthropic:trie:{groupID}:{prefixHash} -func BuildAnthropicTrieKey(groupID int64, prefixHash string) string { - return anthropicTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash -} - // GenerateAnthropicDigestSessionKey 生成 Anthropic 摘要 fallback 的 sessionKey // 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey func GenerateAnthropicDigestSessionKey(prefixHash, uuid string) string { diff --git a/backend/internal/service/anthropic_session_test.go b/backend/internal/service/anthropic_session_test.go index e2f873e7..10406643 100644 --- a/backend/internal/service/anthropic_session_test.go +++ b/backend/internal/service/anthropic_session_test.go @@ -236,43 +236,6 @@ func TestBuildAnthropicDigestChain_Deterministic(t *testing.T) { } } -func TestBuildAnthropicTrieKey(t *testing.T) { - tests := []struct { - name string - groupID int64 - prefixHash string - want string - }{ - { - name: "normal", - groupID: 123, - prefixHash: "abcdef12", - want: "anthropic:trie:123:abcdef12", - }, - { - name: "zero group", - groupID: 0, - prefixHash: "xyz", - want: "anthropic:trie:0:xyz", - }, - { - name: "empty prefix", - groupID: 1, - prefixHash: "", - want: "anthropic:trie:1:", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := BuildAnthropicTrieKey(tt.groupID, tt.prefixHash) - if got != tt.want { - t.Errorf("BuildAnthropicTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want) - } - }) - } -} - func TestGenerateAnthropicDigestSessionKey(t *testing.T) { tests := []struct { name string diff --git a/backend/internal/service/digest_session_store.go b/backend/internal/service/digest_session_store.go new file mode 100644 index 00000000..3ac08936 --- /dev/null +++ b/backend/internal/service/digest_session_store.go @@ -0,0 +1,69 @@ +package service + +import ( + "strconv" + "strings" + "time" + + gocache "github.com/patrickmn/go-cache" +) + +// digestSessionTTL 摘要会话默认 TTL +const digestSessionTTL = 5 * time.Minute + +// sessionEntry flat cache 条目 +type sessionEntry struct { + uuid string + accountID int64 +} + +// DigestSessionStore 内存摘要会话存储(flat cache 实现) +// key: "{groupID}:{prefixHash}|{digestChain}" → *sessionEntry +type DigestSessionStore struct { + cache *gocache.Cache +} + +// NewDigestSessionStore 创建内存摘要会话存储 +func NewDigestSessionStore() *DigestSessionStore { + return &DigestSessionStore{ + cache: gocache.New(digestSessionTTL, time.Minute), + } +} + +// Save 保存摘要会话。oldDigestChain 为 Find 返回的 matchedChain,用于删旧 key。 +func (s *DigestSessionStore) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) { + if digestChain == "" { + return + } + ns := buildNS(groupID, prefixHash) + s.cache.Set(ns+digestChain, &sessionEntry{uuid: uuid, accountID: accountID}, gocache.DefaultExpiration) + if oldDigestChain != "" && oldDigestChain != digestChain { + s.cache.Delete(ns + oldDigestChain) + } +} + +// Find 查找摘要会话,从完整 chain 逐段截断,返回最长匹配及对应 matchedChain。 +func (s *DigestSessionStore) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) { + if digestChain == "" { + return "", 0, "", false + } + ns := buildNS(groupID, prefixHash) + chain := digestChain + for { + if val, ok := s.cache.Get(ns + chain); ok { + if e, ok := val.(*sessionEntry); ok { + return e.uuid, e.accountID, chain, true + } + } + i := strings.LastIndex(chain, "-") + if i < 0 { + return "", 0, "", false + } + chain = chain[:i] + } +} + +// buildNS 构建 namespace 前缀 +func buildNS(groupID int64, prefixHash string) string { + return strconv.FormatInt(groupID, 10) + ":" + prefixHash + "|" +} diff --git a/backend/internal/service/digest_session_store_test.go b/backend/internal/service/digest_session_store_test.go new file mode 100644 index 00000000..e505bf30 --- /dev/null +++ b/backend/internal/service/digest_session_store_test.go @@ -0,0 +1,312 @@ +//go:build unit + +package service + +import ( + "fmt" + "sync" + "testing" + "time" + + gocache "github.com/patrickmn/go-cache" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDigestSessionStore_SaveAndFind(t *testing.T) { + store := NewDigestSessionStore() + + store.Save(1, "prefix", "s:a1-u:b2-m:c3", "uuid-1", 100, "") + + uuid, accountID, _, found := store.Find(1, "prefix", "s:a1-u:b2-m:c3") + require.True(t, found) + assert.Equal(t, "uuid-1", uuid) + assert.Equal(t, int64(100), accountID) +} + +func TestDigestSessionStore_PrefixMatch(t *testing.T) { + store := NewDigestSessionStore() + + // 保存短链 + store.Save(1, "prefix", "u:a-m:b", "uuid-short", 10, "") + + // 用长链查找,应前缀匹配到短链 + uuid, accountID, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d") + require.True(t, found) + assert.Equal(t, "uuid-short", uuid) + assert.Equal(t, int64(10), accountID) + assert.Equal(t, "u:a-m:b", matchedChain) +} + +func TestDigestSessionStore_LongestPrefixMatch(t *testing.T) { + store := NewDigestSessionStore() + + store.Save(1, "prefix", "u:a", "uuid-1", 1, "") + store.Save(1, "prefix", "u:a-m:b", "uuid-2", 2, "") + store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-3", 3, "") + + // 应匹配最深的 "u:a-m:b-u:c"(从完整 chain 逐段截断,先命中最长的) + uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e") + require.True(t, found) + assert.Equal(t, "uuid-3", uuid) + assert.Equal(t, int64(3), accountID) + + // 查找中等长度,应匹配到 "u:a-m:b" + uuid, accountID, _, found = store.Find(1, "prefix", "u:a-m:b-u:x") + require.True(t, found) + assert.Equal(t, "uuid-2", uuid) + assert.Equal(t, int64(2), accountID) +} + +func TestDigestSessionStore_SaveDeletesOldChain(t *testing.T) { + store := NewDigestSessionStore() + + // 第一轮:保存 "u:a-m:b" + store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "") + + // 第二轮:同一 uuid 保存更长的链,传入旧 chain + store.Save(1, "prefix", "u:a-m:b-u:c-m:d", "uuid-1", 100, "u:a-m:b") + + // 旧链 "u:a-m:b" 应已被删除 + _, _, _, found := store.Find(1, "prefix", "u:a-m:b") + assert.False(t, found, "old chain should be deleted") + + // 新链应能找到 + uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d") + require.True(t, found) + assert.Equal(t, "uuid-1", uuid) + assert.Equal(t, int64(100), accountID) +} + +func TestDigestSessionStore_DifferentSessionsNoInterference(t *testing.T) { + store := NewDigestSessionStore() + + // 相同系统提示词,不同用户提示词 + store.Save(1, "prefix", "s:sys-u:user1", "uuid-1", 100, "") + store.Save(1, "prefix", "s:sys-u:user2", "uuid-2", 200, "") + + uuid, accountID, _, found := store.Find(1, "prefix", "s:sys-u:user1-m:reply1") + require.True(t, found) + assert.Equal(t, "uuid-1", uuid) + assert.Equal(t, int64(100), accountID) + + uuid, accountID, _, found = store.Find(1, "prefix", "s:sys-u:user2-m:reply2") + require.True(t, found) + assert.Equal(t, "uuid-2", uuid) + assert.Equal(t, int64(200), accountID) +} + +func TestDigestSessionStore_NoMatch(t *testing.T) { + store := NewDigestSessionStore() + + store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "") + + // 完全不同的 chain + _, _, _, found := store.Find(1, "prefix", "u:x-m:y") + assert.False(t, found) +} + +func TestDigestSessionStore_DifferentPrefixHash(t *testing.T) { + store := NewDigestSessionStore() + + store.Save(1, "prefix1", "u:a-m:b", "uuid-1", 100, "") + + // 不同 prefixHash 应隔离 + _, _, _, found := store.Find(1, "prefix2", "u:a-m:b") + assert.False(t, found) +} + +func TestDigestSessionStore_DifferentGroupID(t *testing.T) { + store := NewDigestSessionStore() + + store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "") + + // 不同 groupID 应隔离 + _, _, _, found := store.Find(2, "prefix", "u:a-m:b") + assert.False(t, found) +} + +func TestDigestSessionStore_EmptyDigestChain(t *testing.T) { + store := NewDigestSessionStore() + + // 空链不应保存 + store.Save(1, "prefix", "", "uuid-1", 100, "") + _, _, _, found := store.Find(1, "prefix", "") + assert.False(t, found) +} + +func TestDigestSessionStore_TTLExpiration(t *testing.T) { + store := &DigestSessionStore{ + cache: gocache.New(100*time.Millisecond, 50*time.Millisecond), + } + + store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "") + + // 立即应该能找到 + _, _, _, found := store.Find(1, "prefix", "u:a-m:b") + require.True(t, found) + + // 等待过期 + 清理周期 + time.Sleep(300 * time.Millisecond) + + // 过期后应找不到 + _, _, _, found = store.Find(1, "prefix", "u:a-m:b") + assert.False(t, found) +} + +func TestDigestSessionStore_ConcurrentSafety(t *testing.T) { + store := NewDigestSessionStore() + + var wg sync.WaitGroup + const goroutines = 50 + const operations = 100 + + wg.Add(goroutines) + for g := 0; g < goroutines; g++ { + go func(id int) { + defer wg.Done() + prefix := fmt.Sprintf("prefix-%d", id%5) + for i := 0; i < operations; i++ { + chain := fmt.Sprintf("u:%d-m:%d", id, i) + uuid := fmt.Sprintf("uuid-%d-%d", id, i) + store.Save(1, prefix, chain, uuid, int64(id), "") + store.Find(1, prefix, chain) + } + }(g) + } + wg.Wait() +} + +func TestDigestSessionStore_MultipleSessions(t *testing.T) { + store := NewDigestSessionStore() + + sessions := []struct { + chain string + uuid string + accountID int64 + }{ + {"u:session1", "uuid-1", 1}, + {"u:session2-m:reply2", "uuid-2", 2}, + {"u:session3-m:reply3-u:msg3", "uuid-3", 3}, + } + + for _, sess := range sessions { + store.Save(1, "prefix", sess.chain, sess.uuid, sess.accountID, "") + } + + // 验证每个会话都能正确查找 + for _, sess := range sessions { + uuid, accountID, _, found := store.Find(1, "prefix", sess.chain) + require.True(t, found, "should find session: %s", sess.chain) + assert.Equal(t, sess.uuid, uuid) + assert.Equal(t, sess.accountID, accountID) + } + + // 验证继续对话的场景 + uuid, accountID, _, found := store.Find(1, "prefix", "u:session2-m:reply2-u:newmsg") + require.True(t, found) + assert.Equal(t, "uuid-2", uuid) + assert.Equal(t, int64(2), accountID) +} + +func TestDigestSessionStore_Performance1000Sessions(t *testing.T) { + store := NewDigestSessionStore() + + // 插入 1000 个会话 + for i := 0; i < 1000; i++ { + chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d", i, i) + store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "") + } + + // 查找性能测试 + start := time.Now() + const lookups = 10000 + for i := 0; i < lookups; i++ { + idx := i % 1000 + chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d-u:newmsg", idx, idx) + _, _, _, found := store.Find(1, "prefix", chain) + assert.True(t, found) + } + elapsed := time.Since(start) + t.Logf("%d lookups in %v (%.0f ns/op)", lookups, elapsed, float64(elapsed.Nanoseconds())/lookups) +} + +func TestDigestSessionStore_FindReturnsMatchedChain(t *testing.T) { + store := NewDigestSessionStore() + + store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-1", 100, "") + + // 精确匹配 + _, _, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c") + require.True(t, found) + assert.Equal(t, "u:a-m:b-u:c", matchedChain) + + // 前缀匹配(截断后命中) + _, _, matchedChain, found = store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e") + require.True(t, found) + assert.Equal(t, "u:a-m:b-u:c", matchedChain) +} + +func TestDigestSessionStore_CacheItemCountStable(t *testing.T) { + store := NewDigestSessionStore() + + // 模拟 100 个独立会话,每个进行 10 轮对话 + // 正确传递 oldDigestChain 时,每个会话始终只保留 1 个 key + for conv := 0; conv < 100; conv++ { + var prevMatchedChain string + for round := 0; round < 10; round++ { + chain := fmt.Sprintf("s:sys-u:user%d", conv) + for r := 0; r < round; r++ { + chain += fmt.Sprintf("-m:a%d-u:q%d", r, r+1) + } + uuid := fmt.Sprintf("uuid-conv%d", conv) + + _, _, matched, _ := store.Find(1, "prefix", chain) + store.Save(1, "prefix", chain, uuid, int64(conv), matched) + prevMatchedChain = matched + _ = prevMatchedChain + } + } + + // 100 个会话 × 1 key/会话 = 应该 ≤ 100 个 key + // 允许少量并发残留,但绝不能接近 100×10=1000 + itemCount := store.cache.ItemCount() + assert.LessOrEqual(t, itemCount, 100, "cache should have at most 100 items (1 per conversation), got %d", itemCount) + t.Logf("Cache item count after 100 conversations × 10 rounds: %d", itemCount) +} + +func TestDigestSessionStore_TTLPreventsUnboundedGrowth(t *testing.T) { + // 使用极短 TTL 验证大量写入后 cache 能被清理 + store := &DigestSessionStore{ + cache: gocache.New(100*time.Millisecond, 50*time.Millisecond), + } + + // 插入 500 个不同的 key(无 oldDigestChain,模拟最坏场景:全是新会话首轮) + for i := 0; i < 500; i++ { + chain := fmt.Sprintf("u:user%d", i) + store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "") + } + + assert.Equal(t, 500, store.cache.ItemCount()) + + // 等待 TTL + 清理周期 + time.Sleep(300 * time.Millisecond) + + assert.Equal(t, 0, store.cache.ItemCount(), "all items should be expired and cleaned up") +} + +func TestDigestSessionStore_SaveSameChainNoDelete(t *testing.T) { + store := NewDigestSessionStore() + + // 保存 chain + store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "") + + // 用户重发相同消息:oldDigestChain == digestChain,不应删掉刚设置的 key + store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "u:a-m:b") + + // 仍然能找到 + uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b") + require.True(t, found) + assert.Equal(t, "uuid-1", uuid) + assert.Equal(t, int64(100), accountID) +} diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index bb0c97e8..069ea7d7 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -224,22 +224,6 @@ func (m *mockGatewayCacheForPlatform) GetModelLoadBatch(ctx context.Context, acc return nil, nil } -func (m *mockGatewayCacheForPlatform) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - return "", 0, false -} - -func (m *mockGatewayCacheForPlatform) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - return nil -} - -func (m *mockGatewayCacheForPlatform) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - return "", 0, false -} - -func (m *mockGatewayCacheForPlatform) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - return nil -} - type mockGroupRepoForGateway struct { groups map[int64]*Group getByIDCalls int diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 6572f25d..af8838dc 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -305,23 +305,6 @@ type GatewayCache interface { // GetModelLoadBatch 批量获取账号的模型负载信息(Antigravity 专用) // Batch get model load info for accounts (Antigravity only) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) - - // FindGeminiSession 查找 Gemini 会话(MGET 倒序匹配) - // Find Gemini session using MGET reverse order matching - // 返回最长匹配的会话信息(uuid, accountID) - FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) - - // SaveGeminiSession 保存 Gemini 会话 - // Save Gemini session binding - SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error - - // FindAnthropicSession 查找 Anthropic 会话(Trie 匹配) - // Find Anthropic session using Trie matching - FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) - - // SaveAnthropicSession 保存 Anthropic 会话 - // Save Anthropic session binding - SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error } // derefGroupID safely dereferences *int64 to int64, returning 0 if nil @@ -416,6 +399,7 @@ type GatewayService struct { userSubRepo UserSubscriptionRepository userGroupRateRepo UserGroupRateRepository cache GatewayCache + digestStore *DigestSessionStore cfg *config.Config schedulerSnapshot *SchedulerSnapshotService billingService *BillingService @@ -449,6 +433,7 @@ func NewGatewayService( deferredService *DeferredService, claudeTokenProvider *ClaudeTokenProvider, sessionLimitCache SessionLimitCache, + digestStore *DigestSessionStore, ) *GatewayService { return &GatewayService{ accountRepo: accountRepo, @@ -458,6 +443,7 @@ func NewGatewayService( userSubRepo: userSubRepo, userGroupRateRepo: userGroupRateRepo, cache: cache, + digestStore: digestStore, cfg: cfg, schedulerSnapshot: schedulerSnapshot, concurrencyService: concurrencyService, @@ -557,35 +543,37 @@ func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID // FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配) // 返回最长匹配的会话信息(uuid, accountID) -func (s *GatewayService) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - if digestChain == "" || s.cache == nil { - return "", 0, false +func (s *GatewayService) FindGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) { + if digestChain == "" || s.digestStore == nil { + return "", 0, "", false } - return s.cache.FindGeminiSession(ctx, groupID, prefixHash, digestChain) + return s.digestStore.Find(groupID, prefixHash, digestChain) } -// SaveGeminiSession 保存 Gemini 会话 -func (s *GatewayService) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - if digestChain == "" || s.cache == nil { +// SaveGeminiSession 保存 Gemini 会话。oldDigestChain 为 Find 返回的 matchedChain,用于删旧 key。 +func (s *GatewayService) SaveGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error { + if digestChain == "" || s.digestStore == nil { return nil } - return s.cache.SaveGeminiSession(ctx, groupID, prefixHash, digestChain, uuid, accountID) + s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain) + return nil } // FindAnthropicSession 查找 Anthropic 会话(基于内容摘要链的 Fallback 匹配) -func (s *GatewayService) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - if digestChain == "" || s.cache == nil { - return "", 0, false +func (s *GatewayService) FindAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) { + if digestChain == "" || s.digestStore == nil { + return "", 0, "", false } - return s.cache.FindAnthropicSession(ctx, groupID, prefixHash, digestChain) + return s.digestStore.Find(groupID, prefixHash, digestChain) } // SaveAnthropicSession 保存 Anthropic 会话 -func (s *GatewayService) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - if digestChain == "" || s.cache == nil { +func (s *GatewayService) SaveAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error { + if digestChain == "" || s.digestStore == nil { return nil } - return s.cache.SaveAnthropicSession(ctx, groupID, prefixHash, digestChain, uuid, accountID) + s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain) + return nil } func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 45686d80..c738b79f 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -277,22 +277,6 @@ func (m *mockGatewayCacheForGemini) GetModelLoadBatch(ctx context.Context, accou return nil, nil } -func (m *mockGatewayCacheForGemini) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - return "", 0, false -} - -func (m *mockGatewayCacheForGemini) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - return nil -} - -func (m *mockGatewayCacheForGemini) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - return "", 0, false -} - -func (m *mockGatewayCacheForGemini) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - return nil -} - // TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择 func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) { ctx := context.Background() diff --git a/backend/internal/service/gemini_session.go b/backend/internal/service/gemini_session.go index 859ae9f3..1780d1da 100644 --- a/backend/internal/service/gemini_session.go +++ b/backend/internal/service/gemini_session.go @@ -6,26 +6,11 @@ import ( "encoding/json" "strconv" "strings" - "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/cespare/xxhash/v2" ) -// Gemini 会话 ID Fallback 相关常量 -const ( - // geminiSessionTTLSeconds Gemini 会话缓存 TTL(5 分钟) - geminiSessionTTLSeconds = 300 - - // geminiSessionKeyPrefix Gemini 会话 Redis key 前缀 - geminiSessionKeyPrefix = "gemini:sess:" -) - -// GeminiSessionTTL 返回 Gemini 会话缓存 TTL -func GeminiSessionTTL() time.Duration { - return geminiSessionTTLSeconds * time.Second -} - // shortHash 使用 XXHash64 + Base36 生成短 hash(16 字符) // XXHash64 比 SHA256 快约 10 倍,Base36 比 Hex 短约 20% func shortHash(data []byte) string { @@ -79,35 +64,6 @@ func GenerateGeminiPrefixHash(userID, apiKeyID int64, ip, userAgent, platform, m return base64.RawURLEncoding.EncodeToString(hash[:12]) } -// BuildGeminiSessionKey 构建 Gemini 会话 Redis key -// 格式: gemini:sess:{groupID}:{prefixHash}:{digestChain} -func BuildGeminiSessionKey(groupID int64, prefixHash, digestChain string) string { - return geminiSessionKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash + ":" + digestChain -} - -// GenerateDigestChainPrefixes 生成摘要链的所有前缀(从长到短) -// 用于 MGET 批量查询最长匹配 -func GenerateDigestChainPrefixes(chain string) []string { - if chain == "" { - return nil - } - - var prefixes []string - c := chain - - for c != "" { - prefixes = append(prefixes, c) - // 找到最后一个 "-" 的位置 - if i := strings.LastIndex(c, "-"); i > 0 { - c = c[:i] - } else { - break - } - } - - return prefixes -} - // ParseGeminiSessionValue 解析 Gemini 会话缓存值 // 格式: {uuid}:{accountID} func ParseGeminiSessionValue(value string) (uuid string, accountID int64, ok bool) { @@ -139,15 +95,6 @@ func FormatGeminiSessionValue(uuid string, accountID int64) string { // geminiDigestSessionKeyPrefix Gemini 摘要 fallback 会话 key 前缀 const geminiDigestSessionKeyPrefix = "gemini:digest:" -// geminiTrieKeyPrefix Gemini Trie 会话 key 前缀 -const geminiTrieKeyPrefix = "gemini:trie:" - -// BuildGeminiTrieKey 构建 Gemini Trie Redis key -// 格式: gemini:trie:{groupID}:{prefixHash} -func BuildGeminiTrieKey(groupID int64, prefixHash string) string { - return geminiTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash -} - // GenerateGeminiDigestSessionKey 生成 Gemini 摘要 fallback 的 sessionKey // 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey // 用于在 SelectAccountWithLoadAwareness 中保持粘性会话 diff --git a/backend/internal/service/gemini_session_integration_test.go b/backend/internal/service/gemini_session_integration_test.go index 928c62cf..95b5f594 100644 --- a/backend/internal/service/gemini_session_integration_test.go +++ b/backend/internal/service/gemini_session_integration_test.go @@ -1,41 +1,14 @@ package service import ( - "context" "testing" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" ) -// mockGeminiSessionCache 模拟 Redis 缓存 -type mockGeminiSessionCache struct { - sessions map[string]string // key -> value -} - -func newMockGeminiSessionCache() *mockGeminiSessionCache { - return &mockGeminiSessionCache{sessions: make(map[string]string)} -} - -func (m *mockGeminiSessionCache) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64) { - key := BuildGeminiSessionKey(groupID, prefixHash, digestChain) - value := FormatGeminiSessionValue(uuid, accountID) - m.sessions[key] = value -} - -func (m *mockGeminiSessionCache) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - prefixes := GenerateDigestChainPrefixes(digestChain) - for _, p := range prefixes { - key := BuildGeminiSessionKey(groupID, prefixHash, p) - if val, ok := m.sessions[key]; ok { - return ParseGeminiSessionValue(val) - } - } - return "", 0, false -} - // TestGeminiSessionContinuousConversation 测试连续会话的摘要链匹配 func TestGeminiSessionContinuousConversation(t *testing.T) { - cache := newMockGeminiSessionCache() + store := NewDigestSessionStore() groupID := int64(1) prefixHash := "test_prefix_hash" sessionUUID := "session-uuid-12345" @@ -54,13 +27,13 @@ func TestGeminiSessionContinuousConversation(t *testing.T) { t.Logf("Round 1 chain: %s", chain1) // 第一轮:没有找到会话,创建新会话 - _, _, found := cache.Find(groupID, prefixHash, chain1) + _, _, _, found := store.Find(groupID, prefixHash, chain1) if found { t.Error("Round 1: should not find existing session") } - // 保存第一轮会话 - cache.Save(groupID, prefixHash, chain1, sessionUUID, accountID) + // 保存第一轮会话(首轮无旧 chain) + store.Save(groupID, prefixHash, chain1, sessionUUID, accountID, "") // 模拟第二轮对话(用户继续对话) req2 := &antigravity.GeminiRequest{ @@ -77,7 +50,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) { t.Logf("Round 2 chain: %s", chain2) // 第二轮:应该能找到会话(通过前缀匹配) - foundUUID, foundAccID, found := cache.Find(groupID, prefixHash, chain2) + foundUUID, foundAccID, matchedChain, found := store.Find(groupID, prefixHash, chain2) if !found { t.Error("Round 2: should find session via prefix matching") } @@ -88,8 +61,8 @@ func TestGeminiSessionContinuousConversation(t *testing.T) { t.Errorf("Round 2: expected accountID %d, got %d", accountID, foundAccID) } - // 保存第二轮会话 - cache.Save(groupID, prefixHash, chain2, sessionUUID, accountID) + // 保存第二轮会话,传入 Find 返回的 matchedChain 以删旧 key + store.Save(groupID, prefixHash, chain2, sessionUUID, accountID, matchedChain) // 模拟第三轮对话 req3 := &antigravity.GeminiRequest{ @@ -108,7 +81,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) { t.Logf("Round 3 chain: %s", chain3) // 第三轮:应该能找到会话(通过第二轮的前缀匹配) - foundUUID, foundAccID, found = cache.Find(groupID, prefixHash, chain3) + foundUUID, foundAccID, _, found = store.Find(groupID, prefixHash, chain3) if !found { t.Error("Round 3: should find session via prefix matching") } @@ -118,13 +91,11 @@ func TestGeminiSessionContinuousConversation(t *testing.T) { if foundAccID != accountID { t.Errorf("Round 3: expected accountID %d, got %d", accountID, foundAccID) } - - t.Log("✓ Continuous conversation session matching works correctly!") } // TestGeminiSessionDifferentConversations 测试不同会话不会错误匹配 func TestGeminiSessionDifferentConversations(t *testing.T) { - cache := newMockGeminiSessionCache() + store := NewDigestSessionStore() groupID := int64(1) prefixHash := "test_prefix_hash" @@ -135,7 +106,7 @@ func TestGeminiSessionDifferentConversations(t *testing.T) { }, } chain1 := BuildGeminiDigestChain(req1) - cache.Save(groupID, prefixHash, chain1, "session-1", 100) + store.Save(groupID, prefixHash, chain1, "session-1", 100, "") // 第二个完全不同的会话 req2 := &antigravity.GeminiRequest{ @@ -146,61 +117,29 @@ func TestGeminiSessionDifferentConversations(t *testing.T) { chain2 := BuildGeminiDigestChain(req2) // 不同会话不应该匹配 - _, _, found := cache.Find(groupID, prefixHash, chain2) + _, _, _, found := store.Find(groupID, prefixHash, chain2) if found { t.Error("Different conversations should not match") } - - t.Log("✓ Different conversations are correctly isolated!") } // TestGeminiSessionPrefixMatchingOrder 测试前缀匹配的优先级(最长匹配优先) func TestGeminiSessionPrefixMatchingOrder(t *testing.T) { - cache := newMockGeminiSessionCache() + store := NewDigestSessionStore() groupID := int64(1) prefixHash := "test_prefix_hash" - // 创建一个三轮对话 - req := &antigravity.GeminiRequest{ - SystemInstruction: &antigravity.GeminiContent{ - Parts: []antigravity.GeminiPart{{Text: "System prompt"}}, - }, - Contents: []antigravity.GeminiContent{ - {Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q1"}}}, - {Role: "model", Parts: []antigravity.GeminiPart{{Text: "A1"}}}, - {Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q2"}}}, - }, - } - fullChain := BuildGeminiDigestChain(req) - prefixes := GenerateDigestChainPrefixes(fullChain) - - t.Logf("Full chain: %s", fullChain) - t.Logf("Prefixes (longest first): %v", prefixes) - - // 验证前缀生成顺序(从长到短) - if len(prefixes) != 4 { - t.Errorf("Expected 4 prefixes, got %d", len(prefixes)) - } - // 保存不同轮次的会话到不同账号 - // 第一轮(最短前缀)-> 账号 1 - cache.Save(groupID, prefixHash, prefixes[3], "session-round1", 1) - // 第二轮 -> 账号 2 - cache.Save(groupID, prefixHash, prefixes[2], "session-round2", 2) - // 第三轮(最长前缀,完整链)-> 账号 3 - cache.Save(groupID, prefixHash, prefixes[0], "session-round3", 3) + store.Save(groupID, prefixHash, "s:sys-u:q1", "session-round1", 1, "") + store.Save(groupID, prefixHash, "s:sys-u:q1-m:a1", "session-round2", 2, "") + store.Save(groupID, prefixHash, "s:sys-u:q1-m:a1-u:q2", "session-round3", 3, "") - // 查找应该返回最长匹配(账号 3) - _, accID, found := cache.Find(groupID, prefixHash, fullChain) + // 查找更长的链,应该返回最长匹配(账号 3) + _, accID, _, found := store.Find(groupID, prefixHash, "s:sys-u:q1-m:a1-u:q2-m:a2") if !found { t.Error("Should find session") } if accID != 3 { t.Errorf("Should match longest prefix (account 3), got account %d", accID) } - - t.Log("✓ Longest prefix matching works correctly!") } - -// 确保 context 包被使用(避免未使用的导入警告) -var _ = context.Background diff --git a/backend/internal/service/gemini_session_test.go b/backend/internal/service/gemini_session_test.go index 8c1908f7..a034cddd 100644 --- a/backend/internal/service/gemini_session_test.go +++ b/backend/internal/service/gemini_session_test.go @@ -152,61 +152,6 @@ func TestGenerateGeminiPrefixHash(t *testing.T) { } } -func TestGenerateDigestChainPrefixes(t *testing.T) { - tests := []struct { - name string - chain string - want []string - wantLen int - }{ - { - name: "empty", - chain: "", - wantLen: 0, - }, - { - name: "single part", - chain: "u:abc123", - want: []string{"u:abc123"}, - wantLen: 1, - }, - { - name: "two parts", - chain: "s:xyz-u:abc", - want: []string{"s:xyz-u:abc", "s:xyz"}, - wantLen: 2, - }, - { - name: "four parts", - chain: "s:a-u:b-m:c-u:d", - want: []string{"s:a-u:b-m:c-u:d", "s:a-u:b-m:c", "s:a-u:b", "s:a"}, - wantLen: 4, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := GenerateDigestChainPrefixes(tt.chain) - - if len(result) != tt.wantLen { - t.Errorf("expected %d prefixes, got %d: %v", tt.wantLen, len(result), result) - } - - if tt.want != nil { - for i, want := range tt.want { - if i >= len(result) { - t.Errorf("missing prefix at index %d", i) - continue - } - if result[i] != want { - t.Errorf("prefix[%d]: expected %s, got %s", i, want, result[i]) - } - } - } - }) - } -} - func TestParseGeminiSessionValue(t *testing.T) { tests := []struct { name string @@ -442,40 +387,3 @@ func TestGenerateGeminiDigestSessionKey(t *testing.T) { } }) } - -func TestBuildGeminiTrieKey(t *testing.T) { - tests := []struct { - name string - groupID int64 - prefixHash string - want string - }{ - { - name: "normal", - groupID: 123, - prefixHash: "abcdef12", - want: "gemini:trie:123:abcdef12", - }, - { - name: "zero group", - groupID: 0, - prefixHash: "xyz", - want: "gemini:trie:0:xyz", - }, - { - name: "empty prefix", - groupID: 1, - prefixHash: "", - want: "gemini:trie:1:", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := BuildGeminiTrieKey(tt.groupID, tt.prefixHash) - if got != tt.want { - t.Errorf("BuildGeminiTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want) - } - }) - } -} diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 159b0afb..22b4730d 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -212,22 +212,6 @@ func (c *stubGatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []i return nil, nil } -func (c *stubGatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - return "", 0, false -} - -func (c *stubGatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - return nil -} - -func (c *stubGatewayCache) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - return "", 0, false -} - -func (c *stubGatewayCache) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - return nil -} - func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) { now := time.Now() resetAt := now.Add(10 * time.Minute) diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 05371022..87ca7897 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -275,4 +275,5 @@ var ProviderSet = wire.NewSet( NewUsageCache, NewTotpService, NewErrorPassthroughService, + NewDigestSessionStore, )