Files
sub2api/backend/internal/service/gemini_messages_compat_service.go
song 0059a232a6 feat(gemini): 为 Gemini 原生平台添加图片计费支持
对齐 Antigravity 平台的图片计费逻辑:
- 添加 extractImageSize() 方法提取图片尺寸
- Forward() 和 ForwardNative() 返回 ImageCount/ImageSize
- 支持分组自定义图片价格和倍率
2026-01-26 20:51:40 +08:00

3075 lines
89 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"bufio"
"bytes"
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"math"
mathrand "math/rand"
"net/http"
"regexp"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
)
const geminiStickySessionTTL = time.Hour
const (
geminiMaxRetries = 5
geminiRetryBaseDelay = 1 * time.Second
geminiRetryMaxDelay = 16 * time.Second
)
type GeminiMessagesCompatService struct {
accountRepo AccountRepository
groupRepo GroupRepository
cache GatewayCache
schedulerSnapshot *SchedulerSnapshotService
tokenProvider *GeminiTokenProvider
rateLimitService *RateLimitService
httpUpstream HTTPUpstream
antigravityGatewayService *AntigravityGatewayService
cfg *config.Config
}
func NewGeminiMessagesCompatService(
accountRepo AccountRepository,
groupRepo GroupRepository,
cache GatewayCache,
schedulerSnapshot *SchedulerSnapshotService,
tokenProvider *GeminiTokenProvider,
rateLimitService *RateLimitService,
httpUpstream HTTPUpstream,
antigravityGatewayService *AntigravityGatewayService,
cfg *config.Config,
) *GeminiMessagesCompatService {
return &GeminiMessagesCompatService{
accountRepo: accountRepo,
groupRepo: groupRepo,
cache: cache,
schedulerSnapshot: schedulerSnapshot,
tokenProvider: tokenProvider,
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
antigravityGatewayService: antigravityGatewayService,
cfg: cfg,
}
}
// GetTokenProvider returns the token provider for OAuth accounts
func (s *GeminiMessagesCompatService) GetTokenProvider() *GeminiTokenProvider {
return s.tokenProvider
}
func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil)
}
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 1. 确定目标平台和调度模式
// Determine target platform and scheduling mode
platform, useMixedScheduling, hasForcePlatform, err := s.resolvePlatformAndSchedulingMode(ctx, groupID)
if err != nil {
return nil, err
}
cacheKey := "gemini:" + sessionHash
// 2. 尝试粘性会话命中
// Try sticky session hit
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs, platform, useMixedScheduling); account != nil {
return account, nil
}
// 3. 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
// Query schedulable accounts (force platform mode: try group first, fallback to all)
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, platform, hasForcePlatform)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
// 强制平台模式下,分组中找不到账户时回退查询全部
if len(accounts) == 0 && groupID != nil && hasForcePlatform {
accounts, err = s.listSchedulableAccountsOnce(ctx, nil, platform, hasForcePlatform)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
}
// 4. 按优先级 + LRU 选择最佳账号
// Select best account by priority + LRU
selected := s.selectBestGeminiAccount(ctx, accounts, requestedModel, excludedIDs, platform, useMixedScheduling)
if selected == nil {
if requestedModel != "" {
return nil, fmt.Errorf("no available Gemini accounts supporting model: %s", requestedModel)
}
return nil, errors.New("no available Gemini accounts")
}
// 5. 设置粘性会话绑定
// Set sticky session binding
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL)
}
return selected, nil
}
// resolvePlatformAndSchedulingMode 解析目标平台和调度模式。
// 返回:平台名称、是否使用混合调度、是否强制平台、错误。
//
// resolvePlatformAndSchedulingMode resolves target platform and scheduling mode.
// Returns: platform name, whether to use mixed scheduling, whether force platform, error.
func (s *GeminiMessagesCompatService) resolvePlatformAndSchedulingMode(ctx context.Context, groupID *int64) (platform string, useMixedScheduling bool, hasForcePlatform bool, err error) {
// 优先检查 context 中的强制平台(/antigravity 路由)
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
return forcePlatform, false, true, nil
}
if groupID != nil {
// 根据分组 platform 决定查询哪种账号
var group *Group
if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID {
group = ctxGroup
} else {
group, err = s.groupRepo.GetByIDLite(ctx, *groupID)
if err != nil {
return "", false, false, fmt.Errorf("get group failed: %w", err)
}
}
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
return group.Platform, group.Platform == PlatformGemini, false, nil
}
// 无分组时只使用原生 gemini 平台
return PlatformGemini, true, false, nil
}
// tryStickySessionHit 尝试从粘性会话获取账号。
// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。
//
// tryStickySessionHit attempts to get account from sticky session.
// Returns account if hit and usable; clears session and returns nil if account unavailable.
func (s *GeminiMessagesCompatService) tryStickySessionHit(
ctx context.Context,
groupID *int64,
sessionHash, cacheKey, requestedModel string,
excludedIDs map[int64]struct{},
platform string,
useMixedScheduling bool,
) *Account {
if sessionHash == "" {
return nil
}
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
if err != nil || accountID <= 0 {
return nil
}
if _, excluded := excludedIDs[accountID]; excluded {
return nil
}
account, err := s.getSchedulableAccount(ctx, accountID)
if err != nil {
return nil
}
// 检查账号是否需要清理粘性会话
// Check if sticky session should be cleared
if shouldClearStickySession(account) {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
return nil
}
// 验证账号是否可用于当前请求
// Verify account is usable for current request
if !s.isAccountUsableForRequest(ctx, account, requestedModel, platform, useMixedScheduling) {
return nil
}
// 刷新会话 TTL 并返回账号
// Refresh session TTL and return account
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL)
return account
}
// isAccountUsableForRequest 检查账号是否可用于当前请求。
// 验证:模型调度、模型支持、平台匹配、速率限制预检。
//
// isAccountUsableForRequest checks if account is usable for current request.
// Validates: model scheduling, model support, platform matching, rate limit precheck.
func (s *GeminiMessagesCompatService) isAccountUsableForRequest(
ctx context.Context,
account *Account,
requestedModel, platform string,
useMixedScheduling bool,
) bool {
// 检查模型调度能力
// Check model scheduling capability
if !account.IsSchedulableForModel(requestedModel) {
return false
}
// 检查模型支持
// Check model support
if requestedModel != "" && !s.isModelSupportedByAccount(account, requestedModel) {
return false
}
// 检查平台匹配
// Check platform matching
if !s.isAccountValidForPlatform(account, platform, useMixedScheduling) {
return false
}
// 速率限制预检
// Rate limit precheck
if !s.passesRateLimitPreCheck(ctx, account, requestedModel) {
return false
}
return true
}
// isAccountValidForPlatform 检查账号是否匹配目标平台。
// 原生平台直接匹配;混合调度模式下 antigravity 需要启用 mixed_scheduling。
//
// isAccountValidForPlatform checks if account matches target platform.
// Native platform matches directly; mixed scheduling mode requires antigravity to enable mixed_scheduling.
func (s *GeminiMessagesCompatService) isAccountValidForPlatform(account *Account, platform string, useMixedScheduling bool) bool {
if account.Platform == platform {
return true
}
if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
return true
}
return false
}
// passesRateLimitPreCheck 执行速率限制预检。
// 返回 true 表示通过预检或无需预检。
//
// passesRateLimitPreCheck performs rate limit precheck.
// Returns true if passed or precheck not required.
func (s *GeminiMessagesCompatService) passesRateLimitPreCheck(ctx context.Context, account *Account, requestedModel string) bool {
if s.rateLimitService == nil || requestedModel == "" {
return true
}
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
}
return ok
}
// selectBestGeminiAccount 从候选账号中选择最佳账号(优先级 + LRU + OAuth 优先)。
// 返回 nil 表示无可用账号。
//
// selectBestGeminiAccount selects best account from candidates (priority + LRU + OAuth preferred).
// Returns nil if no available account.
func (s *GeminiMessagesCompatService) selectBestGeminiAccount(
ctx context.Context,
accounts []Account,
requestedModel string,
excludedIDs map[int64]struct{},
platform string,
useMixedScheduling bool,
) *Account {
var selected *Account
for i := range accounts {
acc := &accounts[i]
// 跳过被排除的账号
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// 检查账号是否可用于当前请求
if !s.isAccountUsableForRequest(ctx, acc, requestedModel, platform, useMixedScheduling) {
continue
}
// 选择最佳账号
if selected == nil {
selected = acc
continue
}
if s.isBetterGeminiAccount(acc, selected) {
selected = acc
}
}
return selected
}
// isBetterGeminiAccount 判断 candidate 是否比 current 更优。
// 规则优先级更高数值更小优先同优先级时未使用过的优先OAuth > 非 OAuth其次是最久未使用的。
//
// isBetterGeminiAccount checks if candidate is better than current.
// Rules: higher priority (lower value) wins; same priority: never used (OAuth > non-OAuth) > least recently used.
func (s *GeminiMessagesCompatService) isBetterGeminiAccount(candidate, current *Account) bool {
// 优先级更高(数值更小)
if candidate.Priority < current.Priority {
return true
}
if candidate.Priority > current.Priority {
return false
}
// 同优先级,比较最后使用时间
switch {
case candidate.LastUsedAt == nil && current.LastUsedAt != nil:
// candidate 从未使用,优先
return true
case candidate.LastUsedAt != nil && current.LastUsedAt == nil:
// current 从未使用,保持
return false
case candidate.LastUsedAt == nil && current.LastUsedAt == nil:
// 都未使用,优先选择 OAuth 账号(更兼容 Code Assist 流程)
return candidate.Type == AccountTypeOAuth && current.Type != AccountTypeOAuth
default:
// 都使用过,选择最久未使用的
return candidate.LastUsedAt.Before(*current.LastUsedAt)
}
}
// isModelSupportedByAccount 根据账户平台检查模型支持
func (s *GeminiMessagesCompatService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {
return IsAntigravityModelSupported(requestedModel)
}
return account.IsModelSupported(requestedModel)
}
// GetAntigravityGatewayService 返回 AntigravityGatewayService
func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *AntigravityGatewayService {
return s.antigravityGatewayService
}
func (s *GeminiMessagesCompatService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
if s.schedulerSnapshot != nil {
return s.schedulerSnapshot.GetAccount(ctx, accountID)
}
return s.accountRepo.GetByID(ctx, accountID)
}
func (s *GeminiMessagesCompatService) listSchedulableAccountsOnce(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, error) {
if s.schedulerSnapshot != nil {
accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
return accounts, err
}
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
queryPlatforms := []string{platform}
if useMixedScheduling {
queryPlatforms = []string{platform, PlatformAntigravity}
}
if groupID != nil {
return s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
}
return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
}
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, PlatformAntigravity, false)
if err != nil {
return false, err
}
return len(accounts) > 0, nil
}
// SelectAccountForAIStudioEndpoints selects an account that is likely to succeed against
// generativelanguage.googleapis.com (e.g. GET /v1beta/models).
//
// Preference order:
// 1) API key accounts (AI Studio)
// 2) OAuth accounts without project_id (AI Studio OAuth)
// 3) OAuth accounts explicitly marked as ai_studio
// 4) Any remaining Gemini accounts (fallback)
func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*Account, error) {
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, PlatformGemini, true)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
if len(accounts) == 0 {
return nil, errors.New("no available Gemini accounts")
}
rank := func(a *Account) int {
if a == nil {
return 999
}
switch a.Type {
case AccountTypeAPIKey:
if strings.TrimSpace(a.GetCredential("api_key")) != "" {
return 0
}
return 9
case AccountTypeOAuth:
if strings.TrimSpace(a.GetCredential("project_id")) == "" {
return 1
}
if strings.TrimSpace(a.GetCredential("oauth_type")) == "ai_studio" {
return 2
}
// Code Assist OAuth tokens often lack AI Studio scopes for models listing.
return 3
default:
return 10
}
}
var selected *Account
for i := range accounts {
acc := &accounts[i]
if selected == nil {
selected = acc
continue
}
r1, r2 := rank(acc), rank(selected)
if r1 < r2 {
selected = acc
continue
}
if r1 > r2 {
continue
}
if acc.Priority < selected.Priority {
selected = acc
} else if acc.Priority == selected.Priority {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
}
}
}
}
if selected == nil {
return nil, errors.New("no available Gemini accounts")
}
return selected, nil
}
func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
startTime := time.Now()
var req struct {
Model string `json:"model"`
Stream bool `json:"stream"`
}
if err := json.Unmarshal(body, &req); err != nil {
return nil, fmt.Errorf("parse request: %w", err)
}
if strings.TrimSpace(req.Model) == "" {
return nil, fmt.Errorf("missing model")
}
originalModel := req.Model
mappedModel := req.Model
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(req.Model)
}
geminiReq, err := convertClaudeMessagesToGeminiGenerateContent(body)
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
}
originalClaudeBody := body
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
var requestIDHeader string
var buildReq func(ctx context.Context) (*http.Request, string, error)
useUpstreamStream := req.Stream
if account.Type == AccountTypeOAuth && !req.Stream && strings.TrimSpace(account.GetCredential("project_id")) != "" {
// Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate.
useUpstreamStream = true
}
switch account.Type {
case AccountTypeAPIKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" {
return nil, "", errors.New("gemini api_key not configured")
}
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
action := "generateContent"
if req.Stream {
action = "streamGenerateContent"
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
if req.Stream {
fullURL += "?alt=sse"
}
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiReq))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("x-goog-api-key", apiKey)
return upstreamReq, "x-request-id", nil
}
requestIDHeader = "x-request-id"
case AccountTypeOAuth:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
if s.tokenProvider == nil {
return nil, "", errors.New("gemini token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, "", err
}
projectID := strings.TrimSpace(account.GetCredential("project_id"))
action := "generateContent"
if useUpstreamStream {
action = "streamGenerateContent"
}
// Two modes for OAuth:
// 1. With project_id -> Code Assist API (wrapped request)
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if projectID != "" {
// Mode 1: Code Assist API
baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), action)
if useUpstreamStream {
fullURL += "?alt=sse"
}
wrapped := map[string]any{
"model": mappedModel,
"project": projectID,
}
var inner any
if err := json.Unmarshal(geminiReq, &inner); err != nil {
return nil, "", fmt.Errorf("failed to parse gemini request: %w", err)
}
wrapped["request"] = inner
wrappedBytes, _ := json.Marshal(wrapped)
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBytes))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
upstreamReq.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
return upstreamReq, "x-request-id", nil
} else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
if useUpstreamStream {
fullURL += "?alt=sse"
}
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiReq))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
return upstreamReq, "x-request-id", nil
}
}
requestIDHeader = "x-request-id"
default:
return nil, fmt.Errorf("unsupported account type: %s", account.Type)
}
var resp *http.Response
signatureRetryStage := 0
for attempt := 1; attempt <= geminiMaxRetries; attempt++ {
upstreamReq, idHeader, err := buildReq(ctx)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, err
}
// Local build error: don't retry.
if strings.Contains(err.Error(), "missing project_id") {
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
}
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", err.Error())
}
requestIDHeader = idHeader
// Capture upstream request body for ops retry of this attempt.
if c != nil {
// In this code path `body` is already the JSON sent to upstream.
c.Set(OpsUpstreamRequestBodyKey, string(body))
}
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
if attempt < geminiMaxRetries {
log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
sleepGeminiBackoff(attempt)
continue
}
setOpsUpstreamError(c, 0, safeErr, "")
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries: "+safeErr)
}
// Special-case: signature/thought_signature validation errors are not transient, but may be fixed by
// downgrading Claude thinking/tool history to plain text (conservative two-stage retry).
if resp.StatusCode == http.StatusBadRequest && signatureRetryStage < 2 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if isGeminiSignatureRelatedError(respBody) {
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "signature_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
var strippedClaudeBody []byte
stageName := ""
switch signatureRetryStage {
case 0:
// Stage 1: disable thinking + thinking->text
strippedClaudeBody = FilterThinkingBlocksForRetry(originalClaudeBody)
stageName = "thinking-only"
signatureRetryStage = 1
default:
// Stage 2: additionally downgrade tool_use/tool_result blocks to text
strippedClaudeBody = FilterSignatureSensitiveBlocksForRetry(originalClaudeBody)
stageName = "thinking+tools"
signatureRetryStage = 2
}
retryGeminiReq, txErr := convertClaudeMessagesToGeminiGenerateContent(strippedClaudeBody)
if txErr == nil {
log.Printf("Gemini account %d: detected signature-related 400, retrying with downgraded Claude blocks (%s)", account.ID, stageName)
geminiReq = retryGeminiReq
// Consume one retry budget attempt and continue with the updated request payload.
sleepGeminiBackoff(1)
continue
}
}
// Restore body for downstream error handling.
resp = &http.Response{
StatusCode: http.StatusBadRequest,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break
}
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
// Don't treat insufficient-scope as transient.
if resp.StatusCode == 403 && isGeminiInsufficientScope(resp.Header, respBody) {
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break
}
if resp.StatusCode == 429 {
// Mark as rate-limited early so concurrent requests avoid this account.
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
if attempt < geminiMaxRetries {
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "retry",
Message: upstreamMsg,
Detail: upstreamDetail,
})
log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
sleepGeminiBackoff(attempt)
continue
}
// Final attempt: surface the upstream error body (mapped below) instead of a generic retry error.
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break
}
break
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
tempMatched := false
if s.rateLimitService != nil {
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
}
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if tempMatched {
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
return nil, s.writeGeminiMappedError(c, account, resp.StatusCode, upstreamReqID, respBody)
}
requestID := resp.Header.Get(requestIDHeader)
if requestID == "" {
requestID = resp.Header.Get("x-goog-request-id")
}
if requestID != "" {
c.Header("x-request-id", requestID)
}
var usage *ClaudeUsage
var firstTokenMs *int
if req.Stream {
streamRes, err := s.handleStreamingResponse(c, resp, startTime, originalModel)
if err != nil {
return nil, err
}
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
} else {
if useUpstreamStream {
collected, usageObj, err := collectGeminiSSE(resp.Body, true)
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream stream")
}
claudeResp, usageObj2 := convertGeminiToClaudeMessage(collected, originalModel)
c.JSON(http.StatusOK, claudeResp)
usage = usageObj2
if usageObj != nil && (usageObj.InputTokens > 0 || usageObj.OutputTokens > 0) {
usage = usageObj
}
} else {
usage, err = s.handleNonStreamingResponse(c, resp, originalModel)
if err != nil {
return nil, err
}
}
}
// 图片生成计费
imageCount := 0
imageSize := s.extractImageSize(body)
if isImageGenerationModel(originalModel) {
imageCount = 1
}
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
Model: originalModel,
Stream: req.Stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ImageCount: imageCount,
ImageSize: imageSize,
}, nil
}
func isGeminiSignatureRelatedError(respBody []byte) bool {
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
if msg == "" {
msg = strings.ToLower(string(respBody))
}
return strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature")
}
func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
startTime := time.Now()
if strings.TrimSpace(originalModel) == "" {
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing model in URL")
}
if strings.TrimSpace(action) == "" {
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing action in URL")
}
if len(body) == 0 {
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
}
switch action {
case "generateContent", "streamGenerateContent", "countTokens":
// ok
default:
return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action)
}
mappedModel := originalModel
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(originalModel)
}
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
useUpstreamStream := stream
upstreamAction := action
if account.Type == AccountTypeOAuth && !stream && action == "generateContent" && strings.TrimSpace(account.GetCredential("project_id")) != "" {
// Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate.
useUpstreamStream = true
upstreamAction = "streamGenerateContent"
}
forceAIStudio := action == "countTokens"
var requestIDHeader string
var buildReq func(ctx context.Context) (*http.Request, string, error)
switch account.Type {
case AccountTypeAPIKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" {
return nil, "", errors.New("gemini api_key not configured")
}
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(body))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("x-goog-api-key", apiKey)
return upstreamReq, "x-request-id", nil
}
requestIDHeader = "x-request-id"
case AccountTypeOAuth:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
if s.tokenProvider == nil {
return nil, "", errors.New("gemini token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, "", err
}
projectID := strings.TrimSpace(account.GetCredential("project_id"))
// Two modes for OAuth:
// 1. With project_id -> Code Assist API (wrapped request)
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if projectID != "" && !forceAIStudio {
// Mode 1: Code Assist API
baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
wrapped := map[string]any{
"model": mappedModel,
"project": projectID,
}
var inner any
if err := json.Unmarshal(body, &inner); err != nil {
return nil, "", fmt.Errorf("failed to parse gemini request: %w", err)
}
wrapped["request"] = inner
wrappedBytes, _ := json.Marshal(wrapped)
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBytes))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
upstreamReq.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
return upstreamReq, "x-request-id", nil
} else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(body))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
return upstreamReq, "x-request-id", nil
}
}
requestIDHeader = "x-request-id"
default:
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Unsupported account type: "+account.Type)
}
var resp *http.Response
for attempt := 1; attempt <= geminiMaxRetries; attempt++ {
upstreamReq, idHeader, err := buildReq(ctx)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, err
}
// Local build error: don't retry.
if strings.Contains(err.Error(), "missing project_id") {
return nil, s.writeGoogleError(c, http.StatusBadRequest, err.Error())
}
return nil, s.writeGoogleError(c, http.StatusBadGateway, err.Error())
}
requestIDHeader = idHeader
// Capture upstream request body for ops retry of this attempt.
if c != nil {
// In this code path `body` is already the JSON sent to upstream.
c.Set(OpsUpstreamRequestBodyKey, string(body))
}
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
if attempt < geminiMaxRetries {
log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
sleepGeminiBackoff(attempt)
continue
}
if action == "countTokens" {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
}
setOpsUpstreamError(c, 0, safeErr, "")
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries: "+safeErr)
}
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
// Don't treat insufficient-scope as transient.
if resp.StatusCode == 403 && isGeminiInsufficientScope(resp.Header, respBody) {
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break
}
if resp.StatusCode == 429 {
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
if attempt < geminiMaxRetries {
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "retry",
Message: upstreamMsg,
Detail: upstreamDetail,
})
log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
sleepGeminiBackoff(attempt)
continue
}
if action == "countTokens" {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
}
// Final attempt: surface the upstream error body (passed through below) instead of a generic retry error.
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break
}
break
}
defer func() { _ = resp.Body.Close() }()
requestID := resp.Header.Get(requestIDHeader)
if requestID == "" {
requestID = resp.Header.Get("x-goog-request-id")
}
if requestID != "" {
c.Header("x-request-id", requestID)
}
isOAuth := account.Type == AccountTypeOAuth
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
tempMatched := false
if s.rateLimitService != nil {
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
}
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens.
// This avoids Gemini SDKs failing hard during preflight token counting.
if action == "countTokens" && isOAuth && isGeminiInsufficientScope(resp.Header, respBody) {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{
RequestID: requestID,
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
}
if tempMatched {
evBody := unwrapIfNeeded(isOAuth, respBody)
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(evBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
evBody := unwrapIfNeeded(isOAuth, respBody)
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(evBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
respBody = unwrapIfNeeded(isOAuth, respBody)
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
log.Printf("[Gemini] native upstream error %d: %s", resp.StatusCode, truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes))
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "http_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}
c.Data(resp.StatusCode, contentType, respBody)
if upstreamMsg == "" {
return nil, fmt.Errorf("gemini upstream error: %d", resp.StatusCode)
}
return nil, fmt.Errorf("gemini upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
}
var usage *ClaudeUsage
var firstTokenMs *int
if stream {
streamRes, err := s.handleNativeStreamingResponse(c, resp, startTime, isOAuth)
if err != nil {
return nil, err
}
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
} else {
if useUpstreamStream {
collected, usageObj, err := collectGeminiSSE(resp.Body, isOAuth)
if err != nil {
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Failed to read upstream stream")
}
b, _ := json.Marshal(collected)
c.Data(http.StatusOK, "application/json", b)
usage = usageObj
} else {
usageResp, err := s.handleNativeNonStreamingResponse(c, resp, isOAuth)
if err != nil {
return nil, err
}
usage = usageResp
}
}
if usage == nil {
usage = &ClaudeUsage{}
}
// 图片生成计费
imageCount := 0
imageSize := s.extractImageSize(body)
if isImageGenerationModel(originalModel) {
imageCount = 1
}
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
Model: originalModel,
Stream: stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ImageCount: imageCount,
ImageSize: imageSize,
}, nil
}
func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Account, statusCode int) bool {
switch statusCode {
case 429, 500, 502, 503, 504, 529:
return true
case 403:
// GeminiCli OAuth occasionally returns 403 transiently (activation/quota propagation); allow retry.
if account == nil || account.Type != AccountTypeOAuth {
return false
}
oauthType := strings.ToLower(strings.TrimSpace(account.GetCredential("oauth_type")))
if oauthType == "" && strings.TrimSpace(account.GetCredential("project_id")) != "" {
// Legacy/implicit Code Assist OAuth accounts.
oauthType = "code_assist"
}
return oauthType == "code_assist"
default:
return false
}
}
func (s *GeminiMessagesCompatService) shouldFailoverGeminiUpstreamError(statusCode int) bool {
switch statusCode {
case 401, 403, 429, 529:
return true
default:
return statusCode >= 500
}
}
func sleepGeminiBackoff(attempt int) {
delay := geminiRetryBaseDelay * time.Duration(1<<uint(attempt-1))
if delay > geminiRetryMaxDelay {
delay = geminiRetryMaxDelay
}
// +/- 20% jitter
r := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
jitter := time.Duration(float64(delay) * 0.2 * (r.Float64()*2 - 1))
sleepFor := delay + jitter
if sleepFor < 0 {
sleepFor = 0
}
time.Sleep(sleepFor)
}
var (
sensitiveQueryParamRegex = regexp.MustCompile(`(?i)([?&](?:key|client_secret|access_token|refresh_token)=)[^&"\s]+`)
retryInRegex = regexp.MustCompile(`Please retry in ([0-9.]+)s`)
)
func sanitizeUpstreamErrorMessage(msg string) string {
if msg == "" {
return msg
}
return sensitiveQueryParamRegex.ReplaceAllString(msg, `$1***`)
}
func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error {
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(body), maxBytes)
}
setOpsUpstreamError(c, upstreamStatus, upstreamMsg, upstreamDetail)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: upstreamStatus,
UpstreamRequestID: upstreamRequestID,
Kind: "http_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf("[Gemini] upstream error %d: %s", upstreamStatus, truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes))
}
var statusCode int
var errType, errMsg string
if mapped := mapGeminiErrorBodyToClaudeError(body); mapped != nil {
errType = mapped.Type
if mapped.Message != "" {
errMsg = mapped.Message
}
if mapped.StatusCode > 0 {
statusCode = mapped.StatusCode
}
}
switch upstreamStatus {
case 400:
if statusCode == 0 {
statusCode = http.StatusBadRequest
}
if errType == "" {
errType = "invalid_request_error"
}
if errMsg == "" {
errMsg = "Invalid request"
}
case 401:
if statusCode == 0 {
statusCode = http.StatusBadGateway
}
if errType == "" {
errType = "authentication_error"
}
if errMsg == "" {
errMsg = "Upstream authentication failed, please contact administrator"
}
case 403:
if statusCode == 0 {
statusCode = http.StatusBadGateway
}
if errType == "" {
errType = "permission_error"
}
if errMsg == "" {
errMsg = "Upstream access forbidden, please contact administrator"
}
case 404:
if statusCode == 0 {
statusCode = http.StatusNotFound
}
if errType == "" {
errType = "not_found_error"
}
if errMsg == "" {
errMsg = "Resource not found"
}
case 429:
if statusCode == 0 {
statusCode = http.StatusTooManyRequests
}
if errType == "" {
errType = "rate_limit_error"
}
if errMsg == "" {
errMsg = "Upstream rate limit exceeded, please retry later"
}
case 529:
if statusCode == 0 {
statusCode = http.StatusServiceUnavailable
}
if errType == "" {
errType = "overloaded_error"
}
if errMsg == "" {
errMsg = "Upstream service overloaded, please retry later"
}
case 500, 502, 503, 504:
if statusCode == 0 {
statusCode = http.StatusBadGateway
}
if errType == "" {
switch upstreamStatus {
case 504:
errType = "timeout_error"
case 503:
errType = "overloaded_error"
default:
errType = "api_error"
}
}
if errMsg == "" {
errMsg = "Upstream service temporarily unavailable"
}
default:
if statusCode == 0 {
statusCode = http.StatusBadGateway
}
if errType == "" {
errType = "upstream_error"
}
if errMsg == "" {
errMsg = "Upstream request failed"
}
}
c.JSON(statusCode, gin.H{
"type": "error",
"error": gin.H{"type": errType, "message": errMsg},
})
if upstreamMsg == "" {
return fmt.Errorf("upstream error: %d", upstreamStatus)
}
return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg)
}
type claudeErrorMapping struct {
Type string
Message string
StatusCode int
}
func mapGeminiErrorBodyToClaudeError(body []byte) *claudeErrorMapping {
if len(body) == 0 {
return nil
}
var parsed struct {
Error struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
} `json:"error"`
}
if err := json.Unmarshal(body, &parsed); err != nil {
return nil
}
if strings.TrimSpace(parsed.Error.Status) == "" && parsed.Error.Code == 0 && strings.TrimSpace(parsed.Error.Message) == "" {
return nil
}
mapped := &claudeErrorMapping{
Type: mapGeminiStatusToClaudeErrorType(parsed.Error.Status),
Message: "",
}
if mapped.Type == "" {
mapped.Type = "upstream_error"
}
switch strings.ToUpper(strings.TrimSpace(parsed.Error.Status)) {
case "INVALID_ARGUMENT":
mapped.StatusCode = http.StatusBadRequest
case "NOT_FOUND":
mapped.StatusCode = http.StatusNotFound
case "RESOURCE_EXHAUSTED":
mapped.StatusCode = http.StatusTooManyRequests
default:
// Keep StatusCode unset and let HTTP status mapping decide.
}
// Keep messages generic by default; upstream error message can be long or include sensitive fragments.
return mapped
}
func mapGeminiStatusToClaudeErrorType(status string) string {
switch strings.ToUpper(strings.TrimSpace(status)) {
case "INVALID_ARGUMENT":
return "invalid_request_error"
case "PERMISSION_DENIED":
return "permission_error"
case "NOT_FOUND":
return "not_found_error"
case "RESOURCE_EXHAUSTED":
return "rate_limit_error"
case "UNAUTHENTICATED":
return "authentication_error"
case "UNAVAILABLE":
return "overloaded_error"
case "INTERNAL":
return "api_error"
case "DEADLINE_EXCEEDED":
return "timeout_error"
default:
return ""
}
}
type geminiStreamResult struct {
usage *ClaudeUsage
firstTokenMs *int
}
func (s *GeminiMessagesCompatService) handleNonStreamingResponse(c *gin.Context, resp *http.Response, originalModel string) (*ClaudeUsage, error) {
body, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response")
}
geminiResp, err := unwrapGeminiResponse(body)
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
}
claudeResp, usage := convertGeminiToClaudeMessage(geminiResp, originalModel)
c.JSON(http.StatusOK, claudeResp)
return usage, nil
}
func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*geminiStreamResult, error) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
c.Status(http.StatusOK)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
messageID := "msg_" + randomHex(12)
messageStart := map[string]any{
"type": "message_start",
"message": map[string]any{
"id": messageID,
"type": "message",
"role": "assistant",
"model": originalModel,
"content": []any{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]any{
"input_tokens": 0,
"output_tokens": 0,
},
},
}
writeSSE(c.Writer, "message_start", messageStart)
flusher.Flush()
var firstTokenMs *int
var usage ClaudeUsage
finishReason := ""
sawToolUse := false
nextBlockIndex := 0
openBlockIndex := -1
openBlockType := ""
seenText := ""
openToolIndex := -1
openToolID := ""
openToolName := ""
seenToolJSON := ""
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadString('\n')
if err != nil && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("stream read error: %w", err)
}
if !strings.HasPrefix(line, "data:") {
if errors.Is(err, io.EOF) {
break
}
continue
}
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if payload == "" || payload == "[DONE]" {
if errors.Is(err, io.EOF) {
break
}
continue
}
geminiResp, err := unwrapGeminiResponse([]byte(payload))
if err != nil {
continue
}
if fr := extractGeminiFinishReason(geminiResp); fr != "" {
finishReason = fr
}
parts := extractGeminiParts(geminiResp)
for _, part := range parts {
if text, ok := part["text"].(string); ok && text != "" {
delta, newSeen := computeGeminiTextDelta(seenText, text)
seenText = newSeen
if delta == "" {
continue
}
if openBlockType != "text" {
if openBlockIndex >= 0 {
writeSSE(c.Writer, "content_block_stop", map[string]any{
"type": "content_block_stop",
"index": openBlockIndex,
})
}
openBlockType = "text"
openBlockIndex = nextBlockIndex
nextBlockIndex++
writeSSE(c.Writer, "content_block_start", map[string]any{
"type": "content_block_start",
"index": openBlockIndex,
"content_block": map[string]any{
"type": "text",
"text": "",
},
})
}
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
writeSSE(c.Writer, "content_block_delta", map[string]any{
"type": "content_block_delta",
"index": openBlockIndex,
"delta": map[string]any{
"type": "text_delta",
"text": delta,
},
})
flusher.Flush()
continue
}
if fc, ok := part["functionCall"].(map[string]any); ok && fc != nil {
name, _ := fc["name"].(string)
args := fc["args"]
if strings.TrimSpace(name) == "" {
name = "tool"
}
// Close any open text block before tool_use.
if openBlockIndex >= 0 {
writeSSE(c.Writer, "content_block_stop", map[string]any{
"type": "content_block_stop",
"index": openBlockIndex,
})
openBlockIndex = -1
openBlockType = ""
}
// If we receive streamed tool args in pieces, keep a single tool block open and emit deltas.
if openToolIndex >= 0 && openToolName != name {
writeSSE(c.Writer, "content_block_stop", map[string]any{
"type": "content_block_stop",
"index": openToolIndex,
})
openToolIndex = -1
openToolName = ""
seenToolJSON = ""
}
if openToolIndex < 0 {
openToolID = "toolu_" + randomHex(8)
openToolIndex = nextBlockIndex
openToolName = name
nextBlockIndex++
sawToolUse = true
writeSSE(c.Writer, "content_block_start", map[string]any{
"type": "content_block_start",
"index": openToolIndex,
"content_block": map[string]any{
"type": "tool_use",
"id": openToolID,
"name": name,
"input": map[string]any{},
},
})
}
argsJSONText := "{}"
switch v := args.(type) {
case nil:
// keep default "{}"
case string:
if strings.TrimSpace(v) != "" {
argsJSONText = v
}
default:
if b, err := json.Marshal(args); err == nil && len(b) > 0 {
argsJSONText = string(b)
}
}
delta, newSeen := computeGeminiTextDelta(seenToolJSON, argsJSONText)
seenToolJSON = newSeen
if delta != "" {
writeSSE(c.Writer, "content_block_delta", map[string]any{
"type": "content_block_delta",
"index": openToolIndex,
"delta": map[string]any{
"type": "input_json_delta",
"partial_json": delta,
},
})
}
flusher.Flush()
}
}
if u := extractGeminiUsage(geminiResp); u != nil {
usage = *u
}
// Process the final unterminated line at EOF as well.
if errors.Is(err, io.EOF) {
break
}
}
if openBlockIndex >= 0 {
writeSSE(c.Writer, "content_block_stop", map[string]any{
"type": "content_block_stop",
"index": openBlockIndex,
})
}
if openToolIndex >= 0 {
writeSSE(c.Writer, "content_block_stop", map[string]any{
"type": "content_block_stop",
"index": openToolIndex,
})
}
stopReason := mapGeminiFinishReasonToClaudeStopReason(finishReason)
if sawToolUse {
stopReason = "tool_use"
}
usageObj := map[string]any{
"output_tokens": usage.OutputTokens,
}
if usage.InputTokens > 0 {
usageObj["input_tokens"] = usage.InputTokens
}
writeSSE(c.Writer, "message_delta", map[string]any{
"type": "message_delta",
"delta": map[string]any{
"stop_reason": stopReason,
"stop_sequence": nil,
},
"usage": usageObj,
})
writeSSE(c.Writer, "message_stop", map[string]any{
"type": "message_stop",
})
flusher.Flush()
return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
}
func writeSSE(w io.Writer, event string, data any) {
if event != "" {
_, _ = fmt.Fprintf(w, "event: %s\n", event)
}
b, _ := json.Marshal(data)
_, _ = fmt.Fprintf(w, "data: %s\n\n", string(b))
}
func randomHex(nBytes int) string {
b := make([]byte, nBytes)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
func (s *GeminiMessagesCompatService) writeClaudeError(c *gin.Context, status int, errType, message string) error {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{"type": errType, "message": message},
})
return fmt.Errorf("%s", message)
}
func (s *GeminiMessagesCompatService) writeGoogleError(c *gin.Context, status int, message string) error {
c.JSON(status, gin.H{
"error": gin.H{
"code": status,
"message": message,
"status": googleapi.HTTPStatusToGoogleStatus(status),
},
})
return fmt.Errorf("%s", message)
}
func unwrapIfNeeded(isOAuth bool, raw []byte) []byte {
if !isOAuth {
return raw
}
inner, err := unwrapGeminiResponse(raw)
if err != nil {
return raw
}
b, err := json.Marshal(inner)
if err != nil {
return raw
}
return b
}
func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsage, error) {
reader := bufio.NewReader(body)
var last map[string]any
var lastWithParts map[string]any
var collectedTextParts []string // Collect all text parts for aggregation
usage := &ClaudeUsage{}
for {
line, err := reader.ReadString('\n')
if len(line) > 0 {
trimmed := strings.TrimRight(line, "\r\n")
if strings.HasPrefix(trimmed, "data:") {
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
switch payload {
case "", "[DONE]":
if payload == "[DONE]" {
return mergeCollectedTextParts(pickGeminiCollectResult(last, lastWithParts), collectedTextParts), usage, nil
}
default:
var parsed map[string]any
if isOAuth {
inner, err := unwrapGeminiResponse([]byte(payload))
if err == nil && inner != nil {
parsed = inner
}
} else {
_ = json.Unmarshal([]byte(payload), &parsed)
}
if parsed != nil {
last = parsed
if u := extractGeminiUsage(parsed); u != nil {
usage = u
}
if parts := extractGeminiParts(parsed); len(parts) > 0 {
lastWithParts = parsed
// Collect text from each part for aggregation
for _, part := range parts {
if text, ok := part["text"].(string); ok && text != "" {
collectedTextParts = append(collectedTextParts, text)
}
}
}
}
}
}
}
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, nil, err
}
}
return mergeCollectedTextParts(pickGeminiCollectResult(last, lastWithParts), collectedTextParts), usage, nil
}
func pickGeminiCollectResult(last map[string]any, lastWithParts map[string]any) map[string]any {
if lastWithParts != nil {
return lastWithParts
}
if last != nil {
return last
}
return map[string]any{}
}
// mergeCollectedTextParts merges all collected text chunks into the final response.
// This fixes the issue where non-streaming responses only returned the last chunk
// instead of the complete aggregated text.
func mergeCollectedTextParts(response map[string]any, textParts []string) map[string]any {
if len(textParts) == 0 {
return response
}
// Join all text parts
mergedText := strings.Join(textParts, "")
// Deep copy response
result := make(map[string]any)
for k, v := range response {
result[k] = v
}
// Get or create candidates
candidates, ok := result["candidates"].([]any)
if !ok || len(candidates) == 0 {
candidates = []any{map[string]any{}}
}
// Get first candidate
candidate, ok := candidates[0].(map[string]any)
if !ok {
candidate = make(map[string]any)
candidates[0] = candidate
}
// Get or create content
content, ok := candidate["content"].(map[string]any)
if !ok {
content = map[string]any{"role": "model"}
candidate["content"] = content
}
// Get existing parts
existingParts, ok := content["parts"].([]any)
if !ok {
existingParts = []any{}
}
// Find and update first text part, or create new one
newParts := make([]any, 0, len(existingParts)+1)
textUpdated := false
for _, p := range existingParts {
pm, ok := p.(map[string]any)
if !ok {
newParts = append(newParts, p)
continue
}
if _, hasText := pm["text"]; hasText && !textUpdated {
// Replace with merged text
newPart := make(map[string]any)
for k, v := range pm {
newPart[k] = v
}
newPart["text"] = mergedText
newParts = append(newParts, newPart)
textUpdated = true
} else {
newParts = append(newParts, pm)
}
}
if !textUpdated {
newParts = append([]any{map[string]any{"text": mergedText}}, newParts...)
}
content["parts"] = newParts
result["candidates"] = candidates
return result
}
type geminiNativeStreamResult struct {
usage *ClaudeUsage
firstTokenMs *int
}
func isGeminiInsufficientScope(headers http.Header, body []byte) bool {
if strings.Contains(strings.ToLower(headers.Get("Www-Authenticate")), "insufficient_scope") {
return true
}
lower := strings.ToLower(string(body))
return strings.Contains(lower, "insufficient authentication scopes") || strings.Contains(lower, "access_token_scope_insufficient")
}
func estimateGeminiCountTokens(reqBody []byte) int {
var obj map[string]any
if err := json.Unmarshal(reqBody, &obj); err != nil {
return 0
}
var texts []string
// systemInstruction.parts[].text
if si, ok := obj["systemInstruction"].(map[string]any); ok {
if parts, ok := si["parts"].([]any); ok {
for _, p := range parts {
if pm, ok := p.(map[string]any); ok {
if t, ok := pm["text"].(string); ok && strings.TrimSpace(t) != "" {
texts = append(texts, t)
}
}
}
}
}
// contents[].parts[].text
if contents, ok := obj["contents"].([]any); ok {
for _, c := range contents {
cm, ok := c.(map[string]any)
if !ok {
continue
}
parts, ok := cm["parts"].([]any)
if !ok {
continue
}
for _, p := range parts {
pm, ok := p.(map[string]any)
if !ok {
continue
}
if t, ok := pm["text"].(string); ok && strings.TrimSpace(t) != "" {
texts = append(texts, t)
}
}
}
}
total := 0
for _, t := range texts {
total += estimateTokensForText(t)
}
if total < 0 {
return 0
}
return total
}
func estimateTokensForText(s string) int {
s = strings.TrimSpace(s)
if s == "" {
return 0
}
runes := []rune(s)
if len(runes) == 0 {
return 0
}
ascii := 0
for _, r := range runes {
if r <= 0x7f {
ascii++
}
}
asciiRatio := float64(ascii) / float64(len(runes))
if asciiRatio >= 0.8 {
// Roughly 4 chars per token for English-like text.
return (len(runes) + 3) / 4
}
// For CJK-heavy text, approximate 1 rune per token.
return len(runes)
}
type UpstreamHTTPResult struct {
StatusCode int
Headers http.Header
Body []byte
}
func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) {
// Log response headers for debugging
log.Printf("[GeminiAPI] ========== Response Headers ==========")
for key, values := range resp.Header {
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
log.Printf("[GeminiAPI] %s: %v", key, values)
}
}
log.Printf("[GeminiAPI] ========================================")
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var parsed map[string]any
if isOAuth {
parsed, err = unwrapGeminiResponse(respBody)
if err == nil && parsed != nil {
respBody, _ = json.Marshal(parsed)
}
} else {
_ = json.Unmarshal(respBody, &parsed)
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}
c.Data(resp.StatusCode, contentType, respBody)
if parsed != nil {
if u := extractGeminiUsage(parsed); u != nil {
return u, nil
}
}
return &ClaudeUsage{}, nil
}
func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) {
// Log response headers for debugging
log.Printf("[GeminiAPI] ========== Streaming Response Headers ==========")
for key, values := range resp.Header {
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
log.Printf("[GeminiAPI] %s: %v", key, values)
}
}
log.Printf("[GeminiAPI] ====================================================")
if s.cfg != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
}
c.Status(resp.StatusCode)
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "text/event-stream; charset=utf-8"
}
c.Header("Content-Type", contentType)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
reader := bufio.NewReader(resp.Body)
usage := &ClaudeUsage{}
var firstTokenMs *int
for {
line, err := reader.ReadString('\n')
if len(line) > 0 {
trimmed := strings.TrimRight(line, "\r\n")
if strings.HasPrefix(trimmed, "data:") {
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
// Keepalive / done markers
if payload == "" || payload == "[DONE]" {
_, _ = io.WriteString(c.Writer, line)
flusher.Flush()
} else {
var rawToWrite string
rawToWrite = payload
var parsed map[string]any
if isOAuth {
inner, err := unwrapGeminiResponse([]byte(payload))
if err == nil && inner != nil {
parsed = inner
if b, err := json.Marshal(inner); err == nil {
rawToWrite = string(b)
}
}
} else {
_ = json.Unmarshal([]byte(payload), &parsed)
}
if parsed != nil {
if u := extractGeminiUsage(parsed); u != nil {
usage = u
}
}
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
if isOAuth {
// SSE format requires double newline (\n\n) to separate events
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", rawToWrite)
} else {
// Pass-through for AI Studio responses.
_, _ = io.WriteString(c.Writer, line)
}
flusher.Flush()
}
} else {
_, _ = io.WriteString(c.Writer, line)
flusher.Flush()
}
}
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, err
}
}
return &geminiNativeStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
// ForwardAIStudioGET forwards a GET request to AI Studio (generativelanguage.googleapis.com) for
// endpoints like /v1beta/models and /v1beta/models/{model}.
//
// This is used to support Gemini SDKs that call models listing endpoints before generation.
func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, account *Account, path string) (*UpstreamHTTPResult, error) {
if account == nil {
return nil, errors.New("account is nil")
}
path = strings.TrimSpace(path)
if path == "" || !strings.HasPrefix(path, "/") {
return nil, errors.New("invalid path")
}
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
fullURL := strings.TrimRight(normalizedBaseURL, "/") + path
var proxyURL string
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
if err != nil {
return nil, err
}
switch account.Type {
case AccountTypeAPIKey:
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
if apiKey == "" {
return nil, errors.New("gemini api_key not configured")
}
req.Header.Set("x-goog-api-key", apiKey)
case AccountTypeOAuth:
if s.tokenProvider == nil {
return nil, errors.New("gemini token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+accessToken)
default:
return nil, fmt.Errorf("unsupported account type: %s", account.Type)
}
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
wwwAuthenticate := resp.Header.Get("Www-Authenticate")
filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.cfg.Security.ResponseHeaders)
if wwwAuthenticate != "" {
filteredHeaders.Set("Www-Authenticate", wwwAuthenticate)
}
return &UpstreamHTTPResult{
StatusCode: resp.StatusCode,
Headers: filteredHeaders,
Body: body,
}, nil
}
func unwrapGeminiResponse(raw []byte) (map[string]any, error) {
var outer map[string]any
if err := json.Unmarshal(raw, &outer); err != nil {
return nil, err
}
if resp, ok := outer["response"].(map[string]any); ok && resp != nil {
return resp, nil
}
return outer, nil
}
func convertGeminiToClaudeMessage(geminiResp map[string]any, originalModel string) (map[string]any, *ClaudeUsage) {
usage := extractGeminiUsage(geminiResp)
if usage == nil {
usage = &ClaudeUsage{}
}
contentBlocks := make([]any, 0)
sawToolUse := false
if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 {
if cand, ok := candidates[0].(map[string]any); ok {
if content, ok := cand["content"].(map[string]any); ok {
if parts, ok := content["parts"].([]any); ok {
for _, part := range parts {
pm, ok := part.(map[string]any)
if !ok {
continue
}
if text, ok := pm["text"].(string); ok && text != "" {
contentBlocks = append(contentBlocks, map[string]any{
"type": "text",
"text": text,
})
}
if fc, ok := pm["functionCall"].(map[string]any); ok {
name, _ := fc["name"].(string)
if strings.TrimSpace(name) == "" {
name = "tool"
}
args := fc["args"]
sawToolUse = true
contentBlocks = append(contentBlocks, map[string]any{
"type": "tool_use",
"id": "toolu_" + randomHex(8),
"name": name,
"input": args,
})
}
}
}
}
}
}
stopReason := mapGeminiFinishReasonToClaudeStopReason(extractGeminiFinishReason(geminiResp))
if sawToolUse {
stopReason = "tool_use"
}
resp := map[string]any{
"id": "msg_" + randomHex(12),
"type": "message",
"role": "assistant",
"model": originalModel,
"content": contentBlocks,
"stop_reason": stopReason,
"stop_sequence": nil,
"usage": map[string]any{
"input_tokens": usage.InputTokens,
"output_tokens": usage.OutputTokens,
},
}
return resp, usage
}
func extractGeminiUsage(geminiResp map[string]any) *ClaudeUsage {
usageMeta, ok := geminiResp["usageMetadata"].(map[string]any)
if !ok || usageMeta == nil {
return nil
}
prompt, _ := asInt(usageMeta["promptTokenCount"])
cand, _ := asInt(usageMeta["candidatesTokenCount"])
return &ClaudeUsage{
InputTokens: prompt,
OutputTokens: cand,
}
}
func asInt(v any) (int, bool) {
switch t := v.(type) {
case float64:
return int(t), true
case int:
return t, true
case int64:
return int(t), true
case json.Number:
i, err := t.Int64()
if err != nil {
return 0, false
}
return int(i), true
default:
return 0, false
}
}
func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) {
if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) {
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
return
}
if statusCode != 429 {
return
}
oauthType := account.GeminiOAuthType()
tierID := account.GeminiTierID()
projectID := strings.TrimSpace(account.GetCredential("project_id"))
isCodeAssist := account.IsGeminiCodeAssist()
resetAt := ParseGeminiRateLimitResetTime(body)
if resetAt == nil {
// 根据账号类型使用不同的默认重置时间
var ra time.Time
if isCodeAssist {
// Code Assist: fallback cooldown by tier
cooldown := geminiCooldownForTier(tierID)
if s.rateLimitService != nil {
cooldown = s.rateLimitService.GeminiCooldown(ctx, account)
}
ra = time.Now().Add(cooldown)
log.Printf("[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second))
} else {
// API Key / AI Studio OAuth: PST 午夜
if ts := nextGeminiDailyResetUnix(); ts != nil {
ra = time.Unix(*ts, 0)
log.Printf("[Gemini 429] Account %d (API Key/AI Studio, type=%s) rate limited, reset at PST midnight (%v)", account.ID, account.Type, ra)
} else {
// 兜底5 分钟
ra = time.Now().Add(5 * time.Minute)
log.Printf("[Gemini 429] Account %d rate limited, fallback to 5min", account.ID)
}
}
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra)
return
}
// 使用解析到的重置时间
resetTime := time.Unix(*resetAt, 0)
_ = s.accountRepo.SetRateLimited(ctx, account.ID, resetTime)
log.Printf("[Gemini 429] Account %d rate limited until %v (oauth_type=%s, tier=%s)",
account.ID, resetTime, oauthType, tierID)
}
// ParseGeminiRateLimitResetTime 解析 Gemini 格式的 429 响应,返回重置时间的 Unix 时间戳
func ParseGeminiRateLimitResetTime(body []byte) *int64 {
// Try to parse metadata.quotaResetDelay like "12.345s"
var parsed map[string]any
if err := json.Unmarshal(body, &parsed); err == nil {
if errObj, ok := parsed["error"].(map[string]any); ok {
if msg, ok := errObj["message"].(string); ok {
if looksLikeGeminiDailyQuota(msg) {
if ts := nextGeminiDailyResetUnix(); ts != nil {
return ts
}
}
}
if details, ok := errObj["details"].([]any); ok {
for _, d := range details {
dm, ok := d.(map[string]any)
if !ok {
continue
}
if meta, ok := dm["metadata"].(map[string]any); ok {
if v, ok := meta["quotaResetDelay"].(string); ok {
if dur, err := time.ParseDuration(v); err == nil {
ts := time.Now().Unix() + int64(dur.Seconds())
return &ts
}
}
}
}
}
}
}
// Match "Please retry in Xs"
matches := retryInRegex.FindStringSubmatch(string(body))
if len(matches) == 2 {
if dur, err := time.ParseDuration(matches[1] + "s"); err == nil {
ts := time.Now().Unix() + int64(math.Ceil(dur.Seconds()))
return &ts
}
}
return nil
}
func looksLikeGeminiDailyQuota(message string) bool {
m := strings.ToLower(message)
if strings.Contains(m, "per day") || strings.Contains(m, "requests per day") || strings.Contains(m, "quota") && strings.Contains(m, "per day") {
return true
}
return false
}
func nextGeminiDailyResetUnix() *int64 {
reset := geminiDailyResetTime(time.Now())
ts := reset.Unix()
return &ts
}
func extractGeminiFinishReason(geminiResp map[string]any) string {
if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 {
if cand, ok := candidates[0].(map[string]any); ok {
if fr, ok := cand["finishReason"].(string); ok {
return fr
}
}
}
return ""
}
func extractGeminiParts(geminiResp map[string]any) []map[string]any {
if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 {
if cand, ok := candidates[0].(map[string]any); ok {
if content, ok := cand["content"].(map[string]any); ok {
if partsAny, ok := content["parts"].([]any); ok && len(partsAny) > 0 {
out := make([]map[string]any, 0, len(partsAny))
for _, p := range partsAny {
pm, ok := p.(map[string]any)
if !ok {
continue
}
out = append(out, pm)
}
return out
}
}
}
}
return nil
}
func computeGeminiTextDelta(seen, incoming string) (delta, newSeen string) {
incoming = strings.TrimSuffix(incoming, "\u0000")
if incoming == "" {
return "", seen
}
// Cumulative mode: incoming contains full text so far.
if strings.HasPrefix(incoming, seen) {
return strings.TrimPrefix(incoming, seen), incoming
}
// Duplicate/rewind: ignore.
if strings.HasPrefix(seen, incoming) {
return "", seen
}
// Delta mode: treat incoming as incremental chunk.
return incoming, seen + incoming
}
func mapGeminiFinishReasonToClaudeStopReason(finishReason string) string {
switch strings.ToUpper(strings.TrimSpace(finishReason)) {
case "MAX_TOKENS":
return "max_tokens"
case "STOP":
return "end_turn"
default:
return "end_turn"
}
}
func convertClaudeMessagesToGeminiGenerateContent(body []byte) ([]byte, error) {
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
}
toolUseIDToName := make(map[string]string)
systemText := extractClaudeSystemText(req["system"])
contents, err := convertClaudeMessagesToGeminiContents(req["messages"], toolUseIDToName)
if err != nil {
return nil, err
}
out := make(map[string]any)
if systemText != "" {
out["systemInstruction"] = map[string]any{
"parts": []any{map[string]any{"text": systemText}},
}
}
out["contents"] = contents
if tools := convertClaudeToolsToGeminiTools(req["tools"]); tools != nil {
out["tools"] = tools
}
generationConfig := convertClaudeGenerationConfig(req)
if generationConfig != nil {
out["generationConfig"] = generationConfig
}
stripGeminiFunctionIDs(out)
return json.Marshal(out)
}
func stripGeminiFunctionIDs(req map[string]any) {
// Defensive cleanup: some upstreams reject unexpected `id` fields in functionCall/functionResponse.
contents, ok := req["contents"].([]any)
if !ok {
return
}
for _, c := range contents {
cm, ok := c.(map[string]any)
if !ok {
continue
}
contentParts, ok := cm["parts"].([]any)
if !ok {
continue
}
for _, p := range contentParts {
pm, ok := p.(map[string]any)
if !ok {
continue
}
if fc, ok := pm["functionCall"].(map[string]any); ok && fc != nil {
delete(fc, "id")
}
if fr, ok := pm["functionResponse"].(map[string]any); ok && fr != nil {
delete(fr, "id")
}
}
}
}
func extractClaudeSystemText(system any) string {
switch v := system.(type) {
case string:
return strings.TrimSpace(v)
case []any:
var parts []string
for _, p := range v {
pm, ok := p.(map[string]any)
if !ok {
continue
}
if t, _ := pm["type"].(string); t != "text" {
continue
}
if text, ok := pm["text"].(string); ok && strings.TrimSpace(text) != "" {
parts = append(parts, text)
}
}
return strings.TrimSpace(strings.Join(parts, "\n"))
default:
return ""
}
}
func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[string]string) ([]any, error) {
arr, ok := messages.([]any)
if !ok {
return nil, errors.New("messages must be an array")
}
out := make([]any, 0, len(arr))
for _, m := range arr {
mm, ok := m.(map[string]any)
if !ok {
continue
}
role, _ := mm["role"].(string)
role = strings.ToLower(strings.TrimSpace(role))
gRole := "user"
if role == "assistant" {
gRole = "model"
}
parts := make([]any, 0)
switch content := mm["content"].(type) {
case string:
// 字符串形式的 content保留所有内容包括空白
parts = append(parts, map[string]any{"text": content})
case []any:
// 如果只有一个 block不过滤空白让上游 API 报错)
singleBlock := len(content) == 1
for _, block := range content {
bm, ok := block.(map[string]any)
if !ok {
continue
}
bt, _ := bm["type"].(string)
switch bt {
case "text":
if text, ok := bm["text"].(string); ok {
// 单个 block 时保留所有内容(包括空白)
// 多个 blocks 时过滤掉空白
if singleBlock || strings.TrimSpace(text) != "" {
parts = append(parts, map[string]any{"text": text})
}
}
case "tool_use":
id, _ := bm["id"].(string)
name, _ := bm["name"].(string)
if strings.TrimSpace(id) != "" && strings.TrimSpace(name) != "" {
toolUseIDToName[id] = name
}
parts = append(parts, map[string]any{
"functionCall": map[string]any{
"name": name,
"args": bm["input"],
},
})
case "tool_result":
toolUseID, _ := bm["tool_use_id"].(string)
name := toolUseIDToName[toolUseID]
if name == "" {
name = "tool"
}
parts = append(parts, map[string]any{
"functionResponse": map[string]any{
"name": name,
"response": map[string]any{
"content": extractClaudeContentText(bm["content"]),
},
},
})
case "image":
if src, ok := bm["source"].(map[string]any); ok {
if srcType, _ := src["type"].(string); srcType == "base64" {
mediaType, _ := src["media_type"].(string)
data, _ := src["data"].(string)
if mediaType != "" && data != "" {
parts = append(parts, map[string]any{
"inlineData": map[string]any{
"mimeType": mediaType,
"data": data,
},
})
}
}
}
default:
// best-effort: preserve unknown blocks as text
if b, err := json.Marshal(bm); err == nil {
parts = append(parts, map[string]any{"text": string(b)})
}
}
}
default:
// ignore
}
out = append(out, map[string]any{
"role": gRole,
"parts": parts,
})
}
return out, nil
}
func extractClaudeContentText(v any) string {
switch t := v.(type) {
case string:
return t
case []any:
var sb strings.Builder
for _, part := range t {
pm, ok := part.(map[string]any)
if !ok {
continue
}
if pm["type"] == "text" {
if text, ok := pm["text"].(string); ok {
_, _ = sb.WriteString(text)
}
}
}
return sb.String()
default:
b, _ := json.Marshal(t)
return string(b)
}
}
func convertClaudeToolsToGeminiTools(tools any) []any {
arr, ok := tools.([]any)
if !ok || len(arr) == 0 {
return nil
}
funcDecls := make([]any, 0, len(arr))
for _, t := range arr {
tm, ok := t.(map[string]any)
if !ok {
continue
}
var name, desc string
var params any
// 检查是否为 custom 类型工具 (MCP)
toolType, _ := tm["type"].(string)
if toolType == "custom" {
// Custom 格式: 从 custom 字段获取 description 和 input_schema
custom, ok := tm["custom"].(map[string]any)
if !ok {
continue
}
name, _ = tm["name"].(string)
desc, _ = custom["description"].(string)
params = custom["input_schema"]
} else {
// 标准格式: 从顶层字段获取
name, _ = tm["name"].(string)
desc, _ = tm["description"].(string)
params = tm["input_schema"]
}
if name == "" {
continue
}
// 为 nil params 提供默认值
if params == nil {
params = map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
// 清理 JSON Schema
cleanedParams := cleanToolSchema(params)
funcDecls = append(funcDecls, map[string]any{
"name": name,
"description": desc,
"parameters": cleanedParams,
})
}
if len(funcDecls) == 0 {
return nil
}
return []any{
map[string]any{
"functionDeclarations": funcDecls,
},
}
}
// cleanToolSchema 清理工具的 JSON Schema移除 Gemini 不支持的字段
func cleanToolSchema(schema any) any {
if schema == nil {
return nil
}
switch v := schema.(type) {
case map[string]any:
cleaned := make(map[string]any)
for key, value := range v {
// 跳过不支持的字段
if key == "$schema" || key == "$id" || key == "$ref" ||
key == "additionalProperties" || key == "minLength" ||
key == "maxLength" || key == "minItems" || key == "maxItems" {
continue
}
// 递归清理嵌套对象
cleaned[key] = cleanToolSchema(value)
}
// 规范化 type 字段为大写
if typeVal, ok := cleaned["type"].(string); ok {
cleaned["type"] = strings.ToUpper(typeVal)
}
return cleaned
case []any:
cleaned := make([]any, len(v))
for i, item := range v {
cleaned[i] = cleanToolSchema(item)
}
return cleaned
default:
return v
}
}
func convertClaudeGenerationConfig(req map[string]any) map[string]any {
out := make(map[string]any)
if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 {
out["maxOutputTokens"] = mt
}
if temp, ok := req["temperature"].(float64); ok {
out["temperature"] = temp
}
if topP, ok := req["top_p"].(float64); ok {
out["topP"] = topP
}
if stopSeq, ok := req["stop_sequences"].([]any); ok && len(stopSeq) > 0 {
out["stopSequences"] = stopSeq
}
if len(out) == 0 {
return nil
}
return out
}
// extractImageSize 从 Gemini 请求中提取 image_size 参数
func (s *GeminiMessagesCompatService) extractImageSize(body []byte) string {
var req struct {
GenerationConfig *struct {
ImageConfig *struct {
ImageSize string `json:"imageSize"`
} `json:"imageConfig"`
} `json:"generationConfig"`
}
if err := json.Unmarshal(body, &req); err != nil {
return "2K"
}
if req.GenerationConfig != nil && req.GenerationConfig.ImageConfig != nil {
size := strings.ToUpper(strings.TrimSpace(req.GenerationConfig.ImageConfig.ImageSize))
if size == "1K" || size == "2K" || size == "4K" {
return size
}
}
return "2K"
}