Merge pull request #393 from IanShaw027/fix/gemini-thought-signature-preserve

fix(gemini): 修复 thoughtSignature 跨账号验证错误
This commit is contained in:
Wesley Liddick
2026-01-26 09:23:46 +08:00
committed by GitHub
6 changed files with 318 additions and 10 deletions

View File

@@ -0,0 +1,122 @@
//go:build unit
package handler
import (
"crypto/sha256"
"encoding/hex"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestExtractGeminiCLISessionHash(t *testing.T) {
tests := []struct {
name string
body string
privilegedUserID string
wantEmpty bool
wantHash string
}{
{
name: "with privileged-user-id and tmp dir",
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: false,
wantHash: func() string {
combined := "90785f52-8bbe-4b17-b111-a1ddea1636c3:f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"
hash := sha256.Sum256([]byte(combined))
return hex.EncodeToString(hash[:])
}(),
},
{
name: "without privileged-user-id but with tmp dir",
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
privilegedUserID: "",
wantEmpty: false,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "without tmp dir",
body: `{"contents":[{"parts":[{"text":"Hello world"}]}]}`,
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: true,
},
{
name: "empty body",
body: "",
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 创建测试上下文
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/test", nil)
if tt.privilegedUserID != "" {
c.Request.Header.Set("x-gemini-api-privileged-user-id", tt.privilegedUserID)
}
// 调用函数
result := extractGeminiCLISessionHash(c, []byte(tt.body))
// 验证结果
if tt.wantEmpty {
require.Empty(t, result, "expected empty session hash")
} else {
require.NotEmpty(t, result, "expected non-empty session hash")
require.Equal(t, tt.wantHash, result, "session hash mismatch")
}
})
}
}
func TestGeminiCLITmpDirRegex(t *testing.T) {
tests := []struct {
name string
input string
wantMatch bool
wantHash string
}{
{
name: "valid tmp dir path",
input: "/Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
wantMatch: true,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "valid tmp dir path in text",
input: "The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740\nOther text",
wantMatch: true,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "invalid hash length",
input: "/Users/ianshaw/.gemini/tmp/abc123",
wantMatch: false,
},
{
name: "no tmp dir",
input: "Hello world",
wantMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
match := geminiCLITmpDirRegex.FindStringSubmatch(tt.input)
if tt.wantMatch {
require.NotNil(t, match, "expected regex to match")
require.Len(t, match, 2, "expected 2 capture groups")
require.Equal(t, tt.wantHash, match[1], "hash mismatch")
} else {
require.Nil(t, match, "expected regex not to match")
}
})
}
}

View File

@@ -1,11 +1,15 @@
package handler
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"io"
"log"
"net/http"
"regexp"
"strings"
"time"
@@ -19,6 +23,17 @@ import (
"github.com/gin-gonic/gin"
)
// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`)
func isGeminiCLIRequest(c *gin.Context, body []byte) bool {
if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" {
return true
}
return geminiCLITmpDirRegex.Match(body)
}
// GeminiV1BetaListModels proxies:
// GET /v1beta/models
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
@@ -214,12 +229,26 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
// 3) select account (sticky session based on request body)
parsedReq, _ := service.ParseGatewayRequest(body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 优先使用 Gemini CLI 的会话标识privileged-user-id + tmp 目录哈希)
sessionHash := extractGeminiCLISessionHash(c, body)
if sessionHash == "" {
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
parsedReq, _ := service.ParseGatewayRequest(body)
sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
}
sessionKey := sessionHash
if sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
// 查询粘性会话绑定的账号 ID用于检测账号切换
var sessionBoundAccountID int64
if sessionKey != "" {
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
}
isCLI := isGeminiCLIRequest(c, body)
cleanedForUnknownBinding := false
maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
@@ -238,6 +267,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
account := selection.Account
setOpsSelectedAccount(c, account.ID)
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
// 注意Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
body = service.CleanGeminiNativeThoughtSignatures(body)
sessionBoundAccountID = account.ID
} else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
// 无缓存绑定但请求里已有 thoughtSignature常见于缓存丢失/TTL 过期后CLI 继续携带旧签名。
// 为避免第一次转发就 400这里做一次确定性清理让新账号重新生成签名链路。
log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively")
body = service.CleanGeminiNativeThoughtSignatures(body)
cleanedForUnknownBinding = true
sessionBoundAccountID = account.ID
} else if sessionBoundAccountID == 0 {
// 记录本次请求中首次选择到的账号,便于同一请求内 failover 时检测切换。
sessionBoundAccountID = account.ID
}
// 4) account concurrency slot
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
@@ -433,3 +480,38 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
}
return false
}
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
//
// 会话标识生成策略:
// 1. 从请求体中提取 tmp 目录哈希64位十六进制
// 2. 从 header 中提取 privileged-user-idUUID
// 3. 组合两者生成 SHA256 哈希作为最终的会话标识
//
// 如果找不到 tmp 目录哈希,返回空字符串(不使用粘性会话)。
//
// extractGeminiCLISessionHash extracts session identifier from Gemini CLI requests.
// Combines x-gemini-api-privileged-user-id header with tmp directory hash from request body.
func extractGeminiCLISessionHash(c *gin.Context, body []byte) string {
// 1. 从请求体中提取 tmp 目录哈希
match := geminiCLITmpDirRegex.FindSubmatch(body)
if len(match) < 2 {
return "" // 没有找到 tmp 目录,不使用粘性会话
}
tmpDirHash := string(match[1])
// 2. 提取 privileged-user-id
privilegedUserID := strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id"))
// 3. 组合生成最终的 session hash
if privilegedUserID != "" {
// 组合两个标识符privileged-user-id + tmp 目录哈希
combined := privilegedUserID + ":" + tmpDirHash
hash := sha256.Sum256([]byte(combined))
return hex.EncodeToString(hash[:])
}
// 如果没有 privileged-user-id直接使用 tmp 目录哈希
return tmpDirHash
}

View File

@@ -367,8 +367,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
Text: block.Thinking,
Thought: true,
}
// 保留原有 signatureClaude 模型需要有效的 signature
if block.Signature != "" {
// signature 处理:
// - Claude 模型allowDummyThought=false必须是上游返回的真实 signaturedummy 视为缺失)
// - Gemini 模型allowDummyThought=true优先透传真实 signature缺失时使用 dummy signature
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
part.ThoughtSignature = block.Signature
} else if !allowDummyThought {
// Claude 模型需要有效 signature在缺失时降级为普通文本并在上层禁用 thinking mode。
@@ -407,12 +409,12 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
},
}
// tool_use 的 signature 处理:
// - Gemini 模型:使用 dummy signature跳过 thought_signature 校验
// - Claude 模型:透传上游返回的真实 signatureVertex/Google 需要完整签名链路)
if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature
} else if block.Signature != "" && block.Signature != dummyThoughtSignature {
// - Claude 模型allowDummyThought=false必须是上游返回的真实 signaturedummy 视为缺失
// - Gemini 模型allowDummyThought=true优先透传真实 signature缺失时使用 dummy signature
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
part.ThoughtSignature = block.Signature
} else if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature
}
parts = append(parts, part)

View File

@@ -100,7 +100,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}, "signature": "sig_tool_abc"}
]`
t.Run("Gemini uses dummy tool_use signature", func(t *testing.T) {
t.Run("Gemini preserves provided tool_use signature", func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, _, err := buildParts(json.RawMessage(content), toolIDToName, true)
if err != nil {
@@ -109,6 +109,23 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
if len(parts) != 1 || parts[0].FunctionCall == nil {
t.Fatalf("expected 1 functionCall part, got %+v", parts)
}
if parts[0].ThoughtSignature != "sig_tool_abc" {
t.Fatalf("expected preserved tool signature %q, got %q", "sig_tool_abc", parts[0].ThoughtSignature)
}
})
t.Run("Gemini falls back to dummy tool_use signature when missing", func(t *testing.T) {
contentNoSig := `[
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}}
]`
toolIDToName := make(map[string]string)
parts, _, err := buildParts(json.RawMessage(contentNoSig), toolIDToName, true)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
}
if len(parts) != 1 || parts[0].FunctionCall == nil {
t.Fatalf("expected 1 functionCall part, got %+v", parts)
}
if parts[0].ThoughtSignature != dummyThoughtSignature {
t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature)
}

View File

@@ -305,6 +305,19 @@ func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64,
return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, accountID, stickySessionTTL)
}
// GetCachedSessionAccountID retrieves the account ID bound to a sticky session.
// Returns 0 if no binding exists or on error.
func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID *int64, sessionHash string) (int64, error) {
if sessionHash == "" || s.cache == nil {
return 0, nil
}
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err != nil {
return 0, err
}
return accountID, nil
}
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
if parsed == nil {
return ""

View File

@@ -0,0 +1,72 @@
package service
import (
"encoding/json"
)
// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中移除 thoughtSignature 字段,
// 以避免跨账号签名验证错误。
//
// 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature
// 会导致新账号的签名验证失败。通过移除这些签名,让新账号重新生成有效的签名。
//
// CleanGeminiNativeThoughtSignatures removes thoughtSignature fields from Gemini native API requests
// to avoid cross-account signature validation errors.
//
// When sticky session switches accounts (e.g., original account becomes unavailable),
// thoughtSignatures from the old account will cause validation failures on the new account.
// By removing these signatures, we allow the new account to generate valid signatures.
func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
if len(body) == 0 {
return body
}
// 解析 JSON
var data any
if err := json.Unmarshal(body, &data); err != nil {
// 如果解析失败,返回原始 body可能不是 JSON 或格式不正确)
return body
}
// 递归清理 thoughtSignature
cleaned := cleanThoughtSignaturesRecursive(data)
// 重新序列化
result, err := json.Marshal(cleaned)
if err != nil {
// 如果序列化失败,返回原始 body
return body
}
return result
}
// cleanThoughtSignaturesRecursive 递归遍历数据结构,移除所有 thoughtSignature 字段
func cleanThoughtSignaturesRecursive(data any) any {
switch v := data.(type) {
case map[string]any:
// 创建新的 map移除 thoughtSignature
result := make(map[string]any, len(v))
for key, value := range v {
// 跳过 thoughtSignature 字段
if key == "thoughtSignature" {
continue
}
// 递归处理嵌套结构
result[key] = cleanThoughtSignaturesRecursive(value)
}
return result
case []any:
// 递归处理数组中的每个元素
result := make([]any, len(v))
for i, item := range v {
result[i] = cleanThoughtSignaturesRecursive(item)
}
return result
default:
// 基本类型string, number, bool, null直接返回
return v
}
}