diff --git a/backend/internal/handler/gemini_cli_session_test.go b/backend/internal/handler/gemini_cli_session_test.go new file mode 100644 index 00000000..0b37f5f2 --- /dev/null +++ b/backend/internal/handler/gemini_cli_session_test.go @@ -0,0 +1,122 @@ +//go:build unit + +package handler + +import ( + "crypto/sha256" + "encoding/hex" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestExtractGeminiCLISessionHash(t *testing.T) { + tests := []struct { + name string + body string + privilegedUserID string + wantEmpty bool + wantHash string + }{ + { + name: "with privileged-user-id and tmp dir", + body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`, + privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3", + wantEmpty: false, + wantHash: func() string { + combined := "90785f52-8bbe-4b17-b111-a1ddea1636c3:f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740" + hash := sha256.Sum256([]byte(combined)) + return hex.EncodeToString(hash[:]) + }(), + }, + { + name: "without privileged-user-id but with tmp dir", + body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`, + privilegedUserID: "", + wantEmpty: false, + wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740", + }, + { + name: "without tmp dir", + body: `{"contents":[{"parts":[{"text":"Hello world"}]}]}`, + privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3", + wantEmpty: true, + }, + { + name: "empty body", + body: "", + privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3", + wantEmpty: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 创建测试上下文 + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("POST", "/test", nil) + if tt.privilegedUserID != "" { + c.Request.Header.Set("x-gemini-api-privileged-user-id", tt.privilegedUserID) + } + + // 调用函数 + result := extractGeminiCLISessionHash(c, []byte(tt.body)) + + // 验证结果 + if tt.wantEmpty { + require.Empty(t, result, "expected empty session hash") + } else { + require.NotEmpty(t, result, "expected non-empty session hash") + require.Equal(t, tt.wantHash, result, "session hash mismatch") + } + }) + } +} + +func TestGeminiCLITmpDirRegex(t *testing.T) { + tests := []struct { + name string + input string + wantMatch bool + wantHash string + }{ + { + name: "valid tmp dir path", + input: "/Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740", + wantMatch: true, + wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740", + }, + { + name: "valid tmp dir path in text", + input: "The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740\nOther text", + wantMatch: true, + wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740", + }, + { + name: "invalid hash length", + input: "/Users/ianshaw/.gemini/tmp/abc123", + wantMatch: false, + }, + { + name: "no tmp dir", + input: "Hello world", + wantMatch: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + match := geminiCLITmpDirRegex.FindStringSubmatch(tt.input) + if tt.wantMatch { + require.NotNil(t, match, "expected regex to match") + require.Len(t, match, 2, "expected 2 capture groups") + require.Equal(t, tt.wantHash, match[1], "hash mismatch") + } else { + require.Nil(t, match, "expected regex not to match") + } + }) + } +} diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index c7646b38..32f83013 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -1,11 +1,15 @@ package handler import ( + "bytes" "context" + "crypto/sha256" + "encoding/hex" "errors" "io" "log" "net/http" + "regexp" "strings" "time" @@ -19,6 +23,17 @@ import ( "github.com/gin-gonic/gin" ) +// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值 +// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希] +var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`) + +func isGeminiCLIRequest(c *gin.Context, body []byte) bool { + if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" { + return true + } + return geminiCLITmpDirRegex.Match(body) +} + // GeminiV1BetaListModels proxies: // GET /v1beta/models func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) { @@ -214,12 +229,26 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } // 3) select account (sticky session based on request body) - parsedReq, _ := service.ParseGatewayRequest(body) - sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) + // 优先使用 Gemini CLI 的会话标识(privileged-user-id + tmp 目录哈希) + sessionHash := extractGeminiCLISessionHash(c, body) + if sessionHash == "" { + // Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端) + parsedReq, _ := service.ParseGatewayRequest(body) + sessionHash = h.gatewayService.GenerateSessionHash(parsedReq) + } sessionKey := sessionHash if sessionHash != "" { sessionKey = "gemini:" + sessionHash } + + // 查询粘性会话绑定的账号 ID(用于检测账号切换) + var sessionBoundAccountID int64 + if sessionKey != "" { + sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey) + } + isCLI := isGeminiCLIRequest(c, body) + cleanedForUnknownBinding := false + maxAccountSwitches := h.maxAccountSwitchesGemini switchCount := 0 failedAccountIDs := make(map[int64]struct{}) @@ -238,6 +267,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { account := selection.Account setOpsSelectedAccount(c, account.ID) + // 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature + // 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。 + if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID { + log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID) + body = service.CleanGeminiNativeThoughtSignatures(body) + sessionBoundAccountID = account.ID + } else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) { + // 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,CLI 继续携带旧签名。 + // 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。 + log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively") + body = service.CleanGeminiNativeThoughtSignatures(body) + cleanedForUnknownBinding = true + sessionBoundAccountID = account.ID + } else if sessionBoundAccountID == 0 { + // 记录本次请求中首次选择到的账号,便于同一请求内 failover 时检测切换。 + sessionBoundAccountID = account.ID + } + // 4) account concurrency slot accountReleaseFunc := selection.ReleaseFunc if !selection.Acquired { @@ -433,3 +480,38 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool { } return false } + +// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。 +// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。 +// +// 会话标识生成策略: +// 1. 从请求体中提取 tmp 目录哈希(64位十六进制) +// 2. 从 header 中提取 privileged-user-id(UUID) +// 3. 组合两者生成 SHA256 哈希作为最终的会话标识 +// +// 如果找不到 tmp 目录哈希,返回空字符串(不使用粘性会话)。 +// +// extractGeminiCLISessionHash extracts session identifier from Gemini CLI requests. +// Combines x-gemini-api-privileged-user-id header with tmp directory hash from request body. +func extractGeminiCLISessionHash(c *gin.Context, body []byte) string { + // 1. 从请求体中提取 tmp 目录哈希 + match := geminiCLITmpDirRegex.FindSubmatch(body) + if len(match) < 2 { + return "" // 没有找到 tmp 目录,不使用粘性会话 + } + tmpDirHash := string(match[1]) + + // 2. 提取 privileged-user-id + privilegedUserID := strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) + + // 3. 组合生成最终的 session hash + if privilegedUserID != "" { + // 组合两个标识符:privileged-user-id + tmp 目录哈希 + combined := privilegedUserID + ":" + tmpDirHash + hash := sha256.Sum256([]byte(combined)) + return hex.EncodeToString(hash[:]) + } + + // 如果没有 privileged-user-id,直接使用 tmp 目录哈希 + return tmpDirHash +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 9565da29..a8f5baeb 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -305,6 +305,19 @@ func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64, return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, accountID, stickySessionTTL) } +// GetCachedSessionAccountID retrieves the account ID bound to a sticky session. +// Returns 0 if no binding exists or on error. +func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID *int64, sessionHash string) (int64, error) { + if sessionHash == "" || s.cache == nil { + return 0, nil + } + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + if err != nil { + return 0, err + } + return accountID, nil +} + func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { if parsed == nil { return "" diff --git a/backend/internal/service/gemini_native_signature_cleaner.go b/backend/internal/service/gemini_native_signature_cleaner.go new file mode 100644 index 00000000..b3352fb0 --- /dev/null +++ b/backend/internal/service/gemini_native_signature_cleaner.go @@ -0,0 +1,72 @@ +package service + +import ( + "encoding/json" +) + +// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中移除 thoughtSignature 字段, +// 以避免跨账号签名验证错误。 +// +// 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature +// 会导致新账号的签名验证失败。通过移除这些签名,让新账号重新生成有效的签名。 +// +// CleanGeminiNativeThoughtSignatures removes thoughtSignature fields from Gemini native API requests +// to avoid cross-account signature validation errors. +// +// When sticky session switches accounts (e.g., original account becomes unavailable), +// thoughtSignatures from the old account will cause validation failures on the new account. +// By removing these signatures, we allow the new account to generate valid signatures. +func CleanGeminiNativeThoughtSignatures(body []byte) []byte { + if len(body) == 0 { + return body + } + + // 解析 JSON + var data any + if err := json.Unmarshal(body, &data); err != nil { + // 如果解析失败,返回原始 body(可能不是 JSON 或格式不正确) + return body + } + + // 递归清理 thoughtSignature + cleaned := cleanThoughtSignaturesRecursive(data) + + // 重新序列化 + result, err := json.Marshal(cleaned) + if err != nil { + // 如果序列化失败,返回原始 body + return body + } + + return result +} + +// cleanThoughtSignaturesRecursive 递归遍历数据结构,移除所有 thoughtSignature 字段 +func cleanThoughtSignaturesRecursive(data any) any { + switch v := data.(type) { + case map[string]any: + // 创建新的 map,移除 thoughtSignature + result := make(map[string]any, len(v)) + for key, value := range v { + // 跳过 thoughtSignature 字段 + if key == "thoughtSignature" { + continue + } + // 递归处理嵌套结构 + result[key] = cleanThoughtSignaturesRecursive(value) + } + return result + + case []any: + // 递归处理数组中的每个元素 + result := make([]any, len(v)) + for i, item := range v { + result[i] = cleanThoughtSignaturesRecursive(item) + } + return result + + default: + // 基本类型(string, number, bool, null)直接返回 + return v + } +}