From 35598d5648aaca514ebb8741bf37bd5a849dffab Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 9 Feb 2026 06:47:22 +0800 Subject: [PATCH] fix: parse Gemini native request format in ParseGatewayRequest for correct session hash generation ParseGatewayRequest only parsed Anthropic format (system/messages), ignoring Gemini native format (systemInstruction/contents). This caused GenerateSessionHash to produce identical hashes for all Gemini sessions. Add protocol parameter to ParseGatewayRequest to branch between Anthropic and Gemini parsing. Update GenerateSessionHash message traversal to extract text from both formats. --- backend/internal/handler/gateway_handler.go | 5 +- .../internal/handler/gemini_v1beta_handler.go | 3 +- backend/internal/service/gateway_request.go | 39 +- .../internal/service/gateway_request_test.go | 110 +++++- backend/internal/service/gateway_service.go | 17 +- .../service/gateway_service_benchmark_test.go | 2 +- .../service/generate_session_hash_test.go | 370 ++++++++++++++++++ backend/internal/service/ops_retry.go | 5 +- 8 files changed, 524 insertions(+), 27 deletions(-) diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 91348608..4daa874a 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -13,6 +13,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" @@ -114,7 +115,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { setOpsRequestContext(c, "", false, body) - parsedReq, err := service.ParseGatewayRequest(body) + parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic) if err != nil { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return @@ -939,7 +940,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { setOpsRequestContext(c, "", false, body) - parsedReq, err := service.ParseGatewayRequest(body) + parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic) if err != nil { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index cd7c2d3f..8f0a7bb3 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -14,6 +14,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/gemini" @@ -232,7 +233,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { sessionHash := extractGeminiCLISessionHash(c, body) if sessionHash == "" { // Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端) - parsedReq, _ := service.ParseGatewayRequest(body) + parsedReq, _ := service.ParseGatewayRequest(body, domain.PlatformGemini) if parsedReq != nil { parsedReq.SessionContext = &service.SessionContext{ ClientIP: ip.GetClientIP(c), diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go index 519207c9..c039f030 100644 --- a/backend/internal/service/gateway_request.go +++ b/backend/internal/service/gateway_request.go @@ -6,6 +6,7 @@ import ( "fmt" "math" + "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" ) @@ -43,9 +44,10 @@ type ParsedRequest struct { SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变) } -// ParseGatewayRequest 解析网关请求体并返回结构化结果 -// 性能优化:一次解析提取所有需要的字段,避免重复 Unmarshal -func ParseGatewayRequest(body []byte) (*ParsedRequest, error) { +// ParseGatewayRequest 解析网关请求体并返回结构化结果。 +// protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini), +// 不同协议使用不同的 system/messages 字段名。 +func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) { var req map[string]any if err := json.Unmarshal(body, &req); err != nil { return nil, err @@ -74,14 +76,29 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) { parsed.MetadataUserID = userID } } - // system 字段只要存在就视为显式提供(即使为 null), - // 以避免客户端传 null 时被默认 system 误注入。 - if system, ok := req["system"]; ok { - parsed.HasSystem = true - parsed.System = system - } - if messages, ok := req["messages"].([]any); ok { - parsed.Messages = messages + + switch protocol { + case domain.PlatformGemini: + // Gemini 原生格式: systemInstruction.parts / contents + if sysInst, ok := req["systemInstruction"].(map[string]any); ok { + if parts, ok := sysInst["parts"].([]any); ok { + parsed.System = parts + } + } + if contents, ok := req["contents"].([]any); ok { + parsed.Messages = contents + } + default: + // Anthropic / OpenAI 格式: system / messages + // system 字段只要存在就视为显式提供(即使为 null), + // 以避免客户端传 null 时被默认 system 误注入。 + if system, ok := req["system"]; ok { + parsed.HasSystem = true + parsed.System = system + } + if messages, ok := req["messages"].([]any); ok { + parsed.Messages = messages + } } // thinking: {type: "enabled"} diff --git a/backend/internal/service/gateway_request_test.go b/backend/internal/service/gateway_request_test.go index 4e390b0a..c7519d6e 100644 --- a/backend/internal/service/gateway_request_test.go +++ b/backend/internal/service/gateway_request_test.go @@ -4,12 +4,13 @@ import ( "encoding/json" "testing" + "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/stretchr/testify/require" ) func TestParseGatewayRequest(t *testing.T) { body := []byte(`{"model":"claude-3-7-sonnet","stream":true,"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"system":[{"type":"text","text":"hello","cache_control":{"type":"ephemeral"}}],"messages":[{"content":"hi"}]}`) - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) require.Equal(t, "claude-3-7-sonnet", parsed.Model) require.True(t, parsed.Stream) @@ -22,7 +23,7 @@ func TestParseGatewayRequest(t *testing.T) { func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) { body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"enabled"},"messages":[{"content":"hi"}]}`) - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) require.Equal(t, "claude-sonnet-4-5", parsed.Model) require.True(t, parsed.ThinkingEnabled) @@ -30,21 +31,21 @@ func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) { func TestParseGatewayRequest_MaxTokens(t *testing.T) { body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1}`) - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) require.Equal(t, 1, parsed.MaxTokens) } func TestParseGatewayRequest_MaxTokensNonIntegralIgnored(t *testing.T) { body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1.5}`) - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) require.Equal(t, 0, parsed.MaxTokens) } func TestParseGatewayRequest_SystemNull(t *testing.T) { body := []byte(`{"model":"claude-3","system":null}`) - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) // 显式传入 system:null 也应视为“字段已存在”,避免默认 system 被注入。 require.True(t, parsed.HasSystem) @@ -53,16 +54,111 @@ func TestParseGatewayRequest_SystemNull(t *testing.T) { func TestParseGatewayRequest_InvalidModelType(t *testing.T) { body := []byte(`{"model":123}`) - _, err := ParseGatewayRequest(body) + _, err := ParseGatewayRequest(body, "") require.Error(t, err) } func TestParseGatewayRequest_InvalidStreamType(t *testing.T) { body := []byte(`{"stream":"true"}`) - _, err := ParseGatewayRequest(body) + _, err := ParseGatewayRequest(body, "") require.Error(t, err) } +// ============ Gemini 原生格式解析测试 ============ + +func TestParseGatewayRequest_GeminiContents(t *testing.T) { + body := []byte(`{ + "contents": [ + {"role": "user", "parts": [{"text": "Hello"}]}, + {"role": "model", "parts": [{"text": "Hi there"}]}, + {"role": "user", "parts": [{"text": "How are you?"}]} + ] + }`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.Len(t, parsed.Messages, 3, "should parse contents as Messages") + require.False(t, parsed.HasSystem, "Gemini format should not set HasSystem") + require.Nil(t, parsed.System, "no systemInstruction means nil System") +} + +func TestParseGatewayRequest_GeminiSystemInstruction(t *testing.T) { + body := []byte(`{ + "systemInstruction": { + "parts": [{"text": "You are a helpful assistant."}] + }, + "contents": [ + {"role": "user", "parts": [{"text": "Hello"}]} + ] + }`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.NotNil(t, parsed.System, "should parse systemInstruction.parts as System") + parts, ok := parsed.System.([]any) + require.True(t, ok) + require.Len(t, parts, 1) + partMap, ok := parts[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "You are a helpful assistant.", partMap["text"]) + require.Len(t, parsed.Messages, 1) +} + +func TestParseGatewayRequest_GeminiWithModel(t *testing.T) { + body := []byte(`{ + "model": "gemini-2.5-pro", + "contents": [{"role": "user", "parts": [{"text": "test"}]}] + }`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.Equal(t, "gemini-2.5-pro", parsed.Model) + require.Len(t, parsed.Messages, 1) +} + +func TestParseGatewayRequest_GeminiIgnoresAnthropicFields(t *testing.T) { + // Gemini 格式下 system/messages 字段应被忽略 + body := []byte(`{ + "system": "should be ignored", + "messages": [{"role": "user", "content": "ignored"}], + "contents": [{"role": "user", "parts": [{"text": "real content"}]}] + }`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.False(t, parsed.HasSystem, "Gemini protocol should not parse Anthropic system field") + require.Nil(t, parsed.System, "no systemInstruction = nil System") + require.Len(t, parsed.Messages, 1, "should use contents, not messages") +} + +func TestParseGatewayRequest_GeminiEmptyContents(t *testing.T) { + body := []byte(`{"contents": []}`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.Empty(t, parsed.Messages) +} + +func TestParseGatewayRequest_GeminiNoContents(t *testing.T) { + body := []byte(`{"model": "gemini-2.5-flash"}`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.Nil(t, parsed.Messages) + require.Equal(t, "gemini-2.5-flash", parsed.Model) +} + +func TestParseGatewayRequest_AnthropicIgnoresGeminiFields(t *testing.T) { + // Anthropic 格式下 contents/systemInstruction 字段应被忽略 + body := []byte(`{ + "system": "real system", + "messages": [{"role": "user", "content": "real content"}], + "contents": [{"role": "user", "parts": [{"text": "ignored"}]}], + "systemInstruction": {"parts": [{"text": "ignored"}]} + }`) + parsed, err := ParseGatewayRequest(body, domain.PlatformAnthropic) + require.NoError(t, err) + require.True(t, parsed.HasSystem) + require.Equal(t, "real system", parsed.System) + require.Len(t, parsed.Messages, 1) + msg := parsed.Messages[0].(map[string]any) + require.Equal(t, "real content", msg["content"]) +} + func TestFilterThinkingBlocks(t *testing.T) { containsThinkingBlock := func(body []byte) bool { var req map[string]any diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 82fb0e04..6572f25d 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -510,9 +510,20 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { } for _, msg := range parsed.Messages { if m, ok := msg.(map[string]any); ok { - msgText := s.extractTextFromContent(m["content"]) - if msgText != "" { - _, _ = combined.WriteString(msgText) + if content, exists := m["content"]; exists { + // Anthropic: messages[].content + if msgText := s.extractTextFromContent(content); msgText != "" { + _, _ = combined.WriteString(msgText) + } + } else if parts, ok := m["parts"].([]any); ok { + // Gemini: contents[].parts[].text + for _, part := range parts { + if partMap, ok := part.(map[string]any); ok { + if text, ok := partMap["text"].(string); ok { + _, _ = combined.WriteString(text) + } + } + } } } } diff --git a/backend/internal/service/gateway_service_benchmark_test.go b/backend/internal/service/gateway_service_benchmark_test.go index f15a85d6..c9c4d3dd 100644 --- a/backend/internal/service/gateway_service_benchmark_test.go +++ b/backend/internal/service/gateway_service_benchmark_test.go @@ -14,7 +14,7 @@ func BenchmarkGenerateSessionHash_Metadata(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") if err != nil { b.Fatalf("解析请求失败: %v", err) } diff --git a/backend/internal/service/generate_session_hash_test.go b/backend/internal/service/generate_session_hash_test.go index f315164f..8aa358a5 100644 --- a/backend/internal/service/generate_session_hash_test.go +++ b/backend/internal/service/generate_session_hash_test.go @@ -841,3 +841,373 @@ func TestGenerateSessionHash_LongConversation(t *testing.T) { h2 := svc.GenerateSessionHash(parsed2) require.NotEqual(t, h, h2, "adding more messages to long conversation should change hash") } + +// ============ Gemini 原生格式 session hash 测试 ============ + +func TestGenerateSessionHash_GeminiContentsProducesHash(t *testing.T) { + svc := &GatewayService{} + + // Gemini 格式: contents[].parts[].text + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Hello from Gemini"}, + }, + }, + }, + SessionContext: &SessionContext{ + ClientIP: "1.2.3.4", + UserAgent: "gemini-cli", + APIKeyID: 1, + }, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "Gemini contents with parts should produce a non-empty hash") +} + +func TestGenerateSessionHash_GeminiDifferentContentsDifferentHash(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + parsed1 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Hello"}, + }, + }, + }, + SessionContext: ctx, + } + parsed2 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Goodbye"}, + }, + }, + }, + SessionContext: ctx, + } + + h1 := svc.GenerateSessionHash(parsed1) + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h1, h2, "different Gemini contents should produce different hashes") +} + +func TestGenerateSessionHash_GeminiSameContentsSameHash(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + mk := func() *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Hello"}, + }, + }, + map[string]any{ + "role": "model", + "parts": []any{ + map[string]any{"text": "Hi there!"}, + }, + }, + }, + SessionContext: ctx, + } + } + + h1 := svc.GenerateSessionHash(mk()) + h2 := svc.GenerateSessionHash(mk()) + require.Equal(t, h1, h2, "same Gemini contents should produce identical hash") +} + +func TestGenerateSessionHash_GeminiMultiTurnHashChanges(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + round1 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + }, + SessionContext: ctx, + } + + round2 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + map[string]any{ + "role": "model", + "parts": []any{map[string]any{"text": "Hi!"}}, + }, + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "How are you?"}}, + }, + }, + SessionContext: ctx, + } + + h1 := svc.GenerateSessionHash(round1) + h2 := svc.GenerateSessionHash(round2) + require.NotEmpty(t, h1) + require.NotEmpty(t, h2) + require.NotEqual(t, h1, h2, "Gemini multi-turn should produce different hashes per round") +} + +func TestGenerateSessionHash_GeminiDifferentUsersSameContentDifferentHash(t *testing.T) { + svc := &GatewayService{} + + // 核心场景:两个不同用户发送相同 Gemini 格式消息应得到不同 hash + user1 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: "gemini-cli", + APIKeyID: 10, + }, + } + user2 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + }, + SessionContext: &SessionContext{ + ClientIP: "2.2.2.2", + UserAgent: "gemini-cli", + APIKeyID: 20, + }, + } + + h1 := svc.GenerateSessionHash(user1) + h2 := svc.GenerateSessionHash(user2) + require.NotEqual(t, h1, h2, "CRITICAL: different Gemini users with same content must get different hashes") +} + +func TestGenerateSessionHash_GeminiSystemInstructionAffectsHash(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + // systemInstruction 经 ParseGatewayRequest 解析后存入 parsed.System + withSys := &ParsedRequest{ + System: []any{ + map[string]any{"text": "You are a coding assistant."}, + }, + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + }, + SessionContext: ctx, + } + withoutSys := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + }, + SessionContext: ctx, + } + + h1 := svc.GenerateSessionHash(withSys) + h2 := svc.GenerateSessionHash(withoutSys) + require.NotEqual(t, h1, h2, "systemInstruction should affect the hash") +} + +func TestGenerateSessionHash_GeminiMultiPartMessage(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + // 多 parts 的消息 + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Part 1"}, + map[string]any{"text": "Part 2"}, + map[string]any{"text": "Part 3"}, + }, + }, + }, + SessionContext: ctx, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "multi-part Gemini message should produce a hash") + + // 不同内容的多 parts + parsed2 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Part 1"}, + map[string]any{"text": "CHANGED"}, + map[string]any{"text": "Part 3"}, + }, + }, + }, + SessionContext: ctx, + } + + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h, h2, "changing a part should change the hash") +} + +func TestGenerateSessionHash_GeminiNonTextPartsIgnored(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + // 含非 text 类型 parts(如 inline_data),应被跳过但不报错 + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Describe this image"}, + map[string]any{"inline_data": map[string]any{"mime_type": "image/png", "data": "base64..."}}, + }, + }, + }, + SessionContext: ctx, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "Gemini message with mixed parts should still produce a hash from text parts") +} + +func TestGenerateSessionHash_GeminiMultiTurnHashNotSticky(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "10.0.0.1", UserAgent: "gemini-cli", APIKeyID: 42} + + // 模拟同一 Gemini 会话的三轮请求,每轮 contents 累积增长。 + // 验证预期行为:每轮 hash 都不同,即 GenerateSessionHash 不具备跨轮粘性。 + // 这是 by-design 的——Gemini 的跨轮粘性由 Digest Fallback(BuildGeminiDigestChain)负责。 + round1Body := []byte(`{ + "systemInstruction": {"parts": [{"text": "You are a coding assistant."}]}, + "contents": [ + {"role": "user", "parts": [{"text": "Write a Go function"}]} + ] + }`) + round2Body := []byte(`{ + "systemInstruction": {"parts": [{"text": "You are a coding assistant."}]}, + "contents": [ + {"role": "user", "parts": [{"text": "Write a Go function"}]}, + {"role": "model", "parts": [{"text": "func hello() {}"}]}, + {"role": "user", "parts": [{"text": "Add error handling"}]} + ] + }`) + round3Body := []byte(`{ + "systemInstruction": {"parts": [{"text": "You are a coding assistant."}]}, + "contents": [ + {"role": "user", "parts": [{"text": "Write a Go function"}]}, + {"role": "model", "parts": [{"text": "func hello() {}"}]}, + {"role": "user", "parts": [{"text": "Add error handling"}]}, + {"role": "model", "parts": [{"text": "func hello() error { return nil }"}]}, + {"role": "user", "parts": [{"text": "Now add tests"}]} + ] + }`) + + hashes := make([]string, 3) + for i, body := range [][]byte{round1Body, round2Body, round3Body} { + parsed, err := ParseGatewayRequest(body, "gemini") + require.NoError(t, err) + parsed.SessionContext = ctx + hashes[i] = svc.GenerateSessionHash(parsed) + require.NotEmpty(t, hashes[i], "round %d hash should not be empty", i+1) + } + + // 每轮 hash 都不同——这是预期行为 + require.NotEqual(t, hashes[0], hashes[1], "round 1 vs 2 hash should differ (contents grow)") + require.NotEqual(t, hashes[1], hashes[2], "round 2 vs 3 hash should differ (contents grow)") + require.NotEqual(t, hashes[0], hashes[2], "round 1 vs 3 hash should differ") + + // 同一轮重试应产生相同 hash + parsed1Again, err := ParseGatewayRequest(round2Body, "gemini") + require.NoError(t, err) + parsed1Again.SessionContext = ctx + h2Again := svc.GenerateSessionHash(parsed1Again) + require.Equal(t, hashes[1], h2Again, "retry of same round should produce same hash") +} + +func TestGenerateSessionHash_GeminiEndToEnd(t *testing.T) { + svc := &GatewayService{} + + // 端到端测试:模拟 ParseGatewayRequest + GenerateSessionHash 完整流程 + body := []byte(`{ + "model": "gemini-2.5-pro", + "systemInstruction": { + "parts": [{"text": "You are a coding assistant."}] + }, + "contents": [ + {"role": "user", "parts": [{"text": "Write a Go function"}]}, + {"role": "model", "parts": [{"text": "Here is a function..."}]}, + {"role": "user", "parts": [{"text": "Now add error handling"}]} + ] + }`) + + parsed, err := ParseGatewayRequest(body, "gemini") + require.NoError(t, err) + parsed.SessionContext = &SessionContext{ + ClientIP: "10.0.0.1", + UserAgent: "gemini-cli/1.0", + APIKeyID: 42, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "end-to-end Gemini flow should produce a hash") + + // 同一请求再次解析应产生相同 hash + parsed2, err := ParseGatewayRequest(body, "gemini") + require.NoError(t, err) + parsed2.SessionContext = &SessionContext{ + ClientIP: "10.0.0.1", + UserAgent: "gemini-cli/1.0", + APIKeyID: 42, + } + + h2 := svc.GenerateSessionHash(parsed2) + require.Equal(t, h, h2, "same request should produce same hash") + + // 不同用户发送相同请求应产生不同 hash + parsed3, err := ParseGatewayRequest(body, "gemini") + require.NoError(t, err) + parsed3.SessionContext = &SessionContext{ + ClientIP: "10.0.0.2", + UserAgent: "gemini-cli/1.0", + APIKeyID: 99, + } + + h3 := svc.GenerateSessionHash(parsed3) + require.NotEqual(t, h, h3, "different user with same Gemini request should get different hash") +} diff --git a/backend/internal/service/ops_retry.go b/backend/internal/service/ops_retry.go index fbc800f2..23a524ad 100644 --- a/backend/internal/service/ops_retry.go +++ b/backend/internal/service/ops_retry.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/gin-gonic/gin" @@ -528,7 +529,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry func extractRetryModelAndStream(reqType opsRetryRequestType, errorLog *OpsErrorLogDetail, body []byte) (model string, stream bool, err error) { switch reqType { case opsRetryTypeMessages: - parsed, parseErr := ParseGatewayRequest(body) + parsed, parseErr := ParseGatewayRequest(body, domain.PlatformAnthropic) if parseErr != nil { return "", false, fmt.Errorf("failed to parse messages request body: %w", parseErr) } @@ -596,7 +597,7 @@ func (s *OpsService) executeWithAccount(ctx context.Context, reqType opsRetryReq if s.gatewayService == nil { return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "gateway service not available"} } - parsedReq, parseErr := ParseGatewayRequest(body) + parsedReq, parseErr := ParseGatewayRequest(body, domain.PlatformAnthropic) if parseErr != nil { return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "failed to parse request body"} }