Previously, thoughtSignature cleanup only applied to Gemini CLI requests (detected via x-gemini-api-privileged-user-id header or tmp dir pattern). This caused 400 errors for non-CLI clients when session cache expired and they sent stale signatures. Remove the isGeminiCLIRequest guard so all clients benefit from proactive thoughtSignature cleanup on session binding miss.
661 lines
23 KiB
Go
661 lines
23 KiB
Go
package handler
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"crypto/sha256"
|
||
"encoding/hex"
|
||
"encoding/json"
|
||
"errors"
|
||
"io"
|
||
"log"
|
||
"net/http"
|
||
"regexp"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||
"github.com/google/uuid"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
|
||
// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
|
||
var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`)
|
||
|
||
// GeminiV1BetaListModels proxies:
|
||
// GET /v1beta/models
|
||
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
|
||
apiKey, ok := middleware.GetAPIKeyFromContext(c)
|
||
if !ok || apiKey == nil {
|
||
googleError(c, http.StatusUnauthorized, "Invalid API key")
|
||
return
|
||
}
|
||
// 检查平台:优先使用强制平台(/antigravity 路由),否则要求 gemini 分组
|
||
forcePlatform, hasForcePlatform := middleware.GetForcePlatformFromContext(c)
|
||
if !hasForcePlatform && (apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini) {
|
||
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
|
||
return
|
||
}
|
||
|
||
// 强制 antigravity 模式:返回 antigravity 支持的模型列表
|
||
if forcePlatform == service.PlatformAntigravity {
|
||
c.JSON(http.StatusOK, antigravity.FallbackGeminiModelsList())
|
||
return
|
||
}
|
||
|
||
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
|
||
if err != nil {
|
||
// 没有 gemini 账户,检查是否有 antigravity 账户可用
|
||
hasAntigravity, _ := h.geminiCompatService.HasAntigravityAccounts(c.Request.Context(), apiKey.GroupID)
|
||
if hasAntigravity {
|
||
// antigravity 账户使用静态模型列表
|
||
c.JSON(http.StatusOK, gemini.FallbackModelsList())
|
||
return
|
||
}
|
||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||
return
|
||
}
|
||
|
||
res, err := h.geminiCompatService.ForwardAIStudioGET(c.Request.Context(), account, "/v1beta/models")
|
||
if err != nil {
|
||
googleError(c, http.StatusBadGateway, err.Error())
|
||
return
|
||
}
|
||
if shouldFallbackGeminiModels(res) {
|
||
c.JSON(http.StatusOK, gemini.FallbackModelsList())
|
||
return
|
||
}
|
||
writeUpstreamResponse(c, res)
|
||
}
|
||
|
||
// GeminiV1BetaGetModel proxies:
|
||
// GET /v1beta/models/{model}
|
||
func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
|
||
apiKey, ok := middleware.GetAPIKeyFromContext(c)
|
||
if !ok || apiKey == nil {
|
||
googleError(c, http.StatusUnauthorized, "Invalid API key")
|
||
return
|
||
}
|
||
// 检查平台:优先使用强制平台(/antigravity 路由),否则要求 gemini 分组
|
||
forcePlatform, hasForcePlatform := middleware.GetForcePlatformFromContext(c)
|
||
if !hasForcePlatform && (apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini) {
|
||
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
|
||
return
|
||
}
|
||
|
||
modelName := strings.TrimSpace(c.Param("model"))
|
||
if modelName == "" {
|
||
googleError(c, http.StatusBadRequest, "Missing model in URL")
|
||
return
|
||
}
|
||
|
||
// 强制 antigravity 模式:返回 antigravity 模型信息
|
||
if forcePlatform == service.PlatformAntigravity {
|
||
c.JSON(http.StatusOK, antigravity.FallbackGeminiModel(modelName))
|
||
return
|
||
}
|
||
|
||
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
|
||
if err != nil {
|
||
// 没有 gemini 账户,检查是否有 antigravity 账户可用
|
||
hasAntigravity, _ := h.geminiCompatService.HasAntigravityAccounts(c.Request.Context(), apiKey.GroupID)
|
||
if hasAntigravity {
|
||
// antigravity 账户使用静态模型信息
|
||
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
|
||
return
|
||
}
|
||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||
return
|
||
}
|
||
|
||
res, err := h.geminiCompatService.ForwardAIStudioGET(c.Request.Context(), account, "/v1beta/models/"+modelName)
|
||
if err != nil {
|
||
googleError(c, http.StatusBadGateway, err.Error())
|
||
return
|
||
}
|
||
if shouldFallbackGeminiModels(res) {
|
||
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
|
||
return
|
||
}
|
||
writeUpstreamResponse(c, res)
|
||
}
|
||
|
||
// GeminiV1BetaModels proxies Gemini native REST endpoints like:
|
||
// POST /v1beta/models/{model}:generateContent
|
||
// POST /v1beta/models/{model}:streamGenerateContent?alt=sse
|
||
func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||
apiKey, ok := middleware.GetAPIKeyFromContext(c)
|
||
if !ok || apiKey == nil {
|
||
googleError(c, http.StatusUnauthorized, "Invalid API key")
|
||
return
|
||
}
|
||
authSubject, ok := middleware.GetAuthSubjectFromContext(c)
|
||
if !ok {
|
||
googleError(c, http.StatusInternalServerError, "User context not found")
|
||
return
|
||
}
|
||
|
||
// 检查平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则要求 gemini 分组
|
||
if !middleware.HasForcePlatform(c) {
|
||
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
|
||
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
|
||
return
|
||
}
|
||
}
|
||
|
||
modelName, action, err := parseGeminiModelAction(strings.TrimPrefix(c.Param("modelAction"), "/"))
|
||
if err != nil {
|
||
googleError(c, http.StatusNotFound, err.Error())
|
||
return
|
||
}
|
||
|
||
stream := action == "streamGenerateContent"
|
||
|
||
body, err := io.ReadAll(c.Request.Body)
|
||
if err != nil {
|
||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||
googleError(c, http.StatusRequestEntityTooLarge, buildBodyTooLargeMessage(maxErr.Limit))
|
||
return
|
||
}
|
||
googleError(c, http.StatusBadRequest, "Failed to read request body")
|
||
return
|
||
}
|
||
if len(body) == 0 {
|
||
googleError(c, http.StatusBadRequest, "Request body is empty")
|
||
return
|
||
}
|
||
|
||
setOpsRequestContext(c, modelName, stream, body)
|
||
|
||
// Get subscription (may be nil)
|
||
subscription, _ := middleware.GetSubscriptionFromContext(c)
|
||
|
||
// For Gemini native API, do not send Claude-style ping frames.
|
||
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0)
|
||
|
||
// 0) wait queue check
|
||
maxWait := service.CalculateMaxWait(authSubject.Concurrency)
|
||
canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait)
|
||
waitCounted := false
|
||
if err != nil {
|
||
log.Printf("Increment wait count failed: %v", err)
|
||
} else if !canWait {
|
||
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
|
||
return
|
||
}
|
||
if err == nil && canWait {
|
||
waitCounted = true
|
||
}
|
||
defer func() {
|
||
if waitCounted {
|
||
geminiConcurrency.DecrementWaitCount(c.Request.Context(), authSubject.UserID)
|
||
}
|
||
}()
|
||
|
||
// 1) user concurrency slot
|
||
streamStarted := false
|
||
if h.errorPassthroughService != nil {
|
||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||
}
|
||
userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted)
|
||
if err != nil {
|
||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||
return
|
||
}
|
||
if waitCounted {
|
||
geminiConcurrency.DecrementWaitCount(c.Request.Context(), authSubject.UserID)
|
||
waitCounted = false
|
||
}
|
||
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
|
||
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
||
if userReleaseFunc != nil {
|
||
defer userReleaseFunc()
|
||
}
|
||
|
||
// 2) billing eligibility check (after wait)
|
||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||
status, _, message := billingErrorDetails(err)
|
||
googleError(c, status, message)
|
||
return
|
||
}
|
||
|
||
// 3) select account (sticky session based on request body)
|
||
// 优先使用 Gemini CLI 的会话标识(privileged-user-id + tmp 目录哈希)
|
||
sessionHash := extractGeminiCLISessionHash(c, body)
|
||
if sessionHash == "" {
|
||
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
|
||
parsedReq, _ := service.ParseGatewayRequest(body)
|
||
sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
|
||
}
|
||
sessionKey := sessionHash
|
||
if sessionHash != "" {
|
||
sessionKey = "gemini:" + sessionHash
|
||
}
|
||
|
||
// 查询粘性会话绑定的账号 ID(用于检测账号切换)
|
||
var sessionBoundAccountID int64
|
||
if sessionKey != "" {
|
||
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
|
||
}
|
||
|
||
// === Gemini 内容摘要会话 Fallback 逻辑 ===
|
||
// 当原有会话标识无效时(sessionBoundAccountID == 0),尝试基于内容摘要链匹配
|
||
var geminiDigestChain string
|
||
var geminiPrefixHash string
|
||
var geminiSessionUUID string
|
||
useDigestFallback := sessionBoundAccountID == 0
|
||
|
||
if useDigestFallback {
|
||
// 解析 Gemini 请求体
|
||
var geminiReq antigravity.GeminiRequest
|
||
if err := json.Unmarshal(body, &geminiReq); err == nil && len(geminiReq.Contents) > 0 {
|
||
// 生成摘要链
|
||
geminiDigestChain = service.BuildGeminiDigestChain(&geminiReq)
|
||
if geminiDigestChain != "" {
|
||
// 生成前缀 hash
|
||
userAgent := c.GetHeader("User-Agent")
|
||
clientIP := ip.GetClientIP(c)
|
||
platform := ""
|
||
if apiKey.Group != nil {
|
||
platform = apiKey.Group.Platform
|
||
}
|
||
geminiPrefixHash = service.GenerateGeminiPrefixHash(
|
||
authSubject.UserID,
|
||
apiKey.ID,
|
||
clientIP,
|
||
userAgent,
|
||
platform,
|
||
modelName,
|
||
)
|
||
|
||
// 查找会话
|
||
foundUUID, foundAccountID, found := h.gatewayService.FindGeminiSession(
|
||
c.Request.Context(),
|
||
derefGroupID(apiKey.GroupID),
|
||
geminiPrefixHash,
|
||
geminiDigestChain,
|
||
)
|
||
if found {
|
||
sessionBoundAccountID = foundAccountID
|
||
geminiSessionUUID = foundUUID
|
||
log.Printf("[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s",
|
||
safeShortPrefix(foundUUID, 8), foundAccountID, truncateDigestChain(geminiDigestChain))
|
||
|
||
// 关键:如果原 sessionKey 为空,使用 prefixHash + uuid 作为 sessionKey
|
||
// 这样 SelectAccountWithLoadAwareness 的粘性会话逻辑会优先使用匹配到的账号
|
||
if sessionKey == "" {
|
||
sessionKey = service.GenerateGeminiDigestSessionKey(geminiPrefixHash, foundUUID)
|
||
}
|
||
_ = h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, foundAccountID)
|
||
} else {
|
||
// 生成新的会话 UUID
|
||
geminiSessionUUID = uuid.New().String()
|
||
// 为新会话也生成 sessionKey(用于后续请求的粘性会话)
|
||
if sessionKey == "" {
|
||
sessionKey = service.GenerateGeminiDigestSessionKey(geminiPrefixHash, geminiSessionUUID)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
|
||
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
||
cleanedForUnknownBinding := false
|
||
|
||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||
switchCount := 0
|
||
failedAccountIDs := make(map[int64]struct{})
|
||
var lastFailoverErr *service.UpstreamFailoverError
|
||
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
||
|
||
for {
|
||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
|
||
if err != nil {
|
||
if len(failedAccountIDs) == 0 {
|
||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||
return
|
||
}
|
||
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
|
||
return
|
||
}
|
||
account := selection.Account
|
||
setOpsSelectedAccount(c, account.ID)
|
||
|
||
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
|
||
// 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
|
||
if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
|
||
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
|
||
body = service.CleanGeminiNativeThoughtSignatures(body)
|
||
sessionBoundAccountID = account.ID
|
||
} else if sessionKey != "" && sessionBoundAccountID == 0 && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
|
||
// 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,客户端继续携带旧签名。
|
||
// 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。
|
||
log.Printf("[Gemini] Sticky session binding missing, cleaning thoughtSignature proactively")
|
||
body = service.CleanGeminiNativeThoughtSignatures(body)
|
||
cleanedForUnknownBinding = true
|
||
sessionBoundAccountID = account.ID
|
||
} else if sessionBoundAccountID == 0 {
|
||
// 记录本次请求中首次选择到的账号,便于同一请求内 failover 时检测切换。
|
||
sessionBoundAccountID = account.ID
|
||
}
|
||
|
||
// 4) account concurrency slot
|
||
accountReleaseFunc := selection.ReleaseFunc
|
||
if !selection.Acquired {
|
||
if selection.WaitPlan == nil {
|
||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts")
|
||
return
|
||
}
|
||
accountWaitCounted := false
|
||
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||
if err != nil {
|
||
log.Printf("Increment account wait count failed: %v", err)
|
||
} else if !canWait {
|
||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
|
||
return
|
||
}
|
||
if err == nil && canWait {
|
||
accountWaitCounted = true
|
||
}
|
||
defer func() {
|
||
if accountWaitCounted {
|
||
geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||
}
|
||
}()
|
||
|
||
accountReleaseFunc, err = geminiConcurrency.AcquireAccountSlotWithWaitTimeout(
|
||
c,
|
||
account.ID,
|
||
selection.WaitPlan.MaxConcurrency,
|
||
selection.WaitPlan.Timeout,
|
||
stream,
|
||
&streamStarted,
|
||
)
|
||
if err != nil {
|
||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||
return
|
||
}
|
||
if accountWaitCounted {
|
||
geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||
accountWaitCounted = false
|
||
}
|
||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
||
log.Printf("Bind sticky session failed: %v", err)
|
||
}
|
||
}
|
||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||
|
||
// 5) forward (根据平台分流)
|
||
var result *service.ForwardResult
|
||
requestCtx := c.Request.Context()
|
||
if switchCount > 0 {
|
||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
||
}
|
||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession)
|
||
} else {
|
||
result, err = h.geminiCompatService.ForwardNative(requestCtx, c, account, modelName, action, stream, body)
|
||
}
|
||
if accountReleaseFunc != nil {
|
||
accountReleaseFunc()
|
||
}
|
||
if err != nil {
|
||
var failoverErr *service.UpstreamFailoverError
|
||
if errors.As(err, &failoverErr) {
|
||
failedAccountIDs[account.ID] = struct{}{}
|
||
if failoverErr.ForceCacheBilling {
|
||
forceCacheBilling = true
|
||
}
|
||
if switchCount >= maxAccountSwitches {
|
||
lastFailoverErr = failoverErr
|
||
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
|
||
return
|
||
}
|
||
lastFailoverErr = failoverErr
|
||
switchCount++
|
||
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||
continue
|
||
}
|
||
// ForwardNative already wrote the response
|
||
log.Printf("Gemini native forward failed: %v", err)
|
||
return
|
||
}
|
||
|
||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||
userAgent := c.GetHeader("User-Agent")
|
||
clientIP := ip.GetClientIP(c)
|
||
|
||
// 保存 Gemini 内容摘要会话(用于 Fallback 匹配)
|
||
if useDigestFallback && geminiDigestChain != "" && geminiPrefixHash != "" {
|
||
if err := h.gatewayService.SaveGeminiSession(
|
||
c.Request.Context(),
|
||
derefGroupID(apiKey.GroupID),
|
||
geminiPrefixHash,
|
||
geminiDigestChain,
|
||
geminiSessionUUID,
|
||
account.ID,
|
||
); err != nil {
|
||
log.Printf("[Gemini] Failed to save digest session: %v", err)
|
||
}
|
||
}
|
||
|
||
// 6) record usage async (Gemini 使用长上下文双倍计费)
|
||
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string, fcb bool) {
|
||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||
defer cancel()
|
||
|
||
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
||
Result: result,
|
||
APIKey: apiKey,
|
||
User: apiKey.User,
|
||
Account: usedAccount,
|
||
Subscription: subscription,
|
||
UserAgent: ua,
|
||
IPAddress: ip,
|
||
LongContextThreshold: 200000, // Gemini 200K 阈值
|
||
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
||
ForceCacheBilling: fcb,
|
||
APIKeyService: h.apiKeyService,
|
||
}); err != nil {
|
||
log.Printf("Record usage failed: %v", err)
|
||
}
|
||
}(result, account, userAgent, clientIP, forceCacheBilling)
|
||
return
|
||
}
|
||
}
|
||
|
||
func parseGeminiModelAction(rest string) (model string, action string, err error) {
|
||
rest = strings.TrimSpace(rest)
|
||
if rest == "" {
|
||
return "", "", &pathParseError{"missing path"}
|
||
}
|
||
|
||
// Standard: {model}:{action}
|
||
if i := strings.Index(rest, ":"); i > 0 && i < len(rest)-1 {
|
||
return rest[:i], rest[i+1:], nil
|
||
}
|
||
|
||
// Fallback: {model}/{action}
|
||
if i := strings.Index(rest, "/"); i > 0 && i < len(rest)-1 {
|
||
return rest[:i], rest[i+1:], nil
|
||
}
|
||
|
||
return "", "", &pathParseError{"invalid model action path"}
|
||
}
|
||
|
||
func (h *GatewayHandler) handleGeminiFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError) {
|
||
if failoverErr == nil {
|
||
googleError(c, http.StatusBadGateway, "Upstream request failed")
|
||
return
|
||
}
|
||
|
||
statusCode := failoverErr.StatusCode
|
||
responseBody := failoverErr.ResponseBody
|
||
|
||
// 先检查透传规则
|
||
if h.errorPassthroughService != nil && len(responseBody) > 0 {
|
||
if rule := h.errorPassthroughService.MatchRule(service.PlatformGemini, statusCode, responseBody); rule != nil {
|
||
// 确定响应状态码
|
||
respCode := statusCode
|
||
if !rule.PassthroughCode && rule.ResponseCode != nil {
|
||
respCode = *rule.ResponseCode
|
||
}
|
||
|
||
// 确定响应消息
|
||
msg := service.ExtractUpstreamErrorMessage(responseBody)
|
||
if !rule.PassthroughBody && rule.CustomMessage != nil {
|
||
msg = *rule.CustomMessage
|
||
}
|
||
|
||
googleError(c, respCode, msg)
|
||
return
|
||
}
|
||
}
|
||
|
||
// 使用默认的错误映射
|
||
status, message := mapGeminiUpstreamError(statusCode)
|
||
googleError(c, status, message)
|
||
}
|
||
|
||
func mapGeminiUpstreamError(statusCode int) (int, string) {
|
||
switch statusCode {
|
||
case 401:
|
||
return http.StatusBadGateway, "Upstream authentication failed, please contact administrator"
|
||
case 403:
|
||
return http.StatusBadGateway, "Upstream access forbidden, please contact administrator"
|
||
case 429:
|
||
return http.StatusTooManyRequests, "Upstream rate limit exceeded, please retry later"
|
||
case 529:
|
||
return http.StatusServiceUnavailable, "Upstream service overloaded, please retry later"
|
||
case 500, 502, 503, 504:
|
||
return http.StatusBadGateway, "Upstream service temporarily unavailable"
|
||
default:
|
||
return http.StatusBadGateway, "Upstream request failed"
|
||
}
|
||
}
|
||
|
||
type pathParseError struct{ msg string }
|
||
|
||
func (e *pathParseError) Error() string { return e.msg }
|
||
|
||
func googleError(c *gin.Context, status int, message string) {
|
||
c.JSON(status, gin.H{
|
||
"error": gin.H{
|
||
"code": status,
|
||
"message": message,
|
||
"status": googleapi.HTTPStatusToGoogleStatus(status),
|
||
},
|
||
})
|
||
}
|
||
|
||
func writeUpstreamResponse(c *gin.Context, res *service.UpstreamHTTPResult) {
|
||
if res == nil {
|
||
googleError(c, http.StatusBadGateway, "Empty upstream response")
|
||
return
|
||
}
|
||
for k, vv := range res.Headers {
|
||
// Avoid overriding content-length and hop-by-hop headers.
|
||
if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Connection") {
|
||
continue
|
||
}
|
||
for _, v := range vv {
|
||
c.Writer.Header().Add(k, v)
|
||
}
|
||
}
|
||
contentType := res.Headers.Get("Content-Type")
|
||
if contentType == "" {
|
||
contentType = "application/json"
|
||
}
|
||
c.Data(res.StatusCode, contentType, res.Body)
|
||
}
|
||
|
||
func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
|
||
if res == nil {
|
||
return true
|
||
}
|
||
if res.StatusCode != http.StatusUnauthorized && res.StatusCode != http.StatusForbidden {
|
||
return false
|
||
}
|
||
if strings.Contains(strings.ToLower(res.Headers.Get("Www-Authenticate")), "insufficient_scope") {
|
||
return true
|
||
}
|
||
if strings.Contains(strings.ToLower(string(res.Body)), "insufficient authentication scopes") {
|
||
return true
|
||
}
|
||
if strings.Contains(strings.ToLower(string(res.Body)), "access_token_scope_insufficient") {
|
||
return true
|
||
}
|
||
return false
|
||
}
|
||
|
||
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
|
||
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
|
||
//
|
||
// 会话标识生成策略:
|
||
// 1. 从请求体中提取 tmp 目录哈希(64位十六进制)
|
||
// 2. 从 header 中提取 privileged-user-id(UUID)
|
||
// 3. 组合两者生成 SHA256 哈希作为最终的会话标识
|
||
//
|
||
// 如果找不到 tmp 目录哈希,返回空字符串(不使用粘性会话)。
|
||
//
|
||
// extractGeminiCLISessionHash extracts session identifier from Gemini CLI requests.
|
||
// Combines x-gemini-api-privileged-user-id header with tmp directory hash from request body.
|
||
func extractGeminiCLISessionHash(c *gin.Context, body []byte) string {
|
||
// 1. 从请求体中提取 tmp 目录哈希
|
||
match := geminiCLITmpDirRegex.FindSubmatch(body)
|
||
if len(match) < 2 {
|
||
return "" // 没有找到 tmp 目录,不使用粘性会话
|
||
}
|
||
tmpDirHash := string(match[1])
|
||
|
||
// 2. 提取 privileged-user-id
|
||
privilegedUserID := strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id"))
|
||
|
||
// 3. 组合生成最终的 session hash
|
||
if privilegedUserID != "" {
|
||
// 组合两个标识符:privileged-user-id + tmp 目录哈希
|
||
combined := privilegedUserID + ":" + tmpDirHash
|
||
hash := sha256.Sum256([]byte(combined))
|
||
return hex.EncodeToString(hash[:])
|
||
}
|
||
|
||
// 如果没有 privileged-user-id,直接使用 tmp 目录哈希
|
||
return tmpDirHash
|
||
}
|
||
|
||
// truncateDigestChain 截断摘要链用于日志显示
|
||
func truncateDigestChain(chain string) string {
|
||
if len(chain) <= 50 {
|
||
return chain
|
||
}
|
||
return chain[:50] + "..."
|
||
}
|
||
|
||
// safeShortPrefix 返回字符串前 n 个字符;长度不足时返回原字符串。
|
||
// 用于日志展示,避免切片越界。
|
||
func safeShortPrefix(value string, n int) string {
|
||
if n <= 0 || len(value) <= n {
|
||
return value
|
||
}
|
||
return value[:n]
|
||
}
|
||
|
||
// derefGroupID 安全解引用 *int64,nil 返回 0
|
||
func derefGroupID(groupID *int64) int64 {
|
||
if groupID == nil {
|
||
return 0
|
||
}
|
||
return *groupID
|
||
}
|