Files
sub2api/backend/internal/service/gemini_messages_compat_service.go
2026-01-05 14:43:08 +08:00

2550 lines
71 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"bufio"
"bytes"
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"math"
mathrand "math/rand"
"net/http"
"regexp"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
)
const geminiStickySessionTTL = time.Hour
const (
geminiMaxRetries = 5
geminiRetryBaseDelay = 1 * time.Second
geminiRetryMaxDelay = 16 * time.Second
)
type GeminiMessagesCompatService struct {
accountRepo AccountRepository
groupRepo GroupRepository
cache GatewayCache
tokenProvider *GeminiTokenProvider
rateLimitService *RateLimitService
httpUpstream HTTPUpstream
antigravityGatewayService *AntigravityGatewayService
cfg *config.Config
}
func NewGeminiMessagesCompatService(
accountRepo AccountRepository,
groupRepo GroupRepository,
cache GatewayCache,
tokenProvider *GeminiTokenProvider,
rateLimitService *RateLimitService,
httpUpstream HTTPUpstream,
antigravityGatewayService *AntigravityGatewayService,
cfg *config.Config,
) *GeminiMessagesCompatService {
return &GeminiMessagesCompatService{
accountRepo: accountRepo,
groupRepo: groupRepo,
cache: cache,
tokenProvider: tokenProvider,
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
antigravityGatewayService: antigravityGatewayService,
cfg: cfg,
}
}
// GetTokenProvider returns the token provider for OAuth accounts
func (s *GeminiMessagesCompatService) GetTokenProvider() *GeminiTokenProvider {
return s.tokenProvider
}
func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil)
}
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 优先检查 context 中的强制平台(/antigravity 路由)
var platform string
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
platform = forcePlatform
} else if groupID != nil {
// 根据分组 platform 决定查询哪种账号
group, err := s.groupRepo.GetByID(ctx, *groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
}
platform = group.Platform
} else {
// 无分组时只使用原生 gemini 平台
platform = PlatformGemini
}
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
var queryPlatforms []string
if useMixedScheduling {
queryPlatforms = []string{PlatformGemini, PlatformAntigravity}
} else {
queryPlatforms = []string{platform}
}
cacheKey := "gemini:" + sessionHash
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
// 检查账号是否有效原生平台直接匹配antigravity 需要启用混合调度
if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
valid := false
if account.Platform == platform {
valid = true
} else if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
valid = true
}
if valid {
usable := true
if s.rateLimitService != nil && requestedModel != "" {
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
}
if !ok {
usable = false
}
}
if usable {
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
return account, nil
}
}
}
}
}
}
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
// 强制平台模式下,分组中找不到账户时回退查询全部
if len(accounts) == 0 && hasForcePlatform {
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
}
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
var selected *Account
for i := range accounts {
acc := &accounts[i]
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// 混合调度模式下原生平台直接通过antigravity 需要启用 mixed_scheduling
// 非混合调度模式antigravity 分组):不需要过滤
if useMixedScheduling && acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
if s.rateLimitService != nil && requestedModel != "" {
ok, err := s.rateLimitService.PreCheckUsage(ctx, acc, requestedModel)
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", acc.ID, err)
}
if !ok {
continue
}
}
if selected == nil {
selected = acc
continue
}
if acc.Priority < selected.Priority {
selected = acc
} else if acc.Priority == selected.Priority {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
// Prefer OAuth accounts when both are unused (more compatible for Code Assist flows).
if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
}
}
}
}
if selected == nil {
if requestedModel != "" {
return nil, fmt.Errorf("no available Gemini accounts supporting model: %s", requestedModel)
}
return nil, errors.New("no available Gemini accounts")
}
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, cacheKey, selected.ID, geminiStickySessionTTL)
}
return selected, nil
}
// isModelSupportedByAccount 根据账户平台检查模型支持
func (s *GeminiMessagesCompatService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {
return IsAntigravityModelSupported(requestedModel)
}
return account.IsModelSupported(requestedModel)
}
// GetAntigravityGatewayService 返回 AntigravityGatewayService
func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *AntigravityGatewayService {
return s.antigravityGatewayService
}
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAntigravity)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAntigravity)
}
if err != nil {
return false, err
}
return len(accounts) > 0, nil
}
// SelectAccountForAIStudioEndpoints selects an account that is likely to succeed against
// generativelanguage.googleapis.com (e.g. GET /v1beta/models).
//
// Preference order:
// 1) API key accounts (AI Studio)
// 2) OAuth accounts without project_id (AI Studio OAuth)
// 3) OAuth accounts explicitly marked as ai_studio
// 4) Any remaining Gemini accounts (fallback)
func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*Account, error) {
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
if len(accounts) == 0 {
return nil, errors.New("no available Gemini accounts")
}
rank := func(a *Account) int {
if a == nil {
return 999
}
switch a.Type {
case AccountTypeAPIKey:
if strings.TrimSpace(a.GetCredential("api_key")) != "" {
return 0
}
return 9
case AccountTypeOAuth:
if strings.TrimSpace(a.GetCredential("project_id")) == "" {
return 1
}
if strings.TrimSpace(a.GetCredential("oauth_type")) == "ai_studio" {
return 2
}
// Code Assist OAuth tokens often lack AI Studio scopes for models listing.
return 3
default:
return 10
}
}
var selected *Account
for i := range accounts {
acc := &accounts[i]
if selected == nil {
selected = acc
continue
}
r1, r2 := rank(acc), rank(selected)
if r1 < r2 {
selected = acc
continue
}
if r1 > r2 {
continue
}
if acc.Priority < selected.Priority {
selected = acc
} else if acc.Priority == selected.Priority {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
}
}
}
}
if selected == nil {
return nil, errors.New("no available Gemini accounts")
}
return selected, nil
}
func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
startTime := time.Now()
var req struct {
Model string `json:"model"`
Stream bool `json:"stream"`
}
if err := json.Unmarshal(body, &req); err != nil {
return nil, fmt.Errorf("parse request: %w", err)
}
if strings.TrimSpace(req.Model) == "" {
return nil, fmt.Errorf("missing model")
}
originalModel := req.Model
mappedModel := req.Model
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(req.Model)
}
geminiReq, err := convertClaudeMessagesToGeminiGenerateContent(body)
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
}
originalClaudeBody := body
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
var requestIDHeader string
var buildReq func(ctx context.Context) (*http.Request, string, error)
useUpstreamStream := req.Stream
if account.Type == AccountTypeOAuth && !req.Stream && strings.TrimSpace(account.GetCredential("project_id")) != "" {
// Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate.
useUpstreamStream = true
}
switch account.Type {
case AccountTypeAPIKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" {
return nil, "", errors.New("gemini api_key not configured")
}
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
action := "generateContent"
if req.Stream {
action = "streamGenerateContent"
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
if req.Stream {
fullURL += "?alt=sse"
}
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiReq))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("x-goog-api-key", apiKey)
return upstreamReq, "x-request-id", nil
}
requestIDHeader = "x-request-id"
case AccountTypeOAuth:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
if s.tokenProvider == nil {
return nil, "", errors.New("gemini token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, "", err
}
projectID := strings.TrimSpace(account.GetCredential("project_id"))
action := "generateContent"
if useUpstreamStream {
action = "streamGenerateContent"
}
// Two modes for OAuth:
// 1. With project_id -> Code Assist API (wrapped request)
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if projectID != "" {
// Mode 1: Code Assist API
baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), action)
if useUpstreamStream {
fullURL += "?alt=sse"
}
wrapped := map[string]any{
"model": mappedModel,
"project": projectID,
}
var inner any
if err := json.Unmarshal(geminiReq, &inner); err != nil {
return nil, "", fmt.Errorf("failed to parse gemini request: %w", err)
}
wrapped["request"] = inner
wrappedBytes, _ := json.Marshal(wrapped)
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBytes))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
upstreamReq.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
return upstreamReq, "x-request-id", nil
} else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
if useUpstreamStream {
fullURL += "?alt=sse"
}
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiReq))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
return upstreamReq, "x-request-id", nil
}
}
requestIDHeader = "x-request-id"
default:
return nil, fmt.Errorf("unsupported account type: %s", account.Type)
}
var resp *http.Response
signatureRetryStage := 0
for attempt := 1; attempt <= geminiMaxRetries; attempt++ {
upstreamReq, idHeader, err := buildReq(ctx)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, err
}
// Local build error: don't retry.
if strings.Contains(err.Error(), "missing project_id") {
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
}
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", err.Error())
}
requestIDHeader = idHeader
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
if attempt < geminiMaxRetries {
log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
sleepGeminiBackoff(attempt)
continue
}
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries: "+sanitizeUpstreamErrorMessage(err.Error()))
}
// Special-case: signature/thought_signature validation errors are not transient, but may be fixed by
// downgrading Claude thinking/tool history to plain text (conservative two-stage retry).
if resp.StatusCode == http.StatusBadRequest && signatureRetryStage < 2 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if isGeminiSignatureRelatedError(respBody) {
var strippedClaudeBody []byte
stageName := ""
switch signatureRetryStage {
case 0:
// Stage 1: disable thinking + thinking->text
strippedClaudeBody = FilterThinkingBlocksForRetry(originalClaudeBody)
stageName = "thinking-only"
signatureRetryStage = 1
default:
// Stage 2: additionally downgrade tool_use/tool_result blocks to text
strippedClaudeBody = FilterSignatureSensitiveBlocksForRetry(originalClaudeBody)
stageName = "thinking+tools"
signatureRetryStage = 2
}
retryGeminiReq, txErr := convertClaudeMessagesToGeminiGenerateContent(strippedClaudeBody)
if txErr == nil {
log.Printf("Gemini account %d: detected signature-related 400, retrying with downgraded Claude blocks (%s)", account.ID, stageName)
geminiReq = retryGeminiReq
// Consume one retry budget attempt and continue with the updated request payload.
sleepGeminiBackoff(1)
continue
}
}
// Restore body for downstream error handling.
resp = &http.Response{
StatusCode: http.StatusBadRequest,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break
}
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
// Don't treat insufficient-scope as transient.
if resp.StatusCode == 403 && isGeminiInsufficientScope(resp.Header, respBody) {
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break
}
if resp.StatusCode == 429 {
// Mark as rate-limited early so concurrent requests avoid this account.
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
if attempt < geminiMaxRetries {
log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
sleepGeminiBackoff(attempt)
continue
}
// Final attempt: surface the upstream error body (mapped below) instead of a generic retry error.
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break
}
break
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
tempMatched := false
if s.rateLimitService != nil {
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
}
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if tempMatched {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return nil, s.writeGeminiMappedError(c, resp.StatusCode, respBody)
}
requestID := resp.Header.Get(requestIDHeader)
if requestID == "" {
requestID = resp.Header.Get("x-goog-request-id")
}
if requestID != "" {
c.Header("x-request-id", requestID)
}
var usage *ClaudeUsage
var firstTokenMs *int
if req.Stream {
streamRes, err := s.handleStreamingResponse(c, resp, startTime, originalModel)
if err != nil {
return nil, err
}
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
} else {
if useUpstreamStream {
collected, usageObj, err := collectGeminiSSE(resp.Body, true)
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream stream")
}
claudeResp, usageObj2 := convertGeminiToClaudeMessage(collected, originalModel)
c.JSON(http.StatusOK, claudeResp)
usage = usageObj2
if usageObj != nil && (usageObj.InputTokens > 0 || usageObj.OutputTokens > 0) {
usage = usageObj
}
} else {
usage, err = s.handleNonStreamingResponse(c, resp, originalModel)
if err != nil {
return nil, err
}
}
}
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
Model: originalModel,
Stream: req.Stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
func isGeminiSignatureRelatedError(respBody []byte) bool {
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
if msg == "" {
msg = strings.ToLower(string(respBody))
}
return strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature")
}
func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
startTime := time.Now()
if strings.TrimSpace(originalModel) == "" {
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing model in URL")
}
if strings.TrimSpace(action) == "" {
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing action in URL")
}
if len(body) == 0 {
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
}
switch action {
case "generateContent", "streamGenerateContent", "countTokens":
// ok
default:
return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action)
}
mappedModel := originalModel
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(originalModel)
}
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
useUpstreamStream := stream
upstreamAction := action
if account.Type == AccountTypeOAuth && !stream && action == "generateContent" && strings.TrimSpace(account.GetCredential("project_id")) != "" {
// Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate.
useUpstreamStream = true
upstreamAction = "streamGenerateContent"
}
forceAIStudio := action == "countTokens"
var requestIDHeader string
var buildReq func(ctx context.Context) (*http.Request, string, error)
switch account.Type {
case AccountTypeAPIKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" {
return nil, "", errors.New("gemini api_key not configured")
}
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(body))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("x-goog-api-key", apiKey)
return upstreamReq, "x-request-id", nil
}
requestIDHeader = "x-request-id"
case AccountTypeOAuth:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
if s.tokenProvider == nil {
return nil, "", errors.New("gemini token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, "", err
}
projectID := strings.TrimSpace(account.GetCredential("project_id"))
// Two modes for OAuth:
// 1. With project_id -> Code Assist API (wrapped request)
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if projectID != "" && !forceAIStudio {
// Mode 1: Code Assist API
baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
wrapped := map[string]any{
"model": mappedModel,
"project": projectID,
}
var inner any
if err := json.Unmarshal(body, &inner); err != nil {
return nil, "", fmt.Errorf("failed to parse gemini request: %w", err)
}
wrapped["request"] = inner
wrappedBytes, _ := json.Marshal(wrapped)
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBytes))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
upstreamReq.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
return upstreamReq, "x-request-id", nil
} else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(body))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
return upstreamReq, "x-request-id", nil
}
}
requestIDHeader = "x-request-id"
default:
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Unsupported account type: "+account.Type)
}
var resp *http.Response
for attempt := 1; attempt <= geminiMaxRetries; attempt++ {
upstreamReq, idHeader, err := buildReq(ctx)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, err
}
// Local build error: don't retry.
if strings.Contains(err.Error(), "missing project_id") {
return nil, s.writeGoogleError(c, http.StatusBadRequest, err.Error())
}
return nil, s.writeGoogleError(c, http.StatusBadGateway, err.Error())
}
requestIDHeader = idHeader
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
if attempt < geminiMaxRetries {
log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
sleepGeminiBackoff(attempt)
continue
}
if action == "countTokens" {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
}
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries: "+sanitizeUpstreamErrorMessage(err.Error()))
}
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
// Don't treat insufficient-scope as transient.
if resp.StatusCode == 403 && isGeminiInsufficientScope(resp.Header, respBody) {
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break
}
if resp.StatusCode == 429 {
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
if attempt < geminiMaxRetries {
log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
sleepGeminiBackoff(attempt)
continue
}
if action == "countTokens" {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
}
// Final attempt: surface the upstream error body (passed through below) instead of a generic retry error.
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break
}
break
}
defer func() { _ = resp.Body.Close() }()
requestID := resp.Header.Get(requestIDHeader)
if requestID == "" {
requestID = resp.Header.Get("x-goog-request-id")
}
if requestID != "" {
c.Header("x-request-id", requestID)
}
isOAuth := account.Type == AccountTypeOAuth
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
tempMatched := false
if s.rateLimitService != nil {
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
}
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens.
// This avoids Gemini SDKs failing hard during preflight token counting.
if action == "countTokens" && isOAuth && isGeminiInsufficientScope(resp.Header, respBody) {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{
RequestID: requestID,
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
}
if tempMatched {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
respBody = unwrapIfNeeded(isOAuth, respBody)
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}
c.Data(resp.StatusCode, contentType, respBody)
return nil, fmt.Errorf("gemini upstream error: %d", resp.StatusCode)
}
var usage *ClaudeUsage
var firstTokenMs *int
if stream {
streamRes, err := s.handleNativeStreamingResponse(c, resp, startTime, isOAuth)
if err != nil {
return nil, err
}
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
} else {
if useUpstreamStream {
collected, usageObj, err := collectGeminiSSE(resp.Body, isOAuth)
if err != nil {
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Failed to read upstream stream")
}
b, _ := json.Marshal(collected)
c.Data(http.StatusOK, "application/json", b)
usage = usageObj
} else {
usageResp, err := s.handleNativeNonStreamingResponse(c, resp, isOAuth)
if err != nil {
return nil, err
}
usage = usageResp
}
}
if usage == nil {
usage = &ClaudeUsage{}
}
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
Model: originalModel,
Stream: stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Account, statusCode int) bool {
switch statusCode {
case 429, 500, 502, 503, 504, 529:
return true
case 403:
// GeminiCli OAuth occasionally returns 403 transiently (activation/quota propagation); allow retry.
if account == nil || account.Type != AccountTypeOAuth {
return false
}
oauthType := strings.ToLower(strings.TrimSpace(account.GetCredential("oauth_type")))
if oauthType == "" && strings.TrimSpace(account.GetCredential("project_id")) != "" {
// Legacy/implicit Code Assist OAuth accounts.
oauthType = "code_assist"
}
return oauthType == "code_assist"
default:
return false
}
}
func (s *GeminiMessagesCompatService) shouldFailoverGeminiUpstreamError(statusCode int) bool {
switch statusCode {
case 401, 403, 429, 529:
return true
default:
return statusCode >= 500
}
}
func sleepGeminiBackoff(attempt int) {
delay := geminiRetryBaseDelay * time.Duration(1<<uint(attempt-1))
if delay > geminiRetryMaxDelay {
delay = geminiRetryMaxDelay
}
// +/- 20% jitter
r := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
jitter := time.Duration(float64(delay) * 0.2 * (r.Float64()*2 - 1))
sleepFor := delay + jitter
if sleepFor < 0 {
sleepFor = 0
}
time.Sleep(sleepFor)
}
var (
sensitiveQueryParamRegex = regexp.MustCompile(`(?i)([?&](?:key|client_secret|access_token|refresh_token)=)[^&"\s]+`)
retryInRegex = regexp.MustCompile(`Please retry in ([0-9.]+)s`)
)
func sanitizeUpstreamErrorMessage(msg string) string {
if msg == "" {
return msg
}
return sensitiveQueryParamRegex.ReplaceAllString(msg, `$1***`)
}
func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, upstreamStatus int, body []byte) error {
var statusCode int
var errType, errMsg string
if mapped := mapGeminiErrorBodyToClaudeError(body); mapped != nil {
errType = mapped.Type
if mapped.Message != "" {
errMsg = mapped.Message
}
if mapped.StatusCode > 0 {
statusCode = mapped.StatusCode
}
}
switch upstreamStatus {
case 400:
if statusCode == 0 {
statusCode = http.StatusBadRequest
}
if errType == "" {
errType = "invalid_request_error"
}
if errMsg == "" {
errMsg = "Invalid request"
}
case 401:
if statusCode == 0 {
statusCode = http.StatusBadGateway
}
if errType == "" {
errType = "authentication_error"
}
if errMsg == "" {
errMsg = "Upstream authentication failed, please contact administrator"
}
case 403:
if statusCode == 0 {
statusCode = http.StatusBadGateway
}
if errType == "" {
errType = "permission_error"
}
if errMsg == "" {
errMsg = "Upstream access forbidden, please contact administrator"
}
case 404:
if statusCode == 0 {
statusCode = http.StatusNotFound
}
if errType == "" {
errType = "not_found_error"
}
if errMsg == "" {
errMsg = "Resource not found"
}
case 429:
if statusCode == 0 {
statusCode = http.StatusTooManyRequests
}
if errType == "" {
errType = "rate_limit_error"
}
if errMsg == "" {
errMsg = "Upstream rate limit exceeded, please retry later"
}
case 529:
if statusCode == 0 {
statusCode = http.StatusServiceUnavailable
}
if errType == "" {
errType = "overloaded_error"
}
if errMsg == "" {
errMsg = "Upstream service overloaded, please retry later"
}
case 500, 502, 503, 504:
if statusCode == 0 {
statusCode = http.StatusBadGateway
}
if errType == "" {
switch upstreamStatus {
case 504:
errType = "timeout_error"
case 503:
errType = "overloaded_error"
default:
errType = "api_error"
}
}
if errMsg == "" {
errMsg = "Upstream service temporarily unavailable"
}
default:
if statusCode == 0 {
statusCode = http.StatusBadGateway
}
if errType == "" {
errType = "upstream_error"
}
if errMsg == "" {
errMsg = "Upstream request failed"
}
}
c.JSON(statusCode, gin.H{
"type": "error",
"error": gin.H{"type": errType, "message": errMsg},
})
return fmt.Errorf("upstream error: %d", upstreamStatus)
}
type claudeErrorMapping struct {
Type string
Message string
StatusCode int
}
func mapGeminiErrorBodyToClaudeError(body []byte) *claudeErrorMapping {
if len(body) == 0 {
return nil
}
var parsed struct {
Error struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
} `json:"error"`
}
if err := json.Unmarshal(body, &parsed); err != nil {
return nil
}
if strings.TrimSpace(parsed.Error.Status) == "" && parsed.Error.Code == 0 && strings.TrimSpace(parsed.Error.Message) == "" {
return nil
}
mapped := &claudeErrorMapping{
Type: mapGeminiStatusToClaudeErrorType(parsed.Error.Status),
Message: "",
}
if mapped.Type == "" {
mapped.Type = "upstream_error"
}
switch strings.ToUpper(strings.TrimSpace(parsed.Error.Status)) {
case "INVALID_ARGUMENT":
mapped.StatusCode = http.StatusBadRequest
case "NOT_FOUND":
mapped.StatusCode = http.StatusNotFound
case "RESOURCE_EXHAUSTED":
mapped.StatusCode = http.StatusTooManyRequests
default:
// Keep StatusCode unset and let HTTP status mapping decide.
}
// Keep messages generic by default; upstream error message can be long or include sensitive fragments.
return mapped
}
func mapGeminiStatusToClaudeErrorType(status string) string {
switch strings.ToUpper(strings.TrimSpace(status)) {
case "INVALID_ARGUMENT":
return "invalid_request_error"
case "PERMISSION_DENIED":
return "permission_error"
case "NOT_FOUND":
return "not_found_error"
case "RESOURCE_EXHAUSTED":
return "rate_limit_error"
case "UNAUTHENTICATED":
return "authentication_error"
case "UNAVAILABLE":
return "overloaded_error"
case "INTERNAL":
return "api_error"
case "DEADLINE_EXCEEDED":
return "timeout_error"
default:
return ""
}
}
type geminiStreamResult struct {
usage *ClaudeUsage
firstTokenMs *int
}
func (s *GeminiMessagesCompatService) handleNonStreamingResponse(c *gin.Context, resp *http.Response, originalModel string) (*ClaudeUsage, error) {
body, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response")
}
geminiResp, err := unwrapGeminiResponse(body)
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
}
claudeResp, usage := convertGeminiToClaudeMessage(geminiResp, originalModel)
c.JSON(http.StatusOK, claudeResp)
return usage, nil
}
func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*geminiStreamResult, error) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
c.Status(http.StatusOK)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
messageID := "msg_" + randomHex(12)
messageStart := map[string]any{
"type": "message_start",
"message": map[string]any{
"id": messageID,
"type": "message",
"role": "assistant",
"model": originalModel,
"content": []any{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]any{
"input_tokens": 0,
"output_tokens": 0,
},
},
}
writeSSE(c.Writer, "message_start", messageStart)
flusher.Flush()
var firstTokenMs *int
var usage ClaudeUsage
finishReason := ""
sawToolUse := false
nextBlockIndex := 0
openBlockIndex := -1
openBlockType := ""
seenText := ""
openToolIndex := -1
openToolID := ""
openToolName := ""
seenToolJSON := ""
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadString('\n')
if err != nil && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("stream read error: %w", err)
}
if !strings.HasPrefix(line, "data:") {
if errors.Is(err, io.EOF) {
break
}
continue
}
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if payload == "" || payload == "[DONE]" {
if errors.Is(err, io.EOF) {
break
}
continue
}
geminiResp, err := unwrapGeminiResponse([]byte(payload))
if err != nil {
continue
}
if fr := extractGeminiFinishReason(geminiResp); fr != "" {
finishReason = fr
}
parts := extractGeminiParts(geminiResp)
for _, part := range parts {
if text, ok := part["text"].(string); ok && text != "" {
delta, newSeen := computeGeminiTextDelta(seenText, text)
seenText = newSeen
if delta == "" {
continue
}
if openBlockType != "text" {
if openBlockIndex >= 0 {
writeSSE(c.Writer, "content_block_stop", map[string]any{
"type": "content_block_stop",
"index": openBlockIndex,
})
}
openBlockType = "text"
openBlockIndex = nextBlockIndex
nextBlockIndex++
writeSSE(c.Writer, "content_block_start", map[string]any{
"type": "content_block_start",
"index": openBlockIndex,
"content_block": map[string]any{
"type": "text",
"text": "",
},
})
}
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
writeSSE(c.Writer, "content_block_delta", map[string]any{
"type": "content_block_delta",
"index": openBlockIndex,
"delta": map[string]any{
"type": "text_delta",
"text": delta,
},
})
flusher.Flush()
continue
}
if fc, ok := part["functionCall"].(map[string]any); ok && fc != nil {
name, _ := fc["name"].(string)
args := fc["args"]
if strings.TrimSpace(name) == "" {
name = "tool"
}
// Close any open text block before tool_use.
if openBlockIndex >= 0 {
writeSSE(c.Writer, "content_block_stop", map[string]any{
"type": "content_block_stop",
"index": openBlockIndex,
})
openBlockIndex = -1
openBlockType = ""
}
// If we receive streamed tool args in pieces, keep a single tool block open and emit deltas.
if openToolIndex >= 0 && openToolName != name {
writeSSE(c.Writer, "content_block_stop", map[string]any{
"type": "content_block_stop",
"index": openToolIndex,
})
openToolIndex = -1
openToolName = ""
seenToolJSON = ""
}
if openToolIndex < 0 {
openToolID = "toolu_" + randomHex(8)
openToolIndex = nextBlockIndex
openToolName = name
nextBlockIndex++
sawToolUse = true
writeSSE(c.Writer, "content_block_start", map[string]any{
"type": "content_block_start",
"index": openToolIndex,
"content_block": map[string]any{
"type": "tool_use",
"id": openToolID,
"name": name,
"input": map[string]any{},
},
})
}
argsJSONText := "{}"
switch v := args.(type) {
case nil:
// keep default "{}"
case string:
if strings.TrimSpace(v) != "" {
argsJSONText = v
}
default:
if b, err := json.Marshal(args); err == nil && len(b) > 0 {
argsJSONText = string(b)
}
}
delta, newSeen := computeGeminiTextDelta(seenToolJSON, argsJSONText)
seenToolJSON = newSeen
if delta != "" {
writeSSE(c.Writer, "content_block_delta", map[string]any{
"type": "content_block_delta",
"index": openToolIndex,
"delta": map[string]any{
"type": "input_json_delta",
"partial_json": delta,
},
})
}
flusher.Flush()
}
}
if u := extractGeminiUsage(geminiResp); u != nil {
usage = *u
}
// Process the final unterminated line at EOF as well.
if errors.Is(err, io.EOF) {
break
}
}
if openBlockIndex >= 0 {
writeSSE(c.Writer, "content_block_stop", map[string]any{
"type": "content_block_stop",
"index": openBlockIndex,
})
}
if openToolIndex >= 0 {
writeSSE(c.Writer, "content_block_stop", map[string]any{
"type": "content_block_stop",
"index": openToolIndex,
})
}
stopReason := mapGeminiFinishReasonToClaudeStopReason(finishReason)
if sawToolUse {
stopReason = "tool_use"
}
usageObj := map[string]any{
"output_tokens": usage.OutputTokens,
}
if usage.InputTokens > 0 {
usageObj["input_tokens"] = usage.InputTokens
}
writeSSE(c.Writer, "message_delta", map[string]any{
"type": "message_delta",
"delta": map[string]any{
"stop_reason": stopReason,
"stop_sequence": nil,
},
"usage": usageObj,
})
writeSSE(c.Writer, "message_stop", map[string]any{
"type": "message_stop",
})
flusher.Flush()
return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
}
func writeSSE(w io.Writer, event string, data any) {
if event != "" {
_, _ = fmt.Fprintf(w, "event: %s\n", event)
}
b, _ := json.Marshal(data)
_, _ = fmt.Fprintf(w, "data: %s\n\n", string(b))
}
func randomHex(nBytes int) string {
b := make([]byte, nBytes)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
func (s *GeminiMessagesCompatService) writeClaudeError(c *gin.Context, status int, errType, message string) error {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{"type": errType, "message": message},
})
return fmt.Errorf("%s", message)
}
func (s *GeminiMessagesCompatService) writeGoogleError(c *gin.Context, status int, message string) error {
c.JSON(status, gin.H{
"error": gin.H{
"code": status,
"message": message,
"status": googleapi.HTTPStatusToGoogleStatus(status),
},
})
return fmt.Errorf("%s", message)
}
func unwrapIfNeeded(isOAuth bool, raw []byte) []byte {
if !isOAuth {
return raw
}
inner, err := unwrapGeminiResponse(raw)
if err != nil {
return raw
}
b, err := json.Marshal(inner)
if err != nil {
return raw
}
return b
}
func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsage, error) {
reader := bufio.NewReader(body)
var last map[string]any
var lastWithParts map[string]any
usage := &ClaudeUsage{}
for {
line, err := reader.ReadString('\n')
if len(line) > 0 {
trimmed := strings.TrimRight(line, "\r\n")
if strings.HasPrefix(trimmed, "data:") {
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
switch payload {
case "", "[DONE]":
if payload == "[DONE]" {
return pickGeminiCollectResult(last, lastWithParts), usage, nil
}
default:
var parsed map[string]any
if isOAuth {
inner, err := unwrapGeminiResponse([]byte(payload))
if err == nil && inner != nil {
parsed = inner
}
} else {
_ = json.Unmarshal([]byte(payload), &parsed)
}
if parsed != nil {
last = parsed
if u := extractGeminiUsage(parsed); u != nil {
usage = u
}
if parts := extractGeminiParts(parsed); len(parts) > 0 {
lastWithParts = parsed
}
}
}
}
}
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, nil, err
}
}
return pickGeminiCollectResult(last, lastWithParts), usage, nil
}
func pickGeminiCollectResult(last map[string]any, lastWithParts map[string]any) map[string]any {
if lastWithParts != nil {
return lastWithParts
}
if last != nil {
return last
}
return map[string]any{}
}
type geminiNativeStreamResult struct {
usage *ClaudeUsage
firstTokenMs *int
}
func isGeminiInsufficientScope(headers http.Header, body []byte) bool {
if strings.Contains(strings.ToLower(headers.Get("Www-Authenticate")), "insufficient_scope") {
return true
}
lower := strings.ToLower(string(body))
return strings.Contains(lower, "insufficient authentication scopes") || strings.Contains(lower, "access_token_scope_insufficient")
}
func estimateGeminiCountTokens(reqBody []byte) int {
var obj map[string]any
if err := json.Unmarshal(reqBody, &obj); err != nil {
return 0
}
var texts []string
// systemInstruction.parts[].text
if si, ok := obj["systemInstruction"].(map[string]any); ok {
if parts, ok := si["parts"].([]any); ok {
for _, p := range parts {
if pm, ok := p.(map[string]any); ok {
if t, ok := pm["text"].(string); ok && strings.TrimSpace(t) != "" {
texts = append(texts, t)
}
}
}
}
}
// contents[].parts[].text
if contents, ok := obj["contents"].([]any); ok {
for _, c := range contents {
cm, ok := c.(map[string]any)
if !ok {
continue
}
parts, ok := cm["parts"].([]any)
if !ok {
continue
}
for _, p := range parts {
pm, ok := p.(map[string]any)
if !ok {
continue
}
if t, ok := pm["text"].(string); ok && strings.TrimSpace(t) != "" {
texts = append(texts, t)
}
}
}
}
total := 0
for _, t := range texts {
total += estimateTokensForText(t)
}
if total < 0 {
return 0
}
return total
}
func estimateTokensForText(s string) int {
s = strings.TrimSpace(s)
if s == "" {
return 0
}
runes := []rune(s)
if len(runes) == 0 {
return 0
}
ascii := 0
for _, r := range runes {
if r <= 0x7f {
ascii++
}
}
asciiRatio := float64(ascii) / float64(len(runes))
if asciiRatio >= 0.8 {
// Roughly 4 chars per token for English-like text.
return (len(runes) + 3) / 4
}
// For CJK-heavy text, approximate 1 rune per token.
return len(runes)
}
type UpstreamHTTPResult struct {
StatusCode int
Headers http.Header
Body []byte
}
func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) {
// Log response headers for debugging
log.Printf("[GeminiAPI] ========== Response Headers ==========")
for key, values := range resp.Header {
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
log.Printf("[GeminiAPI] %s: %v", key, values)
}
}
log.Printf("[GeminiAPI] ========================================")
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var parsed map[string]any
if isOAuth {
parsed, err = unwrapGeminiResponse(respBody)
if err == nil && parsed != nil {
respBody, _ = json.Marshal(parsed)
}
} else {
_ = json.Unmarshal(respBody, &parsed)
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}
c.Data(resp.StatusCode, contentType, respBody)
if parsed != nil {
if u := extractGeminiUsage(parsed); u != nil {
return u, nil
}
}
return &ClaudeUsage{}, nil
}
func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) {
// Log response headers for debugging
log.Printf("[GeminiAPI] ========== Streaming Response Headers ==========")
for key, values := range resp.Header {
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
log.Printf("[GeminiAPI] %s: %v", key, values)
}
}
log.Printf("[GeminiAPI] ====================================================")
if s.cfg != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
}
c.Status(resp.StatusCode)
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "text/event-stream; charset=utf-8"
}
c.Header("Content-Type", contentType)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
reader := bufio.NewReader(resp.Body)
usage := &ClaudeUsage{}
var firstTokenMs *int
for {
line, err := reader.ReadString('\n')
if len(line) > 0 {
trimmed := strings.TrimRight(line, "\r\n")
if strings.HasPrefix(trimmed, "data:") {
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
// Keepalive / done markers
if payload == "" || payload == "[DONE]" {
_, _ = io.WriteString(c.Writer, line)
flusher.Flush()
} else {
var rawToWrite string
rawToWrite = payload
var parsed map[string]any
if isOAuth {
inner, err := unwrapGeminiResponse([]byte(payload))
if err == nil && inner != nil {
parsed = inner
if b, err := json.Marshal(inner); err == nil {
rawToWrite = string(b)
}
}
} else {
_ = json.Unmarshal([]byte(payload), &parsed)
}
if parsed != nil {
if u := extractGeminiUsage(parsed); u != nil {
usage = u
}
}
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
if isOAuth {
// SSE format requires double newline (\n\n) to separate events
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", rawToWrite)
} else {
// Pass-through for AI Studio responses.
_, _ = io.WriteString(c.Writer, line)
}
flusher.Flush()
}
} else {
_, _ = io.WriteString(c.Writer, line)
flusher.Flush()
}
}
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, err
}
}
return &geminiNativeStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
// ForwardAIStudioGET forwards a GET request to AI Studio (generativelanguage.googleapis.com) for
// endpoints like /v1beta/models and /v1beta/models/{model}.
//
// This is used to support Gemini SDKs that call models listing endpoints before generation.
func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, account *Account, path string) (*UpstreamHTTPResult, error) {
if account == nil {
return nil, errors.New("account is nil")
}
path = strings.TrimSpace(path)
if path == "" || !strings.HasPrefix(path, "/") {
return nil, errors.New("invalid path")
}
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
fullURL := strings.TrimRight(normalizedBaseURL, "/") + path
var proxyURL string
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
if err != nil {
return nil, err
}
switch account.Type {
case AccountTypeAPIKey:
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
if apiKey == "" {
return nil, errors.New("gemini api_key not configured")
}
req.Header.Set("x-goog-api-key", apiKey)
case AccountTypeOAuth:
if s.tokenProvider == nil {
return nil, errors.New("gemini token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+accessToken)
default:
return nil, fmt.Errorf("unsupported account type: %s", account.Type)
}
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
wwwAuthenticate := resp.Header.Get("Www-Authenticate")
filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.cfg.Security.ResponseHeaders)
if wwwAuthenticate != "" {
filteredHeaders.Set("Www-Authenticate", wwwAuthenticate)
}
return &UpstreamHTTPResult{
StatusCode: resp.StatusCode,
Headers: filteredHeaders,
Body: body,
}, nil
}
func unwrapGeminiResponse(raw []byte) (map[string]any, error) {
var outer map[string]any
if err := json.Unmarshal(raw, &outer); err != nil {
return nil, err
}
if resp, ok := outer["response"].(map[string]any); ok && resp != nil {
return resp, nil
}
return outer, nil
}
func convertGeminiToClaudeMessage(geminiResp map[string]any, originalModel string) (map[string]any, *ClaudeUsage) {
usage := extractGeminiUsage(geminiResp)
if usage == nil {
usage = &ClaudeUsage{}
}
contentBlocks := make([]any, 0)
sawToolUse := false
if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 {
if cand, ok := candidates[0].(map[string]any); ok {
if content, ok := cand["content"].(map[string]any); ok {
if parts, ok := content["parts"].([]any); ok {
for _, part := range parts {
pm, ok := part.(map[string]any)
if !ok {
continue
}
if text, ok := pm["text"].(string); ok && text != "" {
contentBlocks = append(contentBlocks, map[string]any{
"type": "text",
"text": text,
})
}
if fc, ok := pm["functionCall"].(map[string]any); ok {
name, _ := fc["name"].(string)
if strings.TrimSpace(name) == "" {
name = "tool"
}
args := fc["args"]
sawToolUse = true
contentBlocks = append(contentBlocks, map[string]any{
"type": "tool_use",
"id": "toolu_" + randomHex(8),
"name": name,
"input": args,
})
}
}
}
}
}
}
stopReason := mapGeminiFinishReasonToClaudeStopReason(extractGeminiFinishReason(geminiResp))
if sawToolUse {
stopReason = "tool_use"
}
resp := map[string]any{
"id": "msg_" + randomHex(12),
"type": "message",
"role": "assistant",
"model": originalModel,
"content": contentBlocks,
"stop_reason": stopReason,
"stop_sequence": nil,
"usage": map[string]any{
"input_tokens": usage.InputTokens,
"output_tokens": usage.OutputTokens,
},
}
return resp, usage
}
func extractGeminiUsage(geminiResp map[string]any) *ClaudeUsage {
usageMeta, ok := geminiResp["usageMetadata"].(map[string]any)
if !ok || usageMeta == nil {
return nil
}
prompt, _ := asInt(usageMeta["promptTokenCount"])
cand, _ := asInt(usageMeta["candidatesTokenCount"])
return &ClaudeUsage{
InputTokens: prompt,
OutputTokens: cand,
}
}
func asInt(v any) (int, bool) {
switch t := v.(type) {
case float64:
return int(t), true
case int:
return t, true
case int64:
return int(t), true
case json.Number:
i, err := t.Int64()
if err != nil {
return 0, false
}
return int(i), true
default:
return 0, false
}
}
func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) {
if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) {
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
return
}
if statusCode != 429 {
return
}
oauthType := account.GeminiOAuthType()
tierID := account.GeminiTierID()
projectID := strings.TrimSpace(account.GetCredential("project_id"))
isCodeAssist := account.IsGeminiCodeAssist()
resetAt := ParseGeminiRateLimitResetTime(body)
if resetAt == nil {
// 根据账号类型使用不同的默认重置时间
var ra time.Time
if isCodeAssist {
// Code Assist: fallback cooldown by tier
cooldown := geminiCooldownForTier(tierID)
if s.rateLimitService != nil {
cooldown = s.rateLimitService.GeminiCooldown(ctx, account)
}
ra = time.Now().Add(cooldown)
log.Printf("[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second))
} else {
// API Key / AI Studio OAuth: PST 午夜
if ts := nextGeminiDailyResetUnix(); ts != nil {
ra = time.Unix(*ts, 0)
log.Printf("[Gemini 429] Account %d (API Key/AI Studio, type=%s) rate limited, reset at PST midnight (%v)", account.ID, account.Type, ra)
} else {
// 兜底5 分钟
ra = time.Now().Add(5 * time.Minute)
log.Printf("[Gemini 429] Account %d rate limited, fallback to 5min", account.ID)
}
}
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra)
return
}
// 使用解析到的重置时间
resetTime := time.Unix(*resetAt, 0)
_ = s.accountRepo.SetRateLimited(ctx, account.ID, resetTime)
log.Printf("[Gemini 429] Account %d rate limited until %v (oauth_type=%s, tier=%s)",
account.ID, resetTime, oauthType, tierID)
}
// ParseGeminiRateLimitResetTime 解析 Gemini 格式的 429 响应,返回重置时间的 Unix 时间戳
func ParseGeminiRateLimitResetTime(body []byte) *int64 {
// Try to parse metadata.quotaResetDelay like "12.345s"
var parsed map[string]any
if err := json.Unmarshal(body, &parsed); err == nil {
if errObj, ok := parsed["error"].(map[string]any); ok {
if msg, ok := errObj["message"].(string); ok {
if looksLikeGeminiDailyQuota(msg) {
if ts := nextGeminiDailyResetUnix(); ts != nil {
return ts
}
}
}
if details, ok := errObj["details"].([]any); ok {
for _, d := range details {
dm, ok := d.(map[string]any)
if !ok {
continue
}
if meta, ok := dm["metadata"].(map[string]any); ok {
if v, ok := meta["quotaResetDelay"].(string); ok {
if dur, err := time.ParseDuration(v); err == nil {
ts := time.Now().Unix() + int64(dur.Seconds())
return &ts
}
}
}
}
}
}
}
// Match "Please retry in Xs"
matches := retryInRegex.FindStringSubmatch(string(body))
if len(matches) == 2 {
if dur, err := time.ParseDuration(matches[1] + "s"); err == nil {
ts := time.Now().Unix() + int64(math.Ceil(dur.Seconds()))
return &ts
}
}
return nil
}
func looksLikeGeminiDailyQuota(message string) bool {
m := strings.ToLower(message)
if strings.Contains(m, "per day") || strings.Contains(m, "requests per day") || strings.Contains(m, "quota") && strings.Contains(m, "per day") {
return true
}
return false
}
func nextGeminiDailyResetUnix() *int64 {
reset := geminiDailyResetTime(time.Now())
ts := reset.Unix()
return &ts
}
func extractGeminiFinishReason(geminiResp map[string]any) string {
if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 {
if cand, ok := candidates[0].(map[string]any); ok {
if fr, ok := cand["finishReason"].(string); ok {
return fr
}
}
}
return ""
}
func extractGeminiParts(geminiResp map[string]any) []map[string]any {
if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 {
if cand, ok := candidates[0].(map[string]any); ok {
if content, ok := cand["content"].(map[string]any); ok {
if partsAny, ok := content["parts"].([]any); ok && len(partsAny) > 0 {
out := make([]map[string]any, 0, len(partsAny))
for _, p := range partsAny {
pm, ok := p.(map[string]any)
if !ok {
continue
}
out = append(out, pm)
}
return out
}
}
}
}
return nil
}
func computeGeminiTextDelta(seen, incoming string) (delta, newSeen string) {
incoming = strings.TrimSuffix(incoming, "\u0000")
if incoming == "" {
return "", seen
}
// Cumulative mode: incoming contains full text so far.
if strings.HasPrefix(incoming, seen) {
return strings.TrimPrefix(incoming, seen), incoming
}
// Duplicate/rewind: ignore.
if strings.HasPrefix(seen, incoming) {
return "", seen
}
// Delta mode: treat incoming as incremental chunk.
return incoming, seen + incoming
}
func mapGeminiFinishReasonToClaudeStopReason(finishReason string) string {
switch strings.ToUpper(strings.TrimSpace(finishReason)) {
case "MAX_TOKENS":
return "max_tokens"
case "STOP":
return "end_turn"
default:
return "end_turn"
}
}
func convertClaudeMessagesToGeminiGenerateContent(body []byte) ([]byte, error) {
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
}
toolUseIDToName := make(map[string]string)
systemText := extractClaudeSystemText(req["system"])
contents, err := convertClaudeMessagesToGeminiContents(req["messages"], toolUseIDToName)
if err != nil {
return nil, err
}
out := make(map[string]any)
if systemText != "" {
out["systemInstruction"] = map[string]any{
"parts": []any{map[string]any{"text": systemText}},
}
}
out["contents"] = contents
if tools := convertClaudeToolsToGeminiTools(req["tools"]); tools != nil {
out["tools"] = tools
}
generationConfig := convertClaudeGenerationConfig(req)
if generationConfig != nil {
out["generationConfig"] = generationConfig
}
stripGeminiFunctionIDs(out)
return json.Marshal(out)
}
func stripGeminiFunctionIDs(req map[string]any) {
// Defensive cleanup: some upstreams reject unexpected `id` fields in functionCall/functionResponse.
contents, ok := req["contents"].([]any)
if !ok {
return
}
for _, c := range contents {
cm, ok := c.(map[string]any)
if !ok {
continue
}
contentParts, ok := cm["parts"].([]any)
if !ok {
continue
}
for _, p := range contentParts {
pm, ok := p.(map[string]any)
if !ok {
continue
}
if fc, ok := pm["functionCall"].(map[string]any); ok && fc != nil {
delete(fc, "id")
}
if fr, ok := pm["functionResponse"].(map[string]any); ok && fr != nil {
delete(fr, "id")
}
}
}
}
func extractClaudeSystemText(system any) string {
switch v := system.(type) {
case string:
return strings.TrimSpace(v)
case []any:
var parts []string
for _, p := range v {
pm, ok := p.(map[string]any)
if !ok {
continue
}
if t, _ := pm["type"].(string); t != "text" {
continue
}
if text, ok := pm["text"].(string); ok && strings.TrimSpace(text) != "" {
parts = append(parts, text)
}
}
return strings.TrimSpace(strings.Join(parts, "\n"))
default:
return ""
}
}
func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[string]string) ([]any, error) {
arr, ok := messages.([]any)
if !ok {
return nil, errors.New("messages must be an array")
}
out := make([]any, 0, len(arr))
for _, m := range arr {
mm, ok := m.(map[string]any)
if !ok {
continue
}
role, _ := mm["role"].(string)
role = strings.ToLower(strings.TrimSpace(role))
gRole := "user"
if role == "assistant" {
gRole = "model"
}
parts := make([]any, 0)
switch content := mm["content"].(type) {
case string:
// 字符串形式的 content保留所有内容包括空白
parts = append(parts, map[string]any{"text": content})
case []any:
// 如果只有一个 block不过滤空白让上游 API 报错)
singleBlock := len(content) == 1
for _, block := range content {
bm, ok := block.(map[string]any)
if !ok {
continue
}
bt, _ := bm["type"].(string)
switch bt {
case "text":
if text, ok := bm["text"].(string); ok {
// 单个 block 时保留所有内容(包括空白)
// 多个 blocks 时过滤掉空白
if singleBlock || strings.TrimSpace(text) != "" {
parts = append(parts, map[string]any{"text": text})
}
}
case "tool_use":
id, _ := bm["id"].(string)
name, _ := bm["name"].(string)
if strings.TrimSpace(id) != "" && strings.TrimSpace(name) != "" {
toolUseIDToName[id] = name
}
parts = append(parts, map[string]any{
"functionCall": map[string]any{
"name": name,
"args": bm["input"],
},
})
case "tool_result":
toolUseID, _ := bm["tool_use_id"].(string)
name := toolUseIDToName[toolUseID]
if name == "" {
name = "tool"
}
parts = append(parts, map[string]any{
"functionResponse": map[string]any{
"name": name,
"response": map[string]any{
"content": extractClaudeContentText(bm["content"]),
},
},
})
case "image":
if src, ok := bm["source"].(map[string]any); ok {
if srcType, _ := src["type"].(string); srcType == "base64" {
mediaType, _ := src["media_type"].(string)
data, _ := src["data"].(string)
if mediaType != "" && data != "" {
parts = append(parts, map[string]any{
"inlineData": map[string]any{
"mimeType": mediaType,
"data": data,
},
})
}
}
}
default:
// best-effort: preserve unknown blocks as text
if b, err := json.Marshal(bm); err == nil {
parts = append(parts, map[string]any{"text": string(b)})
}
}
}
default:
// ignore
}
out = append(out, map[string]any{
"role": gRole,
"parts": parts,
})
}
return out, nil
}
func extractClaudeContentText(v any) string {
switch t := v.(type) {
case string:
return t
case []any:
var sb strings.Builder
for _, part := range t {
pm, ok := part.(map[string]any)
if !ok {
continue
}
if pm["type"] == "text" {
if text, ok := pm["text"].(string); ok {
_, _ = sb.WriteString(text)
}
}
}
return sb.String()
default:
b, _ := json.Marshal(t)
return string(b)
}
}
func convertClaudeToolsToGeminiTools(tools any) []any {
arr, ok := tools.([]any)
if !ok || len(arr) == 0 {
return nil
}
funcDecls := make([]any, 0, len(arr))
for _, t := range arr {
tm, ok := t.(map[string]any)
if !ok {
continue
}
var name, desc string
var params any
// 检查是否为 custom 类型工具 (MCP)
toolType, _ := tm["type"].(string)
if toolType == "custom" {
// Custom 格式: 从 custom 字段获取 description 和 input_schema
custom, ok := tm["custom"].(map[string]any)
if !ok {
continue
}
name, _ = tm["name"].(string)
desc, _ = custom["description"].(string)
params = custom["input_schema"]
} else {
// 标准格式: 从顶层字段获取
name, _ = tm["name"].(string)
desc, _ = tm["description"].(string)
params = tm["input_schema"]
}
if name == "" {
continue
}
// 为 nil params 提供默认值
if params == nil {
params = map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
// 清理 JSON Schema
cleanedParams := cleanToolSchema(params)
funcDecls = append(funcDecls, map[string]any{
"name": name,
"description": desc,
"parameters": cleanedParams,
})
}
if len(funcDecls) == 0 {
return nil
}
return []any{
map[string]any{
"functionDeclarations": funcDecls,
},
}
}
// cleanToolSchema 清理工具的 JSON Schema移除 Gemini 不支持的字段
func cleanToolSchema(schema any) any {
if schema == nil {
return nil
}
switch v := schema.(type) {
case map[string]any:
cleaned := make(map[string]any)
for key, value := range v {
// 跳过不支持的字段
if key == "$schema" || key == "$id" || key == "$ref" ||
key == "additionalProperties" || key == "minLength" ||
key == "maxLength" || key == "minItems" || key == "maxItems" {
continue
}
// 递归清理嵌套对象
cleaned[key] = cleanToolSchema(value)
}
// 规范化 type 字段为大写
if typeVal, ok := cleaned["type"].(string); ok {
cleaned["type"] = strings.ToUpper(typeVal)
}
return cleaned
case []any:
cleaned := make([]any, len(v))
for i, item := range v {
cleaned[i] = cleanToolSchema(item)
}
return cleaned
default:
return v
}
}
func convertClaudeGenerationConfig(req map[string]any) map[string]any {
out := make(map[string]any)
if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 {
out["maxOutputTokens"] = mt
}
if temp, ok := req["temperature"].(float64); ok {
out["temperature"] = temp
}
if topP, ok := req["top_p"].(float64); ok {
out["topP"] = topP
}
if stopSeq, ok := req["stop_sequences"].([]any); ok && len(stopSeq) > 0 {
out["stopSequences"] = stopSeq
}
if len(out) == 0 {
return nil
}
return out
}