refactor: replace Trie-based digest session store with flat cache
This commit is contained in:
@@ -1,41 +1,14 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
// mockGeminiSessionCache 模拟 Redis 缓存
|
||||
type mockGeminiSessionCache struct {
|
||||
sessions map[string]string // key -> value
|
||||
}
|
||||
|
||||
func newMockGeminiSessionCache() *mockGeminiSessionCache {
|
||||
return &mockGeminiSessionCache{sessions: make(map[string]string)}
|
||||
}
|
||||
|
||||
func (m *mockGeminiSessionCache) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64) {
|
||||
key := BuildGeminiSessionKey(groupID, prefixHash, digestChain)
|
||||
value := FormatGeminiSessionValue(uuid, accountID)
|
||||
m.sessions[key] = value
|
||||
}
|
||||
|
||||
func (m *mockGeminiSessionCache) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
||||
prefixes := GenerateDigestChainPrefixes(digestChain)
|
||||
for _, p := range prefixes {
|
||||
key := BuildGeminiSessionKey(groupID, prefixHash, p)
|
||||
if val, ok := m.sessions[key]; ok {
|
||||
return ParseGeminiSessionValue(val)
|
||||
}
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
// TestGeminiSessionContinuousConversation 测试连续会话的摘要链匹配
|
||||
func TestGeminiSessionContinuousConversation(t *testing.T) {
|
||||
cache := newMockGeminiSessionCache()
|
||||
store := NewDigestSessionStore()
|
||||
groupID := int64(1)
|
||||
prefixHash := "test_prefix_hash"
|
||||
sessionUUID := "session-uuid-12345"
|
||||
@@ -54,13 +27,13 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
|
||||
t.Logf("Round 1 chain: %s", chain1)
|
||||
|
||||
// 第一轮:没有找到会话,创建新会话
|
||||
_, _, found := cache.Find(groupID, prefixHash, chain1)
|
||||
_, _, _, found := store.Find(groupID, prefixHash, chain1)
|
||||
if found {
|
||||
t.Error("Round 1: should not find existing session")
|
||||
}
|
||||
|
||||
// 保存第一轮会话
|
||||
cache.Save(groupID, prefixHash, chain1, sessionUUID, accountID)
|
||||
// 保存第一轮会话(首轮无旧 chain)
|
||||
store.Save(groupID, prefixHash, chain1, sessionUUID, accountID, "")
|
||||
|
||||
// 模拟第二轮对话(用户继续对话)
|
||||
req2 := &antigravity.GeminiRequest{
|
||||
@@ -77,7 +50,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
|
||||
t.Logf("Round 2 chain: %s", chain2)
|
||||
|
||||
// 第二轮:应该能找到会话(通过前缀匹配)
|
||||
foundUUID, foundAccID, found := cache.Find(groupID, prefixHash, chain2)
|
||||
foundUUID, foundAccID, matchedChain, found := store.Find(groupID, prefixHash, chain2)
|
||||
if !found {
|
||||
t.Error("Round 2: should find session via prefix matching")
|
||||
}
|
||||
@@ -88,8 +61,8 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
|
||||
t.Errorf("Round 2: expected accountID %d, got %d", accountID, foundAccID)
|
||||
}
|
||||
|
||||
// 保存第二轮会话
|
||||
cache.Save(groupID, prefixHash, chain2, sessionUUID, accountID)
|
||||
// 保存第二轮会话,传入 Find 返回的 matchedChain 以删旧 key
|
||||
store.Save(groupID, prefixHash, chain2, sessionUUID, accountID, matchedChain)
|
||||
|
||||
// 模拟第三轮对话
|
||||
req3 := &antigravity.GeminiRequest{
|
||||
@@ -108,7 +81,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
|
||||
t.Logf("Round 3 chain: %s", chain3)
|
||||
|
||||
// 第三轮:应该能找到会话(通过第二轮的前缀匹配)
|
||||
foundUUID, foundAccID, found = cache.Find(groupID, prefixHash, chain3)
|
||||
foundUUID, foundAccID, _, found = store.Find(groupID, prefixHash, chain3)
|
||||
if !found {
|
||||
t.Error("Round 3: should find session via prefix matching")
|
||||
}
|
||||
@@ -118,13 +91,11 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
|
||||
if foundAccID != accountID {
|
||||
t.Errorf("Round 3: expected accountID %d, got %d", accountID, foundAccID)
|
||||
}
|
||||
|
||||
t.Log("✓ Continuous conversation session matching works correctly!")
|
||||
}
|
||||
|
||||
// TestGeminiSessionDifferentConversations 测试不同会话不会错误匹配
|
||||
func TestGeminiSessionDifferentConversations(t *testing.T) {
|
||||
cache := newMockGeminiSessionCache()
|
||||
store := NewDigestSessionStore()
|
||||
groupID := int64(1)
|
||||
prefixHash := "test_prefix_hash"
|
||||
|
||||
@@ -135,7 +106,7 @@ func TestGeminiSessionDifferentConversations(t *testing.T) {
|
||||
},
|
||||
}
|
||||
chain1 := BuildGeminiDigestChain(req1)
|
||||
cache.Save(groupID, prefixHash, chain1, "session-1", 100)
|
||||
store.Save(groupID, prefixHash, chain1, "session-1", 100, "")
|
||||
|
||||
// 第二个完全不同的会话
|
||||
req2 := &antigravity.GeminiRequest{
|
||||
@@ -146,61 +117,29 @@ func TestGeminiSessionDifferentConversations(t *testing.T) {
|
||||
chain2 := BuildGeminiDigestChain(req2)
|
||||
|
||||
// 不同会话不应该匹配
|
||||
_, _, found := cache.Find(groupID, prefixHash, chain2)
|
||||
_, _, _, found := store.Find(groupID, prefixHash, chain2)
|
||||
if found {
|
||||
t.Error("Different conversations should not match")
|
||||
}
|
||||
|
||||
t.Log("✓ Different conversations are correctly isolated!")
|
||||
}
|
||||
|
||||
// TestGeminiSessionPrefixMatchingOrder 测试前缀匹配的优先级(最长匹配优先)
|
||||
func TestGeminiSessionPrefixMatchingOrder(t *testing.T) {
|
||||
cache := newMockGeminiSessionCache()
|
||||
store := NewDigestSessionStore()
|
||||
groupID := int64(1)
|
||||
prefixHash := "test_prefix_hash"
|
||||
|
||||
// 创建一个三轮对话
|
||||
req := &antigravity.GeminiRequest{
|
||||
SystemInstruction: &antigravity.GeminiContent{
|
||||
Parts: []antigravity.GeminiPart{{Text: "System prompt"}},
|
||||
},
|
||||
Contents: []antigravity.GeminiContent{
|
||||
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q1"}}},
|
||||
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "A1"}}},
|
||||
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q2"}}},
|
||||
},
|
||||
}
|
||||
fullChain := BuildGeminiDigestChain(req)
|
||||
prefixes := GenerateDigestChainPrefixes(fullChain)
|
||||
|
||||
t.Logf("Full chain: %s", fullChain)
|
||||
t.Logf("Prefixes (longest first): %v", prefixes)
|
||||
|
||||
// 验证前缀生成顺序(从长到短)
|
||||
if len(prefixes) != 4 {
|
||||
t.Errorf("Expected 4 prefixes, got %d", len(prefixes))
|
||||
}
|
||||
|
||||
// 保存不同轮次的会话到不同账号
|
||||
// 第一轮(最短前缀)-> 账号 1
|
||||
cache.Save(groupID, prefixHash, prefixes[3], "session-round1", 1)
|
||||
// 第二轮 -> 账号 2
|
||||
cache.Save(groupID, prefixHash, prefixes[2], "session-round2", 2)
|
||||
// 第三轮(最长前缀,完整链)-> 账号 3
|
||||
cache.Save(groupID, prefixHash, prefixes[0], "session-round3", 3)
|
||||
store.Save(groupID, prefixHash, "s:sys-u:q1", "session-round1", 1, "")
|
||||
store.Save(groupID, prefixHash, "s:sys-u:q1-m:a1", "session-round2", 2, "")
|
||||
store.Save(groupID, prefixHash, "s:sys-u:q1-m:a1-u:q2", "session-round3", 3, "")
|
||||
|
||||
// 查找应该返回最长匹配(账号 3)
|
||||
_, accID, found := cache.Find(groupID, prefixHash, fullChain)
|
||||
// 查找更长的链,应该返回最长匹配(账号 3)
|
||||
_, accID, _, found := store.Find(groupID, prefixHash, "s:sys-u:q1-m:a1-u:q2-m:a2")
|
||||
if !found {
|
||||
t.Error("Should find session")
|
||||
}
|
||||
if accID != 3 {
|
||||
t.Errorf("Should match longest prefix (account 3), got account %d", accID)
|
||||
}
|
||||
|
||||
t.Log("✓ Longest prefix matching works correctly!")
|
||||
}
|
||||
|
||||
// 确保 context 包被使用(避免未使用的导入警告)
|
||||
var _ = context.Background
|
||||
|
||||
Reference in New Issue
Block a user