From 5c76b9e45a381cfc5ed723808f226a59ff844b24 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 9 Feb 2026 06:46:32 +0800 Subject: [PATCH] fix: prevent sessionHash collision for different users with same messages Mix SessionContext (ClientIP, UserAgent, APIKeyID) into GenerateSessionHash 3rd-level fallback to differentiate requests from different users sending identical content. Also switch hashContent from SHA256-truncated to XXHash64 for better performance, and optimize Trie Lua script to match from longest prefix first. --- backend/internal/handler/gateway_handler.go | 10 + .../internal/handler/gemini_v1beta_handler.go | 7 + backend/internal/repository/gateway_cache.go | 31 +- backend/internal/service/gateway_request.go | 28 +- backend/internal/service/gateway_service.go | 18 +- .../service/generate_session_hash_test.go | 843 ++++++++++++++++++ 6 files changed, 913 insertions(+), 24 deletions(-) create mode 100644 backend/internal/service/generate_session_hash_test.go diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 255d3fab..91348608 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -203,6 +203,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 计算粘性会话hash + parsedReq.SessionContext = &service.SessionContext{ + ClientIP: ip.GetClientIP(c), + UserAgent: c.GetHeader("User-Agent"), + APIKeyID: apiKey.ID, + } sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) // 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台 @@ -962,6 +967,11 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { } // 计算粘性会话 hash + parsedReq.SessionContext = &service.SessionContext{ + ClientIP: ip.GetClientIP(c), + UserAgent: c.GetHeader("User-Agent"), + APIKeyID: apiKey.ID, + } sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) // 选择支持该模型的账号 diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 036b5108..cd7c2d3f 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -233,6 +233,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { if sessionHash == "" { // Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端) parsedReq, _ := service.ParseGatewayRequest(body) + if parsedReq != nil { + parsedReq.SessionContext = &service.SessionContext{ + ClientIP: ip.GetClientIP(c), + UserAgent: c.GetHeader("User-Agent"), + APIKeyID: apiKey.ID, + } + } sessionHash = h.gatewayService.GenerateSessionHash(parsedReq) } sessionKey := sessionHash diff --git a/backend/internal/repository/gateway_cache.go b/backend/internal/repository/gateway_cache.go index 46ae0c16..b9cc521e 100644 --- a/backend/internal/repository/gateway_cache.go +++ b/backend/internal/repository/gateway_cache.go @@ -19,25 +19,34 @@ const ( // ARGV[2] = TTL seconds (用于刷新) // 返回: 最长匹配的 value (uuid:accountID) 或 nil // 查找成功时自动刷新 TTL,防止活跃会话意外过期 + // 从最长前缀(完整 chain)开始逐步缩短,第一次命中即返回 geminiTrieFindScript = ` local chain = ARGV[1] local ttl = tonumber(ARGV[2]) -local lastMatch = nil -local path = "" -for part in string.gmatch(chain, "[^-]+") do - path = path == "" and part or path .. "-" .. part - local val = redis.call('HGET', KEYS[1], path) +-- 先尝试完整 chain(最常见场景:同一对话的下一轮请求) +local val = redis.call('HGET', KEYS[1], chain) +if val and val ~= "" then + redis.call('EXPIRE', KEYS[1], ttl) + return val +end + +-- 从最长前缀开始逐步缩短(去掉最后一个 "-xxx" 段) +local path = chain +while true do + local i = string.find(path, "-[^-]*$") + if not i or i <= 1 then + break + end + path = string.sub(path, 1, i - 1) + val = redis.call('HGET', KEYS[1], path) if val and val ~= "" then - lastMatch = val + redis.call('EXPIRE', KEYS[1], ttl) + return val end end -if lastMatch then - redis.call('EXPIRE', KEYS[1], ttl) -end - -return lastMatch +return nil ` // geminiTrieSaveScript 保存会话到 Trie 的 Lua 脚本 diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go index 0ecd18aa..519207c9 100644 --- a/backend/internal/service/gateway_request.go +++ b/backend/internal/service/gateway_request.go @@ -9,6 +9,15 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" ) +// SessionContext 粘性会话上下文,用于区分不同来源的请求。 +// 仅在 GenerateSessionHash 第 3 级 fallback(消息内容 hash)时混入, +// 避免不同用户发送相同消息产生相同 hash 导致账号集中。 +type SessionContext struct { + ClientIP string + UserAgent string + APIKeyID int64 +} + // ParsedRequest 保存网关请求的预解析结果 // // 性能优化说明: @@ -22,15 +31,16 @@ import ( // 2. 将解析结果 ParsedRequest 传递给 Service 层 // 3. 避免重复 json.Unmarshal,减少 CPU 和内存开销 type ParsedRequest struct { - Body []byte // 原始请求体(保留用于转发) - Model string // 请求的模型名称 - Stream bool // 是否为流式请求 - MetadataUserID string // metadata.user_id(用于会话亲和) - System any // system 字段内容 - Messages []any // messages 数组 - HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入) - ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名) - MaxTokens int // max_tokens 值(用于探测请求拦截) + Body []byte // 原始请求体(保留用于转发) + Model string // 请求的模型名称 + Stream bool // 是否为流式请求 + MetadataUserID string // metadata.user_id(用于会话亲和) + System any // system 字段内容 + Messages []any // messages 数组 + HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入) + ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名) + MaxTokens int // max_tokens 值(用于探测请求拦截) + SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变) } // ParseGatewayRequest 解析网关请求体并返回结构化结果 diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 480f5b67..82fb0e04 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -5,7 +5,6 @@ import ( "bytes" "context" "crypto/sha256" - "encoding/hex" "encoding/json" "errors" "fmt" @@ -17,6 +16,7 @@ import ( "os" "regexp" "sort" + "strconv" "strings" "sync/atomic" "time" @@ -26,6 +26,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" + "github.com/cespare/xxhash/v2" "github.com/google/uuid" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -490,8 +491,17 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { return s.hashContent(cacheableContent) } - // 3. 最后 fallback: 使用 system + 所有消息的完整摘要串 + // 3. 最后 fallback: 使用 session上下文 + system + 所有消息的完整摘要串 var combined strings.Builder + // 混入请求上下文区分因子,避免不同用户相同消息产生相同 hash + if parsed.SessionContext != nil { + _, _ = combined.WriteString(parsed.SessionContext.ClientIP) + _, _ = combined.WriteString(":") + _, _ = combined.WriteString(parsed.SessionContext.UserAgent) + _, _ = combined.WriteString(":") + _, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10)) + _, _ = combined.WriteString("|") + } if parsed.System != nil { systemText := s.extractTextFromSystem(parsed.System) if systemText != "" { @@ -649,8 +659,8 @@ func (s *GatewayService) extractTextFromContent(content any) string { } func (s *GatewayService) hashContent(content string) string { - hash := sha256.Sum256([]byte(content)) - return hex.EncodeToString(hash[:16]) // 32字符 + h := xxhash.Sum64String(content) + return strconv.FormatUint(h, 36) } // replaceModelInBody 替换请求体中的model字段 diff --git a/backend/internal/service/generate_session_hash_test.go b/backend/internal/service/generate_session_hash_test.go new file mode 100644 index 00000000..f315164f --- /dev/null +++ b/backend/internal/service/generate_session_hash_test.go @@ -0,0 +1,843 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// ============ 基础优先级测试 ============ + +func TestGenerateSessionHash_NilParsedRequest(t *testing.T) { + svc := &GatewayService{} + require.Empty(t, svc.GenerateSessionHash(nil)) +} + +func TestGenerateSessionHash_EmptyRequest(t *testing.T) { + svc := &GatewayService{} + require.Empty(t, svc.GenerateSessionHash(&ParsedRequest{})) +} + +func TestGenerateSessionHash_MetadataHasHighestPriority(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000", + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + + hash := svc.GenerateSessionHash(parsed) + require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", hash, "metadata session_id should have highest priority") +} + +// ============ System + Messages 基础测试 ============ + +func TestGenerateSessionHash_SystemPlusMessages(t *testing.T) { + svc := &GatewayService{} + + withSystem := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + withoutSystem := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + + h1 := svc.GenerateSessionHash(withSystem) + h2 := svc.GenerateSessionHash(withoutSystem) + require.NotEmpty(t, h1) + require.NotEmpty(t, h2) + require.NotEqual(t, h1, h2, "system prompt should be part of digest, producing different hash") +} + +func TestGenerateSessionHash_SystemOnlyProducesHash(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + } + hash := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, hash, "system prompt alone should produce a hash as part of full digest") +} + +func TestGenerateSessionHash_DifferentSystemsSameMessages(t *testing.T) { + svc := &GatewayService{} + + parsed1 := &ParsedRequest{ + System: "You are assistant A.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + parsed2 := &ParsedRequest{ + System: "You are assistant B.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + + h1 := svc.GenerateSessionHash(parsed1) + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h1, h2, "different system prompts with same messages should produce different hashes") +} + +func TestGenerateSessionHash_SameSystemSameMessages(t *testing.T) { + svc := &GatewayService{} + + mk := func() *ParsedRequest { + return &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "hi"}, + }, + } + } + + h1 := svc.GenerateSessionHash(mk()) + h2 := svc.GenerateSessionHash(mk()) + require.Equal(t, h1, h2, "same system + same messages should produce identical hash") +} + +func TestGenerateSessionHash_DifferentMessagesProduceDifferentHash(t *testing.T) { + svc := &GatewayService{} + + parsed1 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "help me with Go"}, + }, + } + parsed2 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "help me with Python"}, + }, + } + + h1 := svc.GenerateSessionHash(parsed1) + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h1, h2, "same system but different messages should produce different hashes") +} + +// ============ SessionContext 核心测试 ============ + +func TestGenerateSessionHash_DifferentSessionContextProducesDifferentHash(t *testing.T) { + svc := &GatewayService{} + + // 相同消息 + 不同 SessionContext → 不同 hash(解决碰撞问题的核心场景) + parsed1 := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "192.168.1.1", + UserAgent: "Mozilla/5.0", + APIKeyID: 100, + }, + } + parsed2 := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "10.0.0.1", + UserAgent: "curl/7.0", + APIKeyID: 200, + }, + } + + h1 := svc.GenerateSessionHash(parsed1) + h2 := svc.GenerateSessionHash(parsed2) + require.NotEmpty(t, h1) + require.NotEmpty(t, h2) + require.NotEqual(t, h1, h2, "same messages but different SessionContext should produce different hashes") +} + +func TestGenerateSessionHash_SameSessionContextProducesSameHash(t *testing.T) { + svc := &GatewayService{} + + mk := func() *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "192.168.1.1", + UserAgent: "Mozilla/5.0", + APIKeyID: 100, + }, + } + } + + h1 := svc.GenerateSessionHash(mk()) + h2 := svc.GenerateSessionHash(mk()) + require.Equal(t, h1, h2, "same messages + same SessionContext should produce identical hash") +} + +func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000", + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "192.168.1.1", + UserAgent: "Mozilla/5.0", + APIKeyID: 100, + }, + } + + hash := svc.GenerateSessionHash(parsed) + require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", hash, + "metadata session_id should take priority over SessionContext") +} + +func TestGenerateSessionHash_NilSessionContextBackwardCompatible(t *testing.T) { + svc := &GatewayService{} + + withCtx := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: nil, + } + withoutCtx := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + + h1 := svc.GenerateSessionHash(withCtx) + h2 := svc.GenerateSessionHash(withoutCtx) + require.Equal(t, h1, h2, "nil SessionContext should produce same hash as no SessionContext") +} + +// ============ 多轮连续会话测试 ============ + +func TestGenerateSessionHash_ContinuousConversation_HashChangesWithMessages(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 模拟连续会话:每增加一轮对话,hash 应该不同(内容累积变化) + round1 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: ctx, + } + + round2 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "Hi there!"}, + map[string]any{"role": "user", "content": "How are you?"}, + }, + SessionContext: ctx, + } + + round3 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "Hi there!"}, + map[string]any{"role": "user", "content": "How are you?"}, + map[string]any{"role": "assistant", "content": "I'm doing well!"}, + map[string]any{"role": "user", "content": "Tell me a joke"}, + }, + SessionContext: ctx, + } + + h1 := svc.GenerateSessionHash(round1) + h2 := svc.GenerateSessionHash(round2) + h3 := svc.GenerateSessionHash(round3) + + require.NotEmpty(t, h1) + require.NotEmpty(t, h2) + require.NotEmpty(t, h3) + require.NotEqual(t, h1, h2, "different conversation rounds should produce different hashes") + require.NotEqual(t, h2, h3, "each new round should produce a different hash") + require.NotEqual(t, h1, h3, "round 1 and round 3 should differ") +} + +func TestGenerateSessionHash_ContinuousConversation_SameRoundSameHash(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 同一轮对话重复请求(如重试)应产生相同 hash + mk := func() *ParsedRequest { + return &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "Hi there!"}, + map[string]any{"role": "user", "content": "How are you?"}, + }, + SessionContext: ctx, + } + } + + h1 := svc.GenerateSessionHash(mk()) + h2 := svc.GenerateSessionHash(mk()) + require.Equal(t, h1, h2, "same conversation state should produce identical hash on retry") +} + +// ============ 消息回退测试 ============ + +func TestGenerateSessionHash_MessageRollback(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 模拟消息回退:用户删掉最后一轮再重发 + original := &ParsedRequest{ + System: "System prompt", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "msg1"}, + map[string]any{"role": "assistant", "content": "reply1"}, + map[string]any{"role": "user", "content": "msg2"}, + map[string]any{"role": "assistant", "content": "reply2"}, + map[string]any{"role": "user", "content": "msg3"}, + }, + SessionContext: ctx, + } + + // 回退到 msg2 后,用新的 msg3 替代 + rollback := &ParsedRequest{ + System: "System prompt", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "msg1"}, + map[string]any{"role": "assistant", "content": "reply1"}, + map[string]any{"role": "user", "content": "msg2"}, + map[string]any{"role": "assistant", "content": "reply2"}, + map[string]any{"role": "user", "content": "different msg3"}, + }, + SessionContext: ctx, + } + + hOrig := svc.GenerateSessionHash(original) + hRollback := svc.GenerateSessionHash(rollback) + require.NotEqual(t, hOrig, hRollback, "rollback with different last message should produce different hash") +} + +func TestGenerateSessionHash_MessageRollbackSameContent(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 回退后重新发送相同内容 → 相同 hash(合理的粘性恢复) + mk := func() *ParsedRequest { + return &ParsedRequest{ + System: "System prompt", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "msg1"}, + map[string]any{"role": "assistant", "content": "reply1"}, + map[string]any{"role": "user", "content": "msg2"}, + }, + SessionContext: ctx, + } + } + + h1 := svc.GenerateSessionHash(mk()) + h2 := svc.GenerateSessionHash(mk()) + require.Equal(t, h1, h2, "rollback and resend same content should produce same hash") +} + +// ============ 相同 System、不同用户消息 ============ + +func TestGenerateSessionHash_SameSystemDifferentUsers(t *testing.T) { + svc := &GatewayService{} + + // 两个不同用户使用相同 system prompt 但发送不同消息 + user1 := &ParsedRequest{ + System: "You are a code reviewer.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "Review this Go code"}, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: "vscode", + APIKeyID: 1, + }, + } + user2 := &ParsedRequest{ + System: "You are a code reviewer.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "Review this Python code"}, + }, + SessionContext: &SessionContext{ + ClientIP: "2.2.2.2", + UserAgent: "vscode", + APIKeyID: 2, + }, + } + + h1 := svc.GenerateSessionHash(user1) + h2 := svc.GenerateSessionHash(user2) + require.NotEqual(t, h1, h2, "different users with different messages should get different hashes") +} + +func TestGenerateSessionHash_SameSystemSameMessageDifferentContext(t *testing.T) { + svc := &GatewayService{} + + // 这是修复的核心场景:两个不同用户发送完全相同的 system + messages(如 "hello") + // 有了 SessionContext 后应该产生不同 hash + user1 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: "Mozilla/5.0", + APIKeyID: 10, + }, + } + user2 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "2.2.2.2", + UserAgent: "Mozilla/5.0", + APIKeyID: 20, + }, + } + + h1 := svc.GenerateSessionHash(user1) + h2 := svc.GenerateSessionHash(user2) + require.NotEqual(t, h1, h2, "CRITICAL: same system+messages but different users should get different hashes") +} + +// ============ SessionContext 各字段独立影响测试 ============ + +func TestGenerateSessionHash_SessionContext_IPDifference(t *testing.T) { + svc := &GatewayService{} + + base := func(ip string) *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "test"}, + }, + SessionContext: &SessionContext{ + ClientIP: ip, + UserAgent: "same-ua", + APIKeyID: 1, + }, + } + } + + h1 := svc.GenerateSessionHash(base("1.1.1.1")) + h2 := svc.GenerateSessionHash(base("2.2.2.2")) + require.NotEqual(t, h1, h2, "different IP should produce different hash") +} + +func TestGenerateSessionHash_SessionContext_UADifference(t *testing.T) { + svc := &GatewayService{} + + base := func(ua string) *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "test"}, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: ua, + APIKeyID: 1, + }, + } + } + + h1 := svc.GenerateSessionHash(base("Mozilla/5.0")) + h2 := svc.GenerateSessionHash(base("curl/7.0")) + require.NotEqual(t, h1, h2, "different User-Agent should produce different hash") +} + +func TestGenerateSessionHash_SessionContext_APIKeyIDDifference(t *testing.T) { + svc := &GatewayService{} + + base := func(keyID int64) *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "test"}, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: "same-ua", + APIKeyID: keyID, + }, + } + } + + h1 := svc.GenerateSessionHash(base(1)) + h2 := svc.GenerateSessionHash(base(2)) + require.NotEqual(t, h1, h2, "different APIKeyID should produce different hash") +} + +// ============ 多用户并发相同消息场景 ============ + +func TestGenerateSessionHash_MultipleUsersSameFirstMessage(t *testing.T) { + svc := &GatewayService{} + + // 模拟 5 个不同用户同时发送 "hello" → 应该产生 5 个不同的 hash + hashes := make(map[string]bool) + for i := 0; i < 5; i++ { + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "192.168.1." + string(rune('1'+i)), + UserAgent: "client-" + string(rune('A'+i)), + APIKeyID: int64(i + 1), + }, + } + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h) + require.False(t, hashes[h], "hash collision detected for user %d", i) + hashes[h] = true + } + require.Len(t, hashes, 5, "5 different users should produce 5 unique hashes") +} + +// ============ 连续会话粘性:多轮对话同一用户 ============ + +func TestGenerateSessionHash_SameUserGrowingConversation(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "browser", APIKeyID: 42} + + // 模拟同一用户的连续会话,每轮 hash 不同但同用户重试保持一致 + messages := []map[string]any{ + {"role": "user", "content": "msg1"}, + {"role": "assistant", "content": "reply1"}, + {"role": "user", "content": "msg2"}, + {"role": "assistant", "content": "reply2"}, + {"role": "user", "content": "msg3"}, + {"role": "assistant", "content": "reply3"}, + {"role": "user", "content": "msg4"}, + } + + prevHash := "" + for round := 1; round <= len(messages); round += 2 { + // 构建前 round 条消息 + msgs := make([]any, round) + for j := 0; j < round; j++ { + msgs[j] = messages[j] + } + parsed := &ParsedRequest{ + System: "System", + HasSystem: true, + Messages: msgs, + SessionContext: ctx, + } + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "round %d hash should not be empty", round) + + if prevHash != "" { + require.NotEqual(t, prevHash, h, "round %d hash should differ from previous round", round) + } + prevHash = h + + // 同一轮重试应该相同 + h2 := svc.GenerateSessionHash(parsed) + require.Equal(t, h, h2, "retry of round %d should produce same hash", round) + } +} + +// ============ 多轮消息内容结构化测试 ============ + +func TestGenerateSessionHash_MultipleUserMessages(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 5 条用户消息(无 assistant 回复) + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "first"}, + map[string]any{"role": "user", "content": "second"}, + map[string]any{"role": "user", "content": "third"}, + map[string]any{"role": "user", "content": "fourth"}, + map[string]any{"role": "user", "content": "fifth"}, + }, + SessionContext: ctx, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h) + + // 修改中间一条消息应该改变 hash + parsed2 := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "first"}, + map[string]any{"role": "user", "content": "CHANGED"}, + map[string]any{"role": "user", "content": "third"}, + map[string]any{"role": "user", "content": "fourth"}, + map[string]any{"role": "user", "content": "fifth"}, + }, + SessionContext: ctx, + } + + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h, h2, "changing any message should change the hash") +} + +func TestGenerateSessionHash_MessageOrderMatters(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + parsed1 := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "alpha"}, + map[string]any{"role": "user", "content": "beta"}, + }, + SessionContext: ctx, + } + parsed2 := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "beta"}, + map[string]any{"role": "user", "content": "alpha"}, + }, + SessionContext: ctx, + } + + h1 := svc.GenerateSessionHash(parsed1) + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h1, h2, "message order should affect the hash") +} + +// ============ 复杂内容格式测试 ============ + +func TestGenerateSessionHash_StructuredContent(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 结构化 content(数组形式) + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "text", "text": "Look at this"}, + map[string]any{"type": "text", "text": "And this too"}, + }, + }, + }, + SessionContext: ctx, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "structured content should produce a hash") +} + +func TestGenerateSessionHash_ArraySystemPrompt(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 数组格式的 system prompt + parsed := &ParsedRequest{ + System: []any{ + map[string]any{"type": "text", "text": "You are a helpful assistant."}, + map[string]any{"type": "text", "text": "Be concise."}, + }, + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: ctx, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "array system prompt should produce a hash") +} + +// ============ SessionContext 与 cache_control 优先级 ============ + +func TestGenerateSessionHash_CacheControlOverridesSessionContext(t *testing.T) { + svc := &GatewayService{} + + // 当有 cache_control: ephemeral 时,使用第 2 级优先级 + // SessionContext 不应影响结果 + parsed1 := &ParsedRequest{ + System: []any{ + map[string]any{ + "type": "text", + "text": "You are a tool-specific assistant.", + "cache_control": map[string]any{"type": "ephemeral"}, + }, + }, + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: "ua1", + APIKeyID: 100, + }, + } + parsed2 := &ParsedRequest{ + System: []any{ + map[string]any{ + "type": "text", + "text": "You are a tool-specific assistant.", + "cache_control": map[string]any{"type": "ephemeral"}, + }, + }, + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "2.2.2.2", + UserAgent: "ua2", + APIKeyID: 200, + }, + } + + h1 := svc.GenerateSessionHash(parsed1) + h2 := svc.GenerateSessionHash(parsed2) + require.Equal(t, h1, h2, "cache_control ephemeral has higher priority, SessionContext should not affect result") +} + +// ============ 边界情况 ============ + +func TestGenerateSessionHash_EmptyMessages(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + Messages: []any{}, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: "test", + APIKeyID: 1, + }, + } + + // 空 messages + 只有 SessionContext 时,combined.Len() > 0 因为有 context 写入 + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "empty messages with SessionContext should still produce a hash from context") +} + +func TestGenerateSessionHash_EmptyMessagesNoContext(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + Messages: []any{}, + } + + h := svc.GenerateSessionHash(parsed) + require.Empty(t, h, "empty messages without SessionContext should produce empty hash") +} + +func TestGenerateSessionHash_SessionContextWithEmptyFields(t *testing.T) { + svc := &GatewayService{} + + // SessionContext 字段为空字符串和零值时仍应影响 hash + withEmptyCtx := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "test"}, + }, + SessionContext: &SessionContext{ + ClientIP: "", + UserAgent: "", + APIKeyID: 0, + }, + } + withoutCtx := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "test"}, + }, + } + + h1 := svc.GenerateSessionHash(withEmptyCtx) + h2 := svc.GenerateSessionHash(withoutCtx) + // 有 SessionContext(即使字段为空)仍然会写入分隔符 "::" 等 + require.NotEqual(t, h1, h2, "empty-field SessionContext should still differ from nil SessionContext") +} + +// ============ 长对话历史测试 ============ + +func TestGenerateSessionHash_LongConversation(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 构建 20 轮对话 + messages := make([]any, 0, 40) + for i := 0; i < 20; i++ { + messages = append(messages, map[string]any{ + "role": "user", + "content": "user message " + string(rune('A'+i)), + }) + messages = append(messages, map[string]any{ + "role": "assistant", + "content": "assistant reply " + string(rune('A'+i)), + }) + } + + parsed := &ParsedRequest{ + System: "System prompt", + HasSystem: true, + Messages: messages, + SessionContext: ctx, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h) + + // 再加一轮应该不同 + moreMessages := make([]any, len(messages)+2) + copy(moreMessages, messages) + moreMessages[len(messages)] = map[string]any{"role": "user", "content": "one more"} + moreMessages[len(messages)+1] = map[string]any{"role": "assistant", "content": "ok"} + + parsed2 := &ParsedRequest{ + System: "System prompt", + HasSystem: true, + Messages: moreMessages, + SessionContext: ctx, + } + + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h, h2, "adding more messages to long conversation should change hash") +}