diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 91348608..4daa874a 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -13,6 +13,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" @@ -114,7 +115,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { setOpsRequestContext(c, "", false, body) - parsedReq, err := service.ParseGatewayRequest(body) + parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic) if err != nil { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return @@ -939,7 +940,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { setOpsRequestContext(c, "", false, body) - parsedReq, err := service.ParseGatewayRequest(body) + parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic) if err != nil { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index cd7c2d3f..8f0a7bb3 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -14,6 +14,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/gemini" @@ -232,7 +233,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { sessionHash := extractGeminiCLISessionHash(c, body) if sessionHash == "" { // Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端) - parsedReq, _ := service.ParseGatewayRequest(body) + parsedReq, _ := service.ParseGatewayRequest(body, domain.PlatformGemini) if parsedReq != nil { parsedReq.SessionContext = &service.SessionContext{ ClientIP: ip.GetClientIP(c), diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go index 519207c9..c039f030 100644 --- a/backend/internal/service/gateway_request.go +++ b/backend/internal/service/gateway_request.go @@ -6,6 +6,7 @@ import ( "fmt" "math" + "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" ) @@ -43,9 +44,10 @@ type ParsedRequest struct { SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变) } -// ParseGatewayRequest 解析网关请求体并返回结构化结果 -// 性能优化:一次解析提取所有需要的字段,避免重复 Unmarshal -func ParseGatewayRequest(body []byte) (*ParsedRequest, error) { +// ParseGatewayRequest 解析网关请求体并返回结构化结果。 +// protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini), +// 不同协议使用不同的 system/messages 字段名。 +func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) { var req map[string]any if err := json.Unmarshal(body, &req); err != nil { return nil, err @@ -74,14 +76,29 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) { parsed.MetadataUserID = userID } } - // system 字段只要存在就视为显式提供(即使为 null), - // 以避免客户端传 null 时被默认 system 误注入。 - if system, ok := req["system"]; ok { - parsed.HasSystem = true - parsed.System = system - } - if messages, ok := req["messages"].([]any); ok { - parsed.Messages = messages + + switch protocol { + case domain.PlatformGemini: + // Gemini 原生格式: systemInstruction.parts / contents + if sysInst, ok := req["systemInstruction"].(map[string]any); ok { + if parts, ok := sysInst["parts"].([]any); ok { + parsed.System = parts + } + } + if contents, ok := req["contents"].([]any); ok { + parsed.Messages = contents + } + default: + // Anthropic / OpenAI 格式: system / messages + // system 字段只要存在就视为显式提供(即使为 null), + // 以避免客户端传 null 时被默认 system 误注入。 + if system, ok := req["system"]; ok { + parsed.HasSystem = true + parsed.System = system + } + if messages, ok := req["messages"].([]any); ok { + parsed.Messages = messages + } } // thinking: {type: "enabled"} diff --git a/backend/internal/service/gateway_request_test.go b/backend/internal/service/gateway_request_test.go index 4e390b0a..c7519d6e 100644 --- a/backend/internal/service/gateway_request_test.go +++ b/backend/internal/service/gateway_request_test.go @@ -4,12 +4,13 @@ import ( "encoding/json" "testing" + "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/stretchr/testify/require" ) func TestParseGatewayRequest(t *testing.T) { body := []byte(`{"model":"claude-3-7-sonnet","stream":true,"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"system":[{"type":"text","text":"hello","cache_control":{"type":"ephemeral"}}],"messages":[{"content":"hi"}]}`) - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) require.Equal(t, "claude-3-7-sonnet", parsed.Model) require.True(t, parsed.Stream) @@ -22,7 +23,7 @@ func TestParseGatewayRequest(t *testing.T) { func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) { body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"enabled"},"messages":[{"content":"hi"}]}`) - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) require.Equal(t, "claude-sonnet-4-5", parsed.Model) require.True(t, parsed.ThinkingEnabled) @@ -30,21 +31,21 @@ func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) { func TestParseGatewayRequest_MaxTokens(t *testing.T) { body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1}`) - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) require.Equal(t, 1, parsed.MaxTokens) } func TestParseGatewayRequest_MaxTokensNonIntegralIgnored(t *testing.T) { body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1.5}`) - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) require.Equal(t, 0, parsed.MaxTokens) } func TestParseGatewayRequest_SystemNull(t *testing.T) { body := []byte(`{"model":"claude-3","system":null}`) - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) // 显式传入 system:null 也应视为“字段已存在”,避免默认 system 被注入。 require.True(t, parsed.HasSystem) @@ -53,16 +54,111 @@ func TestParseGatewayRequest_SystemNull(t *testing.T) { func TestParseGatewayRequest_InvalidModelType(t *testing.T) { body := []byte(`{"model":123}`) - _, err := ParseGatewayRequest(body) + _, err := ParseGatewayRequest(body, "") require.Error(t, err) } func TestParseGatewayRequest_InvalidStreamType(t *testing.T) { body := []byte(`{"stream":"true"}`) - _, err := ParseGatewayRequest(body) + _, err := ParseGatewayRequest(body, "") require.Error(t, err) } +// ============ Gemini 原生格式解析测试 ============ + +func TestParseGatewayRequest_GeminiContents(t *testing.T) { + body := []byte(`{ + "contents": [ + {"role": "user", "parts": [{"text": "Hello"}]}, + {"role": "model", "parts": [{"text": "Hi there"}]}, + {"role": "user", "parts": [{"text": "How are you?"}]} + ] + }`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.Len(t, parsed.Messages, 3, "should parse contents as Messages") + require.False(t, parsed.HasSystem, "Gemini format should not set HasSystem") + require.Nil(t, parsed.System, "no systemInstruction means nil System") +} + +func TestParseGatewayRequest_GeminiSystemInstruction(t *testing.T) { + body := []byte(`{ + "systemInstruction": { + "parts": [{"text": "You are a helpful assistant."}] + }, + "contents": [ + {"role": "user", "parts": [{"text": "Hello"}]} + ] + }`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.NotNil(t, parsed.System, "should parse systemInstruction.parts as System") + parts, ok := parsed.System.([]any) + require.True(t, ok) + require.Len(t, parts, 1) + partMap, ok := parts[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "You are a helpful assistant.", partMap["text"]) + require.Len(t, parsed.Messages, 1) +} + +func TestParseGatewayRequest_GeminiWithModel(t *testing.T) { + body := []byte(`{ + "model": "gemini-2.5-pro", + "contents": [{"role": "user", "parts": [{"text": "test"}]}] + }`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.Equal(t, "gemini-2.5-pro", parsed.Model) + require.Len(t, parsed.Messages, 1) +} + +func TestParseGatewayRequest_GeminiIgnoresAnthropicFields(t *testing.T) { + // Gemini 格式下 system/messages 字段应被忽略 + body := []byte(`{ + "system": "should be ignored", + "messages": [{"role": "user", "content": "ignored"}], + "contents": [{"role": "user", "parts": [{"text": "real content"}]}] + }`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.False(t, parsed.HasSystem, "Gemini protocol should not parse Anthropic system field") + require.Nil(t, parsed.System, "no systemInstruction = nil System") + require.Len(t, parsed.Messages, 1, "should use contents, not messages") +} + +func TestParseGatewayRequest_GeminiEmptyContents(t *testing.T) { + body := []byte(`{"contents": []}`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.Empty(t, parsed.Messages) +} + +func TestParseGatewayRequest_GeminiNoContents(t *testing.T) { + body := []byte(`{"model": "gemini-2.5-flash"}`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.Nil(t, parsed.Messages) + require.Equal(t, "gemini-2.5-flash", parsed.Model) +} + +func TestParseGatewayRequest_AnthropicIgnoresGeminiFields(t *testing.T) { + // Anthropic 格式下 contents/systemInstruction 字段应被忽略 + body := []byte(`{ + "system": "real system", + "messages": [{"role": "user", "content": "real content"}], + "contents": [{"role": "user", "parts": [{"text": "ignored"}]}], + "systemInstruction": {"parts": [{"text": "ignored"}]} + }`) + parsed, err := ParseGatewayRequest(body, domain.PlatformAnthropic) + require.NoError(t, err) + require.True(t, parsed.HasSystem) + require.Equal(t, "real system", parsed.System) + require.Len(t, parsed.Messages, 1) + msg := parsed.Messages[0].(map[string]any) + require.Equal(t, "real content", msg["content"]) +} + func TestFilterThinkingBlocks(t *testing.T) { containsThinkingBlock := func(body []byte) bool { var req map[string]any diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 82fb0e04..6572f25d 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -510,9 +510,20 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { } for _, msg := range parsed.Messages { if m, ok := msg.(map[string]any); ok { - msgText := s.extractTextFromContent(m["content"]) - if msgText != "" { - _, _ = combined.WriteString(msgText) + if content, exists := m["content"]; exists { + // Anthropic: messages[].content + if msgText := s.extractTextFromContent(content); msgText != "" { + _, _ = combined.WriteString(msgText) + } + } else if parts, ok := m["parts"].([]any); ok { + // Gemini: contents[].parts[].text + for _, part := range parts { + if partMap, ok := part.(map[string]any); ok { + if text, ok := partMap["text"].(string); ok { + _, _ = combined.WriteString(text) + } + } + } } } } diff --git a/backend/internal/service/gateway_service_benchmark_test.go b/backend/internal/service/gateway_service_benchmark_test.go index f15a85d6..c9c4d3dd 100644 --- a/backend/internal/service/gateway_service_benchmark_test.go +++ b/backend/internal/service/gateway_service_benchmark_test.go @@ -14,7 +14,7 @@ func BenchmarkGenerateSessionHash_Metadata(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") if err != nil { b.Fatalf("解析请求失败: %v", err) } diff --git a/backend/internal/service/generate_session_hash_test.go b/backend/internal/service/generate_session_hash_test.go index f315164f..8aa358a5 100644 --- a/backend/internal/service/generate_session_hash_test.go +++ b/backend/internal/service/generate_session_hash_test.go @@ -841,3 +841,373 @@ func TestGenerateSessionHash_LongConversation(t *testing.T) { h2 := svc.GenerateSessionHash(parsed2) require.NotEqual(t, h, h2, "adding more messages to long conversation should change hash") } + +// ============ Gemini 原生格式 session hash 测试 ============ + +func TestGenerateSessionHash_GeminiContentsProducesHash(t *testing.T) { + svc := &GatewayService{} + + // Gemini 格式: contents[].parts[].text + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Hello from Gemini"}, + }, + }, + }, + SessionContext: &SessionContext{ + ClientIP: "1.2.3.4", + UserAgent: "gemini-cli", + APIKeyID: 1, + }, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "Gemini contents with parts should produce a non-empty hash") +} + +func TestGenerateSessionHash_GeminiDifferentContentsDifferentHash(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + parsed1 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Hello"}, + }, + }, + }, + SessionContext: ctx, + } + parsed2 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Goodbye"}, + }, + }, + }, + SessionContext: ctx, + } + + h1 := svc.GenerateSessionHash(parsed1) + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h1, h2, "different Gemini contents should produce different hashes") +} + +func TestGenerateSessionHash_GeminiSameContentsSameHash(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + mk := func() *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Hello"}, + }, + }, + map[string]any{ + "role": "model", + "parts": []any{ + map[string]any{"text": "Hi there!"}, + }, + }, + }, + SessionContext: ctx, + } + } + + h1 := svc.GenerateSessionHash(mk()) + h2 := svc.GenerateSessionHash(mk()) + require.Equal(t, h1, h2, "same Gemini contents should produce identical hash") +} + +func TestGenerateSessionHash_GeminiMultiTurnHashChanges(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + round1 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + }, + SessionContext: ctx, + } + + round2 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + map[string]any{ + "role": "model", + "parts": []any{map[string]any{"text": "Hi!"}}, + }, + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "How are you?"}}, + }, + }, + SessionContext: ctx, + } + + h1 := svc.GenerateSessionHash(round1) + h2 := svc.GenerateSessionHash(round2) + require.NotEmpty(t, h1) + require.NotEmpty(t, h2) + require.NotEqual(t, h1, h2, "Gemini multi-turn should produce different hashes per round") +} + +func TestGenerateSessionHash_GeminiDifferentUsersSameContentDifferentHash(t *testing.T) { + svc := &GatewayService{} + + // 核心场景:两个不同用户发送相同 Gemini 格式消息应得到不同 hash + user1 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: "gemini-cli", + APIKeyID: 10, + }, + } + user2 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + }, + SessionContext: &SessionContext{ + ClientIP: "2.2.2.2", + UserAgent: "gemini-cli", + APIKeyID: 20, + }, + } + + h1 := svc.GenerateSessionHash(user1) + h2 := svc.GenerateSessionHash(user2) + require.NotEqual(t, h1, h2, "CRITICAL: different Gemini users with same content must get different hashes") +} + +func TestGenerateSessionHash_GeminiSystemInstructionAffectsHash(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + // systemInstruction 经 ParseGatewayRequest 解析后存入 parsed.System + withSys := &ParsedRequest{ + System: []any{ + map[string]any{"text": "You are a coding assistant."}, + }, + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + }, + SessionContext: ctx, + } + withoutSys := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + }, + SessionContext: ctx, + } + + h1 := svc.GenerateSessionHash(withSys) + h2 := svc.GenerateSessionHash(withoutSys) + require.NotEqual(t, h1, h2, "systemInstruction should affect the hash") +} + +func TestGenerateSessionHash_GeminiMultiPartMessage(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + // 多 parts 的消息 + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Part 1"}, + map[string]any{"text": "Part 2"}, + map[string]any{"text": "Part 3"}, + }, + }, + }, + SessionContext: ctx, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "multi-part Gemini message should produce a hash") + + // 不同内容的多 parts + parsed2 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Part 1"}, + map[string]any{"text": "CHANGED"}, + map[string]any{"text": "Part 3"}, + }, + }, + }, + SessionContext: ctx, + } + + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h, h2, "changing a part should change the hash") +} + +func TestGenerateSessionHash_GeminiNonTextPartsIgnored(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + // 含非 text 类型 parts(如 inline_data),应被跳过但不报错 + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Describe this image"}, + map[string]any{"inline_data": map[string]any{"mime_type": "image/png", "data": "base64..."}}, + }, + }, + }, + SessionContext: ctx, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "Gemini message with mixed parts should still produce a hash from text parts") +} + +func TestGenerateSessionHash_GeminiMultiTurnHashNotSticky(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "10.0.0.1", UserAgent: "gemini-cli", APIKeyID: 42} + + // 模拟同一 Gemini 会话的三轮请求,每轮 contents 累积增长。 + // 验证预期行为:每轮 hash 都不同,即 GenerateSessionHash 不具备跨轮粘性。 + // 这是 by-design 的——Gemini 的跨轮粘性由 Digest Fallback(BuildGeminiDigestChain)负责。 + round1Body := []byte(`{ + "systemInstruction": {"parts": [{"text": "You are a coding assistant."}]}, + "contents": [ + {"role": "user", "parts": [{"text": "Write a Go function"}]} + ] + }`) + round2Body := []byte(`{ + "systemInstruction": {"parts": [{"text": "You are a coding assistant."}]}, + "contents": [ + {"role": "user", "parts": [{"text": "Write a Go function"}]}, + {"role": "model", "parts": [{"text": "func hello() {}"}]}, + {"role": "user", "parts": [{"text": "Add error handling"}]} + ] + }`) + round3Body := []byte(`{ + "systemInstruction": {"parts": [{"text": "You are a coding assistant."}]}, + "contents": [ + {"role": "user", "parts": [{"text": "Write a Go function"}]}, + {"role": "model", "parts": [{"text": "func hello() {}"}]}, + {"role": "user", "parts": [{"text": "Add error handling"}]}, + {"role": "model", "parts": [{"text": "func hello() error { return nil }"}]}, + {"role": "user", "parts": [{"text": "Now add tests"}]} + ] + }`) + + hashes := make([]string, 3) + for i, body := range [][]byte{round1Body, round2Body, round3Body} { + parsed, err := ParseGatewayRequest(body, "gemini") + require.NoError(t, err) + parsed.SessionContext = ctx + hashes[i] = svc.GenerateSessionHash(parsed) + require.NotEmpty(t, hashes[i], "round %d hash should not be empty", i+1) + } + + // 每轮 hash 都不同——这是预期行为 + require.NotEqual(t, hashes[0], hashes[1], "round 1 vs 2 hash should differ (contents grow)") + require.NotEqual(t, hashes[1], hashes[2], "round 2 vs 3 hash should differ (contents grow)") + require.NotEqual(t, hashes[0], hashes[2], "round 1 vs 3 hash should differ") + + // 同一轮重试应产生相同 hash + parsed1Again, err := ParseGatewayRequest(round2Body, "gemini") + require.NoError(t, err) + parsed1Again.SessionContext = ctx + h2Again := svc.GenerateSessionHash(parsed1Again) + require.Equal(t, hashes[1], h2Again, "retry of same round should produce same hash") +} + +func TestGenerateSessionHash_GeminiEndToEnd(t *testing.T) { + svc := &GatewayService{} + + // 端到端测试:模拟 ParseGatewayRequest + GenerateSessionHash 完整流程 + body := []byte(`{ + "model": "gemini-2.5-pro", + "systemInstruction": { + "parts": [{"text": "You are a coding assistant."}] + }, + "contents": [ + {"role": "user", "parts": [{"text": "Write a Go function"}]}, + {"role": "model", "parts": [{"text": "Here is a function..."}]}, + {"role": "user", "parts": [{"text": "Now add error handling"}]} + ] + }`) + + parsed, err := ParseGatewayRequest(body, "gemini") + require.NoError(t, err) + parsed.SessionContext = &SessionContext{ + ClientIP: "10.0.0.1", + UserAgent: "gemini-cli/1.0", + APIKeyID: 42, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "end-to-end Gemini flow should produce a hash") + + // 同一请求再次解析应产生相同 hash + parsed2, err := ParseGatewayRequest(body, "gemini") + require.NoError(t, err) + parsed2.SessionContext = &SessionContext{ + ClientIP: "10.0.0.1", + UserAgent: "gemini-cli/1.0", + APIKeyID: 42, + } + + h2 := svc.GenerateSessionHash(parsed2) + require.Equal(t, h, h2, "same request should produce same hash") + + // 不同用户发送相同请求应产生不同 hash + parsed3, err := ParseGatewayRequest(body, "gemini") + require.NoError(t, err) + parsed3.SessionContext = &SessionContext{ + ClientIP: "10.0.0.2", + UserAgent: "gemini-cli/1.0", + APIKeyID: 99, + } + + h3 := svc.GenerateSessionHash(parsed3) + require.NotEqual(t, h, h3, "different user with same Gemini request should get different hash") +} diff --git a/backend/internal/service/ops_retry.go b/backend/internal/service/ops_retry.go index fbc800f2..23a524ad 100644 --- a/backend/internal/service/ops_retry.go +++ b/backend/internal/service/ops_retry.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/gin-gonic/gin" @@ -528,7 +529,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry func extractRetryModelAndStream(reqType opsRetryRequestType, errorLog *OpsErrorLogDetail, body []byte) (model string, stream bool, err error) { switch reqType { case opsRetryTypeMessages: - parsed, parseErr := ParseGatewayRequest(body) + parsed, parseErr := ParseGatewayRequest(body, domain.PlatformAnthropic) if parseErr != nil { return "", false, fmt.Errorf("failed to parse messages request body: %w", parseErr) } @@ -596,7 +597,7 @@ func (s *OpsService) executeWithAccount(ctx context.Context, reqType opsRetryReq if s.gatewayService == nil { return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "gateway service not available"} } - parsedReq, parseErr := ParseGatewayRequest(body) + parsedReq, parseErr := ParseGatewayRequest(body, domain.PlatformAnthropic) if parseErr != nil { return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "failed to parse request body"} }