refactor: replace Trie-based digest session store with flat cache

This commit is contained in:
erio
2026-02-09 07:02:12 +08:00
parent 72b08f9cc5
commit b889d5017b
18 changed files with 428 additions and 640 deletions

View File

@@ -2,7 +2,6 @@ package service
import (
"encoding/json"
"strconv"
"strings"
"time"
)
@@ -12,9 +11,6 @@ const (
// anthropicSessionTTLSeconds Anthropic 会话缓存 TTL5 分钟)
anthropicSessionTTLSeconds = 300
// anthropicTrieKeyPrefix Anthropic Trie 会话 key 前缀
anthropicTrieKeyPrefix = "anthropic:trie:"
// anthropicDigestSessionKeyPrefix Anthropic 摘要 fallback 会话 key 前缀
anthropicDigestSessionKeyPrefix = "anthropic:digest:"
)
@@ -68,12 +64,6 @@ func rolePrefix(role string) string {
}
}
// BuildAnthropicTrieKey 构建 Anthropic Trie Redis key
// 格式: anthropic:trie:{groupID}:{prefixHash}
func BuildAnthropicTrieKey(groupID int64, prefixHash string) string {
return anthropicTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash
}
// GenerateAnthropicDigestSessionKey 生成 Anthropic 摘要 fallback 的 sessionKey
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
func GenerateAnthropicDigestSessionKey(prefixHash, uuid string) string {

View File

@@ -236,43 +236,6 @@ func TestBuildAnthropicDigestChain_Deterministic(t *testing.T) {
}
}
func TestBuildAnthropicTrieKey(t *testing.T) {
tests := []struct {
name string
groupID int64
prefixHash string
want string
}{
{
name: "normal",
groupID: 123,
prefixHash: "abcdef12",
want: "anthropic:trie:123:abcdef12",
},
{
name: "zero group",
groupID: 0,
prefixHash: "xyz",
want: "anthropic:trie:0:xyz",
},
{
name: "empty prefix",
groupID: 1,
prefixHash: "",
want: "anthropic:trie:1:",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := BuildAnthropicTrieKey(tt.groupID, tt.prefixHash)
if got != tt.want {
t.Errorf("BuildAnthropicTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want)
}
})
}
}
func TestGenerateAnthropicDigestSessionKey(t *testing.T) {
tests := []struct {
name string

View File

@@ -0,0 +1,69 @@
package service
import (
"strconv"
"strings"
"time"
gocache "github.com/patrickmn/go-cache"
)
// digestSessionTTL 摘要会话默认 TTL
const digestSessionTTL = 5 * time.Minute
// sessionEntry flat cache 条目
type sessionEntry struct {
uuid string
accountID int64
}
// DigestSessionStore 内存摘要会话存储flat cache 实现)
// key: "{groupID}:{prefixHash}|{digestChain}" → *sessionEntry
type DigestSessionStore struct {
cache *gocache.Cache
}
// NewDigestSessionStore 创建内存摘要会话存储
func NewDigestSessionStore() *DigestSessionStore {
return &DigestSessionStore{
cache: gocache.New(digestSessionTTL, time.Minute),
}
}
// Save 保存摘要会话。oldDigestChain 为 Find 返回的 matchedChain用于删旧 key。
func (s *DigestSessionStore) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) {
if digestChain == "" {
return
}
ns := buildNS(groupID, prefixHash)
s.cache.Set(ns+digestChain, &sessionEntry{uuid: uuid, accountID: accountID}, gocache.DefaultExpiration)
if oldDigestChain != "" && oldDigestChain != digestChain {
s.cache.Delete(ns + oldDigestChain)
}
}
// Find 查找摘要会话,从完整 chain 逐段截断,返回最长匹配及对应 matchedChain。
func (s *DigestSessionStore) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
if digestChain == "" {
return "", 0, "", false
}
ns := buildNS(groupID, prefixHash)
chain := digestChain
for {
if val, ok := s.cache.Get(ns + chain); ok {
if e, ok := val.(*sessionEntry); ok {
return e.uuid, e.accountID, chain, true
}
}
i := strings.LastIndex(chain, "-")
if i < 0 {
return "", 0, "", false
}
chain = chain[:i]
}
}
// buildNS 构建 namespace 前缀
func buildNS(groupID int64, prefixHash string) string {
return strconv.FormatInt(groupID, 10) + ":" + prefixHash + "|"
}

View File

@@ -0,0 +1,312 @@
//go:build unit
package service
import (
"fmt"
"sync"
"testing"
"time"
gocache "github.com/patrickmn/go-cache"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDigestSessionStore_SaveAndFind(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "s:a1-u:b2-m:c3", "uuid-1", 100, "")
uuid, accountID, _, found := store.Find(1, "prefix", "s:a1-u:b2-m:c3")
require.True(t, found)
assert.Equal(t, "uuid-1", uuid)
assert.Equal(t, int64(100), accountID)
}
func TestDigestSessionStore_PrefixMatch(t *testing.T) {
store := NewDigestSessionStore()
// 保存短链
store.Save(1, "prefix", "u:a-m:b", "uuid-short", 10, "")
// 用长链查找,应前缀匹配到短链
uuid, accountID, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d")
require.True(t, found)
assert.Equal(t, "uuid-short", uuid)
assert.Equal(t, int64(10), accountID)
assert.Equal(t, "u:a-m:b", matchedChain)
}
func TestDigestSessionStore_LongestPrefixMatch(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "u:a", "uuid-1", 1, "")
store.Save(1, "prefix", "u:a-m:b", "uuid-2", 2, "")
store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-3", 3, "")
// 应匹配最深的 "u:a-m:b-u:c"(从完整 chain 逐段截断,先命中最长的)
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e")
require.True(t, found)
assert.Equal(t, "uuid-3", uuid)
assert.Equal(t, int64(3), accountID)
// 查找中等长度,应匹配到 "u:a-m:b"
uuid, accountID, _, found = store.Find(1, "prefix", "u:a-m:b-u:x")
require.True(t, found)
assert.Equal(t, "uuid-2", uuid)
assert.Equal(t, int64(2), accountID)
}
func TestDigestSessionStore_SaveDeletesOldChain(t *testing.T) {
store := NewDigestSessionStore()
// 第一轮:保存 "u:a-m:b"
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 第二轮:同一 uuid 保存更长的链,传入旧 chain
store.Save(1, "prefix", "u:a-m:b-u:c-m:d", "uuid-1", 100, "u:a-m:b")
// 旧链 "u:a-m:b" 应已被删除
_, _, _, found := store.Find(1, "prefix", "u:a-m:b")
assert.False(t, found, "old chain should be deleted")
// 新链应能找到
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d")
require.True(t, found)
assert.Equal(t, "uuid-1", uuid)
assert.Equal(t, int64(100), accountID)
}
func TestDigestSessionStore_DifferentSessionsNoInterference(t *testing.T) {
store := NewDigestSessionStore()
// 相同系统提示词,不同用户提示词
store.Save(1, "prefix", "s:sys-u:user1", "uuid-1", 100, "")
store.Save(1, "prefix", "s:sys-u:user2", "uuid-2", 200, "")
uuid, accountID, _, found := store.Find(1, "prefix", "s:sys-u:user1-m:reply1")
require.True(t, found)
assert.Equal(t, "uuid-1", uuid)
assert.Equal(t, int64(100), accountID)
uuid, accountID, _, found = store.Find(1, "prefix", "s:sys-u:user2-m:reply2")
require.True(t, found)
assert.Equal(t, "uuid-2", uuid)
assert.Equal(t, int64(200), accountID)
}
func TestDigestSessionStore_NoMatch(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 完全不同的 chain
_, _, _, found := store.Find(1, "prefix", "u:x-m:y")
assert.False(t, found)
}
func TestDigestSessionStore_DifferentPrefixHash(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix1", "u:a-m:b", "uuid-1", 100, "")
// 不同 prefixHash 应隔离
_, _, _, found := store.Find(1, "prefix2", "u:a-m:b")
assert.False(t, found)
}
func TestDigestSessionStore_DifferentGroupID(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 不同 groupID 应隔离
_, _, _, found := store.Find(2, "prefix", "u:a-m:b")
assert.False(t, found)
}
func TestDigestSessionStore_EmptyDigestChain(t *testing.T) {
store := NewDigestSessionStore()
// 空链不应保存
store.Save(1, "prefix", "", "uuid-1", 100, "")
_, _, _, found := store.Find(1, "prefix", "")
assert.False(t, found)
}
func TestDigestSessionStore_TTLExpiration(t *testing.T) {
store := &DigestSessionStore{
cache: gocache.New(100*time.Millisecond, 50*time.Millisecond),
}
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 立即应该能找到
_, _, _, found := store.Find(1, "prefix", "u:a-m:b")
require.True(t, found)
// 等待过期 + 清理周期
time.Sleep(300 * time.Millisecond)
// 过期后应找不到
_, _, _, found = store.Find(1, "prefix", "u:a-m:b")
assert.False(t, found)
}
func TestDigestSessionStore_ConcurrentSafety(t *testing.T) {
store := NewDigestSessionStore()
var wg sync.WaitGroup
const goroutines = 50
const operations = 100
wg.Add(goroutines)
for g := 0; g < goroutines; g++ {
go func(id int) {
defer wg.Done()
prefix := fmt.Sprintf("prefix-%d", id%5)
for i := 0; i < operations; i++ {
chain := fmt.Sprintf("u:%d-m:%d", id, i)
uuid := fmt.Sprintf("uuid-%d-%d", id, i)
store.Save(1, prefix, chain, uuid, int64(id), "")
store.Find(1, prefix, chain)
}
}(g)
}
wg.Wait()
}
func TestDigestSessionStore_MultipleSessions(t *testing.T) {
store := NewDigestSessionStore()
sessions := []struct {
chain string
uuid string
accountID int64
}{
{"u:session1", "uuid-1", 1},
{"u:session2-m:reply2", "uuid-2", 2},
{"u:session3-m:reply3-u:msg3", "uuid-3", 3},
}
for _, sess := range sessions {
store.Save(1, "prefix", sess.chain, sess.uuid, sess.accountID, "")
}
// 验证每个会话都能正确查找
for _, sess := range sessions {
uuid, accountID, _, found := store.Find(1, "prefix", sess.chain)
require.True(t, found, "should find session: %s", sess.chain)
assert.Equal(t, sess.uuid, uuid)
assert.Equal(t, sess.accountID, accountID)
}
// 验证继续对话的场景
uuid, accountID, _, found := store.Find(1, "prefix", "u:session2-m:reply2-u:newmsg")
require.True(t, found)
assert.Equal(t, "uuid-2", uuid)
assert.Equal(t, int64(2), accountID)
}
func TestDigestSessionStore_Performance1000Sessions(t *testing.T) {
store := NewDigestSessionStore()
// 插入 1000 个会话
for i := 0; i < 1000; i++ {
chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d", i, i)
store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "")
}
// 查找性能测试
start := time.Now()
const lookups = 10000
for i := 0; i < lookups; i++ {
idx := i % 1000
chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d-u:newmsg", idx, idx)
_, _, _, found := store.Find(1, "prefix", chain)
assert.True(t, found)
}
elapsed := time.Since(start)
t.Logf("%d lookups in %v (%.0f ns/op)", lookups, elapsed, float64(elapsed.Nanoseconds())/lookups)
}
func TestDigestSessionStore_FindReturnsMatchedChain(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-1", 100, "")
// 精确匹配
_, _, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c")
require.True(t, found)
assert.Equal(t, "u:a-m:b-u:c", matchedChain)
// 前缀匹配(截断后命中)
_, _, matchedChain, found = store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e")
require.True(t, found)
assert.Equal(t, "u:a-m:b-u:c", matchedChain)
}
func TestDigestSessionStore_CacheItemCountStable(t *testing.T) {
store := NewDigestSessionStore()
// 模拟 100 个独立会话,每个进行 10 轮对话
// 正确传递 oldDigestChain 时,每个会话始终只保留 1 个 key
for conv := 0; conv < 100; conv++ {
var prevMatchedChain string
for round := 0; round < 10; round++ {
chain := fmt.Sprintf("s:sys-u:user%d", conv)
for r := 0; r < round; r++ {
chain += fmt.Sprintf("-m:a%d-u:q%d", r, r+1)
}
uuid := fmt.Sprintf("uuid-conv%d", conv)
_, _, matched, _ := store.Find(1, "prefix", chain)
store.Save(1, "prefix", chain, uuid, int64(conv), matched)
prevMatchedChain = matched
_ = prevMatchedChain
}
}
// 100 个会话 × 1 key/会话 = 应该 ≤ 100 个 key
// 允许少量并发残留,但绝不能接近 100×10=1000
itemCount := store.cache.ItemCount()
assert.LessOrEqual(t, itemCount, 100, "cache should have at most 100 items (1 per conversation), got %d", itemCount)
t.Logf("Cache item count after 100 conversations × 10 rounds: %d", itemCount)
}
func TestDigestSessionStore_TTLPreventsUnboundedGrowth(t *testing.T) {
// 使用极短 TTL 验证大量写入后 cache 能被清理
store := &DigestSessionStore{
cache: gocache.New(100*time.Millisecond, 50*time.Millisecond),
}
// 插入 500 个不同的 key无 oldDigestChain模拟最坏场景全是新会话首轮
for i := 0; i < 500; i++ {
chain := fmt.Sprintf("u:user%d", i)
store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "")
}
assert.Equal(t, 500, store.cache.ItemCount())
// 等待 TTL + 清理周期
time.Sleep(300 * time.Millisecond)
assert.Equal(t, 0, store.cache.ItemCount(), "all items should be expired and cleaned up")
}
func TestDigestSessionStore_SaveSameChainNoDelete(t *testing.T) {
store := NewDigestSessionStore()
// 保存 chain
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 用户重发相同消息oldDigestChain == digestChain不应删掉刚设置的 key
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "u:a-m:b")
// 仍然能找到
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b")
require.True(t, found)
assert.Equal(t, "uuid-1", uuid)
assert.Equal(t, int64(100), accountID)
}

View File

@@ -224,22 +224,6 @@ func (m *mockGatewayCacheForPlatform) GetModelLoadBatch(ctx context.Context, acc
return nil, nil
}
func (m *mockGatewayCacheForPlatform) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForPlatform) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
func (m *mockGatewayCacheForPlatform) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForPlatform) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
type mockGroupRepoForGateway struct {
groups map[int64]*Group
getByIDCalls int

View File

@@ -305,23 +305,6 @@ type GatewayCache interface {
// GetModelLoadBatch 批量获取账号的模型负载信息Antigravity 专用)
// Batch get model load info for accounts (Antigravity only)
GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error)
// FindGeminiSession 查找 Gemini 会话MGET 倒序匹配)
// Find Gemini session using MGET reverse order matching
// 返回最长匹配的会话信息uuid, accountID
FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool)
// SaveGeminiSession 保存 Gemini 会话
// Save Gemini session binding
SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error
// FindAnthropicSession 查找 Anthropic 会话Trie 匹配)
// Find Anthropic session using Trie matching
FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool)
// SaveAnthropicSession 保存 Anthropic 会话
// Save Anthropic session binding
SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error
}
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
@@ -416,6 +399,7 @@ type GatewayService struct {
userSubRepo UserSubscriptionRepository
userGroupRateRepo UserGroupRateRepository
cache GatewayCache
digestStore *DigestSessionStore
cfg *config.Config
schedulerSnapshot *SchedulerSnapshotService
billingService *BillingService
@@ -449,6 +433,7 @@ func NewGatewayService(
deferredService *DeferredService,
claudeTokenProvider *ClaudeTokenProvider,
sessionLimitCache SessionLimitCache,
digestStore *DigestSessionStore,
) *GatewayService {
return &GatewayService{
accountRepo: accountRepo,
@@ -458,6 +443,7 @@ func NewGatewayService(
userSubRepo: userSubRepo,
userGroupRateRepo: userGroupRateRepo,
cache: cache,
digestStore: digestStore,
cfg: cfg,
schedulerSnapshot: schedulerSnapshot,
concurrencyService: concurrencyService,
@@ -557,35 +543,37 @@ func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID
// FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配)
// 返回最长匹配的会话信息uuid, accountID
func (s *GatewayService) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
if digestChain == "" || s.cache == nil {
return "", 0, false
func (s *GatewayService) FindGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
if digestChain == "" || s.digestStore == nil {
return "", 0, "", false
}
return s.cache.FindGeminiSession(ctx, groupID, prefixHash, digestChain)
return s.digestStore.Find(groupID, prefixHash, digestChain)
}
// SaveGeminiSession 保存 Gemini 会话
func (s *GatewayService) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
if digestChain == "" || s.cache == nil {
// SaveGeminiSession 保存 Gemini 会话。oldDigestChain 为 Find 返回的 matchedChain用于删旧 key。
func (s *GatewayService) SaveGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error {
if digestChain == "" || s.digestStore == nil {
return nil
}
return s.cache.SaveGeminiSession(ctx, groupID, prefixHash, digestChain, uuid, accountID)
s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain)
return nil
}
// FindAnthropicSession 查找 Anthropic 会话(基于内容摘要链的 Fallback 匹配)
func (s *GatewayService) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
if digestChain == "" || s.cache == nil {
return "", 0, false
func (s *GatewayService) FindAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
if digestChain == "" || s.digestStore == nil {
return "", 0, "", false
}
return s.cache.FindAnthropicSession(ctx, groupID, prefixHash, digestChain)
return s.digestStore.Find(groupID, prefixHash, digestChain)
}
// SaveAnthropicSession 保存 Anthropic 会话
func (s *GatewayService) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
if digestChain == "" || s.cache == nil {
func (s *GatewayService) SaveAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error {
if digestChain == "" || s.digestStore == nil {
return nil
}
return s.cache.SaveAnthropicSession(ctx, groupID, prefixHash, digestChain, uuid, accountID)
s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain)
return nil
}
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {

View File

@@ -277,22 +277,6 @@ func (m *mockGatewayCacheForGemini) GetModelLoadBatch(ctx context.Context, accou
return nil, nil
}
func (m *mockGatewayCacheForGemini) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForGemini) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
func (m *mockGatewayCacheForGemini) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForGemini) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
ctx := context.Background()

View File

@@ -6,26 +6,11 @@ import (
"encoding/json"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/cespare/xxhash/v2"
)
// Gemini 会话 ID Fallback 相关常量
const (
// geminiSessionTTLSeconds Gemini 会话缓存 TTL5 分钟)
geminiSessionTTLSeconds = 300
// geminiSessionKeyPrefix Gemini 会话 Redis key 前缀
geminiSessionKeyPrefix = "gemini:sess:"
)
// GeminiSessionTTL 返回 Gemini 会话缓存 TTL
func GeminiSessionTTL() time.Duration {
return geminiSessionTTLSeconds * time.Second
}
// shortHash 使用 XXHash64 + Base36 生成短 hash16 字符)
// XXHash64 比 SHA256 快约 10 倍Base36 比 Hex 短约 20%
func shortHash(data []byte) string {
@@ -79,35 +64,6 @@ func GenerateGeminiPrefixHash(userID, apiKeyID int64, ip, userAgent, platform, m
return base64.RawURLEncoding.EncodeToString(hash[:12])
}
// BuildGeminiSessionKey 构建 Gemini 会话 Redis key
// 格式: gemini:sess:{groupID}:{prefixHash}:{digestChain}
func BuildGeminiSessionKey(groupID int64, prefixHash, digestChain string) string {
return geminiSessionKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash + ":" + digestChain
}
// GenerateDigestChainPrefixes 生成摘要链的所有前缀(从长到短)
// 用于 MGET 批量查询最长匹配
func GenerateDigestChainPrefixes(chain string) []string {
if chain == "" {
return nil
}
var prefixes []string
c := chain
for c != "" {
prefixes = append(prefixes, c)
// 找到最后一个 "-" 的位置
if i := strings.LastIndex(c, "-"); i > 0 {
c = c[:i]
} else {
break
}
}
return prefixes
}
// ParseGeminiSessionValue 解析 Gemini 会话缓存值
// 格式: {uuid}:{accountID}
func ParseGeminiSessionValue(value string) (uuid string, accountID int64, ok bool) {
@@ -139,15 +95,6 @@ func FormatGeminiSessionValue(uuid string, accountID int64) string {
// geminiDigestSessionKeyPrefix Gemini 摘要 fallback 会话 key 前缀
const geminiDigestSessionKeyPrefix = "gemini:digest:"
// geminiTrieKeyPrefix Gemini Trie 会话 key 前缀
const geminiTrieKeyPrefix = "gemini:trie:"
// BuildGeminiTrieKey 构建 Gemini Trie Redis key
// 格式: gemini:trie:{groupID}:{prefixHash}
func BuildGeminiTrieKey(groupID int64, prefixHash string) string {
return geminiTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash
}
// GenerateGeminiDigestSessionKey 生成 Gemini 摘要 fallback 的 sessionKey
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
// 用于在 SelectAccountWithLoadAwareness 中保持粘性会话

View File

@@ -1,41 +1,14 @@
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// mockGeminiSessionCache 模拟 Redis 缓存
type mockGeminiSessionCache struct {
sessions map[string]string // key -> value
}
func newMockGeminiSessionCache() *mockGeminiSessionCache {
return &mockGeminiSessionCache{sessions: make(map[string]string)}
}
func (m *mockGeminiSessionCache) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64) {
key := BuildGeminiSessionKey(groupID, prefixHash, digestChain)
value := FormatGeminiSessionValue(uuid, accountID)
m.sessions[key] = value
}
func (m *mockGeminiSessionCache) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
prefixes := GenerateDigestChainPrefixes(digestChain)
for _, p := range prefixes {
key := BuildGeminiSessionKey(groupID, prefixHash, p)
if val, ok := m.sessions[key]; ok {
return ParseGeminiSessionValue(val)
}
}
return "", 0, false
}
// TestGeminiSessionContinuousConversation 测试连续会话的摘要链匹配
func TestGeminiSessionContinuousConversation(t *testing.T) {
cache := newMockGeminiSessionCache()
store := NewDigestSessionStore()
groupID := int64(1)
prefixHash := "test_prefix_hash"
sessionUUID := "session-uuid-12345"
@@ -54,13 +27,13 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
t.Logf("Round 1 chain: %s", chain1)
// 第一轮:没有找到会话,创建新会话
_, _, found := cache.Find(groupID, prefixHash, chain1)
_, _, _, found := store.Find(groupID, prefixHash, chain1)
if found {
t.Error("Round 1: should not find existing session")
}
// 保存第一轮会话
cache.Save(groupID, prefixHash, chain1, sessionUUID, accountID)
// 保存第一轮会话(首轮无旧 chain
store.Save(groupID, prefixHash, chain1, sessionUUID, accountID, "")
// 模拟第二轮对话(用户继续对话)
req2 := &antigravity.GeminiRequest{
@@ -77,7 +50,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
t.Logf("Round 2 chain: %s", chain2)
// 第二轮:应该能找到会话(通过前缀匹配)
foundUUID, foundAccID, found := cache.Find(groupID, prefixHash, chain2)
foundUUID, foundAccID, matchedChain, found := store.Find(groupID, prefixHash, chain2)
if !found {
t.Error("Round 2: should find session via prefix matching")
}
@@ -88,8 +61,8 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
t.Errorf("Round 2: expected accountID %d, got %d", accountID, foundAccID)
}
// 保存第二轮会话
cache.Save(groupID, prefixHash, chain2, sessionUUID, accountID)
// 保存第二轮会话,传入 Find 返回的 matchedChain 以删旧 key
store.Save(groupID, prefixHash, chain2, sessionUUID, accountID, matchedChain)
// 模拟第三轮对话
req3 := &antigravity.GeminiRequest{
@@ -108,7 +81,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
t.Logf("Round 3 chain: %s", chain3)
// 第三轮:应该能找到会话(通过第二轮的前缀匹配)
foundUUID, foundAccID, found = cache.Find(groupID, prefixHash, chain3)
foundUUID, foundAccID, _, found = store.Find(groupID, prefixHash, chain3)
if !found {
t.Error("Round 3: should find session via prefix matching")
}
@@ -118,13 +91,11 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
if foundAccID != accountID {
t.Errorf("Round 3: expected accountID %d, got %d", accountID, foundAccID)
}
t.Log("✓ Continuous conversation session matching works correctly!")
}
// TestGeminiSessionDifferentConversations 测试不同会话不会错误匹配
func TestGeminiSessionDifferentConversations(t *testing.T) {
cache := newMockGeminiSessionCache()
store := NewDigestSessionStore()
groupID := int64(1)
prefixHash := "test_prefix_hash"
@@ -135,7 +106,7 @@ func TestGeminiSessionDifferentConversations(t *testing.T) {
},
}
chain1 := BuildGeminiDigestChain(req1)
cache.Save(groupID, prefixHash, chain1, "session-1", 100)
store.Save(groupID, prefixHash, chain1, "session-1", 100, "")
// 第二个完全不同的会话
req2 := &antigravity.GeminiRequest{
@@ -146,61 +117,29 @@ func TestGeminiSessionDifferentConversations(t *testing.T) {
chain2 := BuildGeminiDigestChain(req2)
// 不同会话不应该匹配
_, _, found := cache.Find(groupID, prefixHash, chain2)
_, _, _, found := store.Find(groupID, prefixHash, chain2)
if found {
t.Error("Different conversations should not match")
}
t.Log("✓ Different conversations are correctly isolated!")
}
// TestGeminiSessionPrefixMatchingOrder 测试前缀匹配的优先级(最长匹配优先)
func TestGeminiSessionPrefixMatchingOrder(t *testing.T) {
cache := newMockGeminiSessionCache()
store := NewDigestSessionStore()
groupID := int64(1)
prefixHash := "test_prefix_hash"
// 创建一个三轮对话
req := &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Parts: []antigravity.GeminiPart{{Text: "System prompt"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q1"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "A1"}}},
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q2"}}},
},
}
fullChain := BuildGeminiDigestChain(req)
prefixes := GenerateDigestChainPrefixes(fullChain)
t.Logf("Full chain: %s", fullChain)
t.Logf("Prefixes (longest first): %v", prefixes)
// 验证前缀生成顺序(从长到短)
if len(prefixes) != 4 {
t.Errorf("Expected 4 prefixes, got %d", len(prefixes))
}
// 保存不同轮次的会话到不同账号
// 第一轮(最短前缀)-> 账号 1
cache.Save(groupID, prefixHash, prefixes[3], "session-round1", 1)
// 第二轮 -> 账号 2
cache.Save(groupID, prefixHash, prefixes[2], "session-round2", 2)
// 第三轮(最长前缀,完整链)-> 账号 3
cache.Save(groupID, prefixHash, prefixes[0], "session-round3", 3)
store.Save(groupID, prefixHash, "s:sys-u:q1", "session-round1", 1, "")
store.Save(groupID, prefixHash, "s:sys-u:q1-m:a1", "session-round2", 2, "")
store.Save(groupID, prefixHash, "s:sys-u:q1-m:a1-u:q2", "session-round3", 3, "")
// 查找应该返回最长匹配(账号 3
_, accID, found := cache.Find(groupID, prefixHash, fullChain)
// 查找更长的链,应该返回最长匹配(账号 3
_, accID, _, found := store.Find(groupID, prefixHash, "s:sys-u:q1-m:a1-u:q2-m:a2")
if !found {
t.Error("Should find session")
}
if accID != 3 {
t.Errorf("Should match longest prefix (account 3), got account %d", accID)
}
t.Log("✓ Longest prefix matching works correctly!")
}
// 确保 context 包被使用(避免未使用的导入警告)
var _ = context.Background

View File

@@ -152,61 +152,6 @@ func TestGenerateGeminiPrefixHash(t *testing.T) {
}
}
func TestGenerateDigestChainPrefixes(t *testing.T) {
tests := []struct {
name string
chain string
want []string
wantLen int
}{
{
name: "empty",
chain: "",
wantLen: 0,
},
{
name: "single part",
chain: "u:abc123",
want: []string{"u:abc123"},
wantLen: 1,
},
{
name: "two parts",
chain: "s:xyz-u:abc",
want: []string{"s:xyz-u:abc", "s:xyz"},
wantLen: 2,
},
{
name: "four parts",
chain: "s:a-u:b-m:c-u:d",
want: []string{"s:a-u:b-m:c-u:d", "s:a-u:b-m:c", "s:a-u:b", "s:a"},
wantLen: 4,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GenerateDigestChainPrefixes(tt.chain)
if len(result) != tt.wantLen {
t.Errorf("expected %d prefixes, got %d: %v", tt.wantLen, len(result), result)
}
if tt.want != nil {
for i, want := range tt.want {
if i >= len(result) {
t.Errorf("missing prefix at index %d", i)
continue
}
if result[i] != want {
t.Errorf("prefix[%d]: expected %s, got %s", i, want, result[i])
}
}
}
})
}
}
func TestParseGeminiSessionValue(t *testing.T) {
tests := []struct {
name string
@@ -442,40 +387,3 @@ func TestGenerateGeminiDigestSessionKey(t *testing.T) {
}
})
}
func TestBuildGeminiTrieKey(t *testing.T) {
tests := []struct {
name string
groupID int64
prefixHash string
want string
}{
{
name: "normal",
groupID: 123,
prefixHash: "abcdef12",
want: "gemini:trie:123:abcdef12",
},
{
name: "zero group",
groupID: 0,
prefixHash: "xyz",
want: "gemini:trie:0:xyz",
},
{
name: "empty prefix",
groupID: 1,
prefixHash: "",
want: "gemini:trie:1:",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := BuildGeminiTrieKey(tt.groupID, tt.prefixHash)
if got != tt.want {
t.Errorf("BuildGeminiTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want)
}
})
}
}

View File

@@ -212,22 +212,6 @@ func (c *stubGatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []i
return nil, nil
}
func (c *stubGatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (c *stubGatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
func (c *stubGatewayCache) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (c *stubGatewayCache) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)

View File

@@ -275,4 +275,5 @@ var ProviderSet = wire.NewSet(
NewUsageCache,
NewTotpService,
NewErrorPassthroughService,
NewDigestSessionStore,
)