2550 lines
71 KiB
Go
2550 lines
71 KiB
Go
package service
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"context"
|
||
"crypto/rand"
|
||
"encoding/hex"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"log"
|
||
"math"
|
||
mathrand "math/rand"
|
||
"net/http"
|
||
"regexp"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
const geminiStickySessionTTL = time.Hour
|
||
|
||
const (
|
||
geminiMaxRetries = 5
|
||
geminiRetryBaseDelay = 1 * time.Second
|
||
geminiRetryMaxDelay = 16 * time.Second
|
||
)
|
||
|
||
type GeminiMessagesCompatService struct {
|
||
accountRepo AccountRepository
|
||
groupRepo GroupRepository
|
||
cache GatewayCache
|
||
tokenProvider *GeminiTokenProvider
|
||
rateLimitService *RateLimitService
|
||
httpUpstream HTTPUpstream
|
||
antigravityGatewayService *AntigravityGatewayService
|
||
cfg *config.Config
|
||
}
|
||
|
||
func NewGeminiMessagesCompatService(
|
||
accountRepo AccountRepository,
|
||
groupRepo GroupRepository,
|
||
cache GatewayCache,
|
||
tokenProvider *GeminiTokenProvider,
|
||
rateLimitService *RateLimitService,
|
||
httpUpstream HTTPUpstream,
|
||
antigravityGatewayService *AntigravityGatewayService,
|
||
cfg *config.Config,
|
||
) *GeminiMessagesCompatService {
|
||
return &GeminiMessagesCompatService{
|
||
accountRepo: accountRepo,
|
||
groupRepo: groupRepo,
|
||
cache: cache,
|
||
tokenProvider: tokenProvider,
|
||
rateLimitService: rateLimitService,
|
||
httpUpstream: httpUpstream,
|
||
antigravityGatewayService: antigravityGatewayService,
|
||
cfg: cfg,
|
||
}
|
||
}
|
||
|
||
// GetTokenProvider returns the token provider for OAuth accounts
|
||
func (s *GeminiMessagesCompatService) GetTokenProvider() *GeminiTokenProvider {
|
||
return s.tokenProvider
|
||
}
|
||
|
||
func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
|
||
return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil)
|
||
}
|
||
|
||
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||
// 优先检查 context 中的强制平台(/antigravity 路由)
|
||
var platform string
|
||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||
if hasForcePlatform && forcePlatform != "" {
|
||
platform = forcePlatform
|
||
} else if groupID != nil {
|
||
// 根据分组 platform 决定查询哪种账号
|
||
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("get group failed: %w", err)
|
||
}
|
||
platform = group.Platform
|
||
} else {
|
||
// 无分组时只使用原生 gemini 平台
|
||
platform = PlatformGemini
|
||
}
|
||
|
||
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
||
// 注意:强制平台模式不走混合调度
|
||
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
|
||
var queryPlatforms []string
|
||
if useMixedScheduling {
|
||
queryPlatforms = []string{PlatformGemini, PlatformAntigravity}
|
||
} else {
|
||
queryPlatforms = []string{platform}
|
||
}
|
||
|
||
cacheKey := "gemini:" + sessionHash
|
||
|
||
if sessionHash != "" {
|
||
accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey)
|
||
if err == nil && accountID > 0 {
|
||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
|
||
if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||
valid := false
|
||
if account.Platform == platform {
|
||
valid = true
|
||
} else if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
|
||
valid = true
|
||
}
|
||
if valid {
|
||
usable := true
|
||
if s.rateLimitService != nil && requestedModel != "" {
|
||
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
|
||
if err != nil {
|
||
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
|
||
}
|
||
if !ok {
|
||
usable = false
|
||
}
|
||
}
|
||
if usable {
|
||
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
|
||
return account, nil
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
|
||
var accounts []Account
|
||
var err error
|
||
if groupID != nil {
|
||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||
}
|
||
// 强制平台模式下,分组中找不到账户时回退查询全部
|
||
if len(accounts) == 0 && hasForcePlatform {
|
||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
||
}
|
||
} else {
|
||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
||
}
|
||
if err != nil {
|
||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||
}
|
||
|
||
var selected *Account
|
||
for i := range accounts {
|
||
acc := &accounts[i]
|
||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||
continue
|
||
}
|
||
// 混合调度模式下:原生平台直接通过,antigravity 需要启用 mixed_scheduling
|
||
// 非混合调度模式(antigravity 分组):不需要过滤
|
||
if useMixedScheduling && acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||
continue
|
||
}
|
||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||
continue
|
||
}
|
||
if s.rateLimitService != nil && requestedModel != "" {
|
||
ok, err := s.rateLimitService.PreCheckUsage(ctx, acc, requestedModel)
|
||
if err != nil {
|
||
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", acc.ID, err)
|
||
}
|
||
if !ok {
|
||
continue
|
||
}
|
||
}
|
||
if selected == nil {
|
||
selected = acc
|
||
continue
|
||
}
|
||
if acc.Priority < selected.Priority {
|
||
selected = acc
|
||
} else if acc.Priority == selected.Priority {
|
||
switch {
|
||
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
|
||
selected = acc
|
||
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
||
// keep selected (never used is preferred)
|
||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||
// Prefer OAuth accounts when both are unused (more compatible for Code Assist flows).
|
||
if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth {
|
||
selected = acc
|
||
}
|
||
default:
|
||
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
||
selected = acc
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if selected == nil {
|
||
if requestedModel != "" {
|
||
return nil, fmt.Errorf("no available Gemini accounts supporting model: %s", requestedModel)
|
||
}
|
||
return nil, errors.New("no available Gemini accounts")
|
||
}
|
||
|
||
if sessionHash != "" {
|
||
_ = s.cache.SetSessionAccountID(ctx, cacheKey, selected.ID, geminiStickySessionTTL)
|
||
}
|
||
|
||
return selected, nil
|
||
}
|
||
|
||
// isModelSupportedByAccount 根据账户平台检查模型支持
|
||
func (s *GeminiMessagesCompatService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
|
||
if account.Platform == PlatformAntigravity {
|
||
return IsAntigravityModelSupported(requestedModel)
|
||
}
|
||
return account.IsModelSupported(requestedModel)
|
||
}
|
||
|
||
// GetAntigravityGatewayService 返回 AntigravityGatewayService
|
||
func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *AntigravityGatewayService {
|
||
return s.antigravityGatewayService
|
||
}
|
||
|
||
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
|
||
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
|
||
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
||
if err != nil {
|
||
return "", fmt.Errorf("invalid base_url: %w", err)
|
||
}
|
||
return normalized, nil
|
||
}
|
||
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
||
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
|
||
RequireAllowlist: true,
|
||
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
|
||
})
|
||
if err != nil {
|
||
return "", fmt.Errorf("invalid base_url: %w", err)
|
||
}
|
||
return normalized, nil
|
||
}
|
||
|
||
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
|
||
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
|
||
var accounts []Account
|
||
var err error
|
||
if groupID != nil {
|
||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAntigravity)
|
||
} else {
|
||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAntigravity)
|
||
}
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
return len(accounts) > 0, nil
|
||
}
|
||
|
||
// SelectAccountForAIStudioEndpoints selects an account that is likely to succeed against
|
||
// generativelanguage.googleapis.com (e.g. GET /v1beta/models).
|
||
//
|
||
// Preference order:
|
||
// 1) API key accounts (AI Studio)
|
||
// 2) OAuth accounts without project_id (AI Studio OAuth)
|
||
// 3) OAuth accounts explicitly marked as ai_studio
|
||
// 4) Any remaining Gemini accounts (fallback)
|
||
func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*Account, error) {
|
||
var accounts []Account
|
||
var err error
|
||
if groupID != nil {
|
||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini)
|
||
} else {
|
||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini)
|
||
}
|
||
if err != nil {
|
||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||
}
|
||
if len(accounts) == 0 {
|
||
return nil, errors.New("no available Gemini accounts")
|
||
}
|
||
|
||
rank := func(a *Account) int {
|
||
if a == nil {
|
||
return 999
|
||
}
|
||
switch a.Type {
|
||
case AccountTypeAPIKey:
|
||
if strings.TrimSpace(a.GetCredential("api_key")) != "" {
|
||
return 0
|
||
}
|
||
return 9
|
||
case AccountTypeOAuth:
|
||
if strings.TrimSpace(a.GetCredential("project_id")) == "" {
|
||
return 1
|
||
}
|
||
if strings.TrimSpace(a.GetCredential("oauth_type")) == "ai_studio" {
|
||
return 2
|
||
}
|
||
// Code Assist OAuth tokens often lack AI Studio scopes for models listing.
|
||
return 3
|
||
default:
|
||
return 10
|
||
}
|
||
}
|
||
|
||
var selected *Account
|
||
for i := range accounts {
|
||
acc := &accounts[i]
|
||
if selected == nil {
|
||
selected = acc
|
||
continue
|
||
}
|
||
|
||
r1, r2 := rank(acc), rank(selected)
|
||
if r1 < r2 {
|
||
selected = acc
|
||
continue
|
||
}
|
||
if r1 > r2 {
|
||
continue
|
||
}
|
||
|
||
if acc.Priority < selected.Priority {
|
||
selected = acc
|
||
} else if acc.Priority == selected.Priority {
|
||
switch {
|
||
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
|
||
selected = acc
|
||
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
||
// keep selected
|
||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||
if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth {
|
||
selected = acc
|
||
}
|
||
default:
|
||
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
||
selected = acc
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if selected == nil {
|
||
return nil, errors.New("no available Gemini accounts")
|
||
}
|
||
return selected, nil
|
||
}
|
||
|
||
func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
||
startTime := time.Now()
|
||
|
||
var req struct {
|
||
Model string `json:"model"`
|
||
Stream bool `json:"stream"`
|
||
}
|
||
if err := json.Unmarshal(body, &req); err != nil {
|
||
return nil, fmt.Errorf("parse request: %w", err)
|
||
}
|
||
if strings.TrimSpace(req.Model) == "" {
|
||
return nil, fmt.Errorf("missing model")
|
||
}
|
||
|
||
originalModel := req.Model
|
||
mappedModel := req.Model
|
||
if account.Type == AccountTypeAPIKey {
|
||
mappedModel = account.GetMappedModel(req.Model)
|
||
}
|
||
|
||
geminiReq, err := convertClaudeMessagesToGeminiGenerateContent(body)
|
||
if err != nil {
|
||
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
|
||
}
|
||
originalClaudeBody := body
|
||
|
||
proxyURL := ""
|
||
if account.ProxyID != nil && account.Proxy != nil {
|
||
proxyURL = account.Proxy.URL()
|
||
}
|
||
|
||
var requestIDHeader string
|
||
var buildReq func(ctx context.Context) (*http.Request, string, error)
|
||
useUpstreamStream := req.Stream
|
||
if account.Type == AccountTypeOAuth && !req.Stream && strings.TrimSpace(account.GetCredential("project_id")) != "" {
|
||
// Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate.
|
||
useUpstreamStream = true
|
||
}
|
||
|
||
switch account.Type {
|
||
case AccountTypeAPIKey:
|
||
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||
apiKey := account.GetCredential("api_key")
|
||
if strings.TrimSpace(apiKey) == "" {
|
||
return nil, "", errors.New("gemini api_key not configured")
|
||
}
|
||
|
||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||
if baseURL == "" {
|
||
baseURL = geminicli.AIStudioBaseURL
|
||
}
|
||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||
if err != nil {
|
||
return nil, "", err
|
||
}
|
||
|
||
action := "generateContent"
|
||
if req.Stream {
|
||
action = "streamGenerateContent"
|
||
}
|
||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
|
||
if req.Stream {
|
||
fullURL += "?alt=sse"
|
||
}
|
||
|
||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiReq))
|
||
if err != nil {
|
||
return nil, "", err
|
||
}
|
||
upstreamReq.Header.Set("Content-Type", "application/json")
|
||
upstreamReq.Header.Set("x-goog-api-key", apiKey)
|
||
return upstreamReq, "x-request-id", nil
|
||
}
|
||
requestIDHeader = "x-request-id"
|
||
|
||
case AccountTypeOAuth:
|
||
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||
if s.tokenProvider == nil {
|
||
return nil, "", errors.New("gemini token provider not configured")
|
||
}
|
||
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
|
||
if err != nil {
|
||
return nil, "", err
|
||
}
|
||
|
||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||
|
||
action := "generateContent"
|
||
if useUpstreamStream {
|
||
action = "streamGenerateContent"
|
||
}
|
||
|
||
// Two modes for OAuth:
|
||
// 1. With project_id -> Code Assist API (wrapped request)
|
||
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
|
||
if projectID != "" {
|
||
// Mode 1: Code Assist API
|
||
baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
|
||
if err != nil {
|
||
return nil, "", err
|
||
}
|
||
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), action)
|
||
if useUpstreamStream {
|
||
fullURL += "?alt=sse"
|
||
}
|
||
|
||
wrapped := map[string]any{
|
||
"model": mappedModel,
|
||
"project": projectID,
|
||
}
|
||
var inner any
|
||
if err := json.Unmarshal(geminiReq, &inner); err != nil {
|
||
return nil, "", fmt.Errorf("failed to parse gemini request: %w", err)
|
||
}
|
||
wrapped["request"] = inner
|
||
wrappedBytes, _ := json.Marshal(wrapped)
|
||
|
||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBytes))
|
||
if err != nil {
|
||
return nil, "", err
|
||
}
|
||
upstreamReq.Header.Set("Content-Type", "application/json")
|
||
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||
upstreamReq.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
|
||
return upstreamReq, "x-request-id", nil
|
||
} else {
|
||
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
|
||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||
if baseURL == "" {
|
||
baseURL = geminicli.AIStudioBaseURL
|
||
}
|
||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||
if err != nil {
|
||
return nil, "", err
|
||
}
|
||
|
||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
|
||
if useUpstreamStream {
|
||
fullURL += "?alt=sse"
|
||
}
|
||
|
||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiReq))
|
||
if err != nil {
|
||
return nil, "", err
|
||
}
|
||
upstreamReq.Header.Set("Content-Type", "application/json")
|
||
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||
return upstreamReq, "x-request-id", nil
|
||
}
|
||
}
|
||
requestIDHeader = "x-request-id"
|
||
|
||
default:
|
||
return nil, fmt.Errorf("unsupported account type: %s", account.Type)
|
||
}
|
||
|
||
var resp *http.Response
|
||
signatureRetryStage := 0
|
||
for attempt := 1; attempt <= geminiMaxRetries; attempt++ {
|
||
upstreamReq, idHeader, err := buildReq(ctx)
|
||
if err != nil {
|
||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||
return nil, err
|
||
}
|
||
// Local build error: don't retry.
|
||
if strings.Contains(err.Error(), "missing project_id") {
|
||
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
|
||
}
|
||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", err.Error())
|
||
}
|
||
requestIDHeader = idHeader
|
||
|
||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||
if err != nil {
|
||
if attempt < geminiMaxRetries {
|
||
log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
|
||
sleepGeminiBackoff(attempt)
|
||
continue
|
||
}
|
||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries: "+sanitizeUpstreamErrorMessage(err.Error()))
|
||
}
|
||
|
||
// Special-case: signature/thought_signature validation errors are not transient, but may be fixed by
|
||
// downgrading Claude thinking/tool history to plain text (conservative two-stage retry).
|
||
if resp.StatusCode == http.StatusBadRequest && signatureRetryStage < 2 {
|
||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||
_ = resp.Body.Close()
|
||
|
||
if isGeminiSignatureRelatedError(respBody) {
|
||
var strippedClaudeBody []byte
|
||
stageName := ""
|
||
switch signatureRetryStage {
|
||
case 0:
|
||
// Stage 1: disable thinking + thinking->text
|
||
strippedClaudeBody = FilterThinkingBlocksForRetry(originalClaudeBody)
|
||
stageName = "thinking-only"
|
||
signatureRetryStage = 1
|
||
default:
|
||
// Stage 2: additionally downgrade tool_use/tool_result blocks to text
|
||
strippedClaudeBody = FilterSignatureSensitiveBlocksForRetry(originalClaudeBody)
|
||
stageName = "thinking+tools"
|
||
signatureRetryStage = 2
|
||
}
|
||
retryGeminiReq, txErr := convertClaudeMessagesToGeminiGenerateContent(strippedClaudeBody)
|
||
if txErr == nil {
|
||
log.Printf("Gemini account %d: detected signature-related 400, retrying with downgraded Claude blocks (%s)", account.ID, stageName)
|
||
geminiReq = retryGeminiReq
|
||
// Consume one retry budget attempt and continue with the updated request payload.
|
||
sleepGeminiBackoff(1)
|
||
continue
|
||
}
|
||
}
|
||
|
||
// Restore body for downstream error handling.
|
||
resp = &http.Response{
|
||
StatusCode: http.StatusBadRequest,
|
||
Header: resp.Header.Clone(),
|
||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||
}
|
||
break
|
||
}
|
||
|
||
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
|
||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||
_ = resp.Body.Close()
|
||
// Don't treat insufficient-scope as transient.
|
||
if resp.StatusCode == 403 && isGeminiInsufficientScope(resp.Header, respBody) {
|
||
resp = &http.Response{
|
||
StatusCode: resp.StatusCode,
|
||
Header: resp.Header.Clone(),
|
||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||
}
|
||
break
|
||
}
|
||
if resp.StatusCode == 429 {
|
||
// Mark as rate-limited early so concurrent requests avoid this account.
|
||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||
}
|
||
if attempt < geminiMaxRetries {
|
||
log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
|
||
sleepGeminiBackoff(attempt)
|
||
continue
|
||
}
|
||
// Final attempt: surface the upstream error body (mapped below) instead of a generic retry error.
|
||
resp = &http.Response{
|
||
StatusCode: resp.StatusCode,
|
||
Header: resp.Header.Clone(),
|
||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||
}
|
||
break
|
||
}
|
||
|
||
break
|
||
}
|
||
defer func() { _ = resp.Body.Close() }()
|
||
|
||
if resp.StatusCode >= 400 {
|
||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||
tempMatched := false
|
||
if s.rateLimitService != nil {
|
||
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
|
||
}
|
||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||
if tempMatched {
|
||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||
}
|
||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||
}
|
||
return nil, s.writeGeminiMappedError(c, resp.StatusCode, respBody)
|
||
}
|
||
|
||
requestID := resp.Header.Get(requestIDHeader)
|
||
if requestID == "" {
|
||
requestID = resp.Header.Get("x-goog-request-id")
|
||
}
|
||
if requestID != "" {
|
||
c.Header("x-request-id", requestID)
|
||
}
|
||
|
||
var usage *ClaudeUsage
|
||
var firstTokenMs *int
|
||
if req.Stream {
|
||
streamRes, err := s.handleStreamingResponse(c, resp, startTime, originalModel)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
usage = streamRes.usage
|
||
firstTokenMs = streamRes.firstTokenMs
|
||
} else {
|
||
if useUpstreamStream {
|
||
collected, usageObj, err := collectGeminiSSE(resp.Body, true)
|
||
if err != nil {
|
||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream stream")
|
||
}
|
||
claudeResp, usageObj2 := convertGeminiToClaudeMessage(collected, originalModel)
|
||
c.JSON(http.StatusOK, claudeResp)
|
||
usage = usageObj2
|
||
if usageObj != nil && (usageObj.InputTokens > 0 || usageObj.OutputTokens > 0) {
|
||
usage = usageObj
|
||
}
|
||
} else {
|
||
usage, err = s.handleNonStreamingResponse(c, resp, originalModel)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
}
|
||
|
||
return &ForwardResult{
|
||
RequestID: requestID,
|
||
Usage: *usage,
|
||
Model: originalModel,
|
||
Stream: req.Stream,
|
||
Duration: time.Since(startTime),
|
||
FirstTokenMs: firstTokenMs,
|
||
}, nil
|
||
}
|
||
|
||
func isGeminiSignatureRelatedError(respBody []byte) bool {
|
||
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
|
||
if msg == "" {
|
||
msg = strings.ToLower(string(respBody))
|
||
}
|
||
return strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature")
|
||
}
|
||
|
||
func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
|
||
startTime := time.Now()
|
||
|
||
if strings.TrimSpace(originalModel) == "" {
|
||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing model in URL")
|
||
}
|
||
if strings.TrimSpace(action) == "" {
|
||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing action in URL")
|
||
}
|
||
if len(body) == 0 {
|
||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
|
||
}
|
||
|
||
switch action {
|
||
case "generateContent", "streamGenerateContent", "countTokens":
|
||
// ok
|
||
default:
|
||
return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action)
|
||
}
|
||
|
||
mappedModel := originalModel
|
||
if account.Type == AccountTypeAPIKey {
|
||
mappedModel = account.GetMappedModel(originalModel)
|
||
}
|
||
|
||
proxyURL := ""
|
||
if account.ProxyID != nil && account.Proxy != nil {
|
||
proxyURL = account.Proxy.URL()
|
||
}
|
||
|
||
useUpstreamStream := stream
|
||
upstreamAction := action
|
||
if account.Type == AccountTypeOAuth && !stream && action == "generateContent" && strings.TrimSpace(account.GetCredential("project_id")) != "" {
|
||
// Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate.
|
||
useUpstreamStream = true
|
||
upstreamAction = "streamGenerateContent"
|
||
}
|
||
forceAIStudio := action == "countTokens"
|
||
|
||
var requestIDHeader string
|
||
var buildReq func(ctx context.Context) (*http.Request, string, error)
|
||
|
||
switch account.Type {
|
||
case AccountTypeAPIKey:
|
||
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||
apiKey := account.GetCredential("api_key")
|
||
if strings.TrimSpace(apiKey) == "" {
|
||
return nil, "", errors.New("gemini api_key not configured")
|
||
}
|
||
|
||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||
if baseURL == "" {
|
||
baseURL = geminicli.AIStudioBaseURL
|
||
}
|
||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||
if err != nil {
|
||
return nil, "", err
|
||
}
|
||
|
||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
|
||
if useUpstreamStream {
|
||
fullURL += "?alt=sse"
|
||
}
|
||
|
||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(body))
|
||
if err != nil {
|
||
return nil, "", err
|
||
}
|
||
upstreamReq.Header.Set("Content-Type", "application/json")
|
||
upstreamReq.Header.Set("x-goog-api-key", apiKey)
|
||
return upstreamReq, "x-request-id", nil
|
||
}
|
||
requestIDHeader = "x-request-id"
|
||
|
||
case AccountTypeOAuth:
|
||
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||
if s.tokenProvider == nil {
|
||
return nil, "", errors.New("gemini token provider not configured")
|
||
}
|
||
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
|
||
if err != nil {
|
||
return nil, "", err
|
||
}
|
||
|
||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||
|
||
// Two modes for OAuth:
|
||
// 1. With project_id -> Code Assist API (wrapped request)
|
||
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
|
||
if projectID != "" && !forceAIStudio {
|
||
// Mode 1: Code Assist API
|
||
baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
|
||
if err != nil {
|
||
return nil, "", err
|
||
}
|
||
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), upstreamAction)
|
||
if useUpstreamStream {
|
||
fullURL += "?alt=sse"
|
||
}
|
||
|
||
wrapped := map[string]any{
|
||
"model": mappedModel,
|
||
"project": projectID,
|
||
}
|
||
var inner any
|
||
if err := json.Unmarshal(body, &inner); err != nil {
|
||
return nil, "", fmt.Errorf("failed to parse gemini request: %w", err)
|
||
}
|
||
wrapped["request"] = inner
|
||
wrappedBytes, _ := json.Marshal(wrapped)
|
||
|
||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBytes))
|
||
if err != nil {
|
||
return nil, "", err
|
||
}
|
||
upstreamReq.Header.Set("Content-Type", "application/json")
|
||
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||
upstreamReq.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
|
||
return upstreamReq, "x-request-id", nil
|
||
} else {
|
||
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
|
||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||
if baseURL == "" {
|
||
baseURL = geminicli.AIStudioBaseURL
|
||
}
|
||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||
if err != nil {
|
||
return nil, "", err
|
||
}
|
||
|
||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
|
||
if useUpstreamStream {
|
||
fullURL += "?alt=sse"
|
||
}
|
||
|
||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(body))
|
||
if err != nil {
|
||
return nil, "", err
|
||
}
|
||
upstreamReq.Header.Set("Content-Type", "application/json")
|
||
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||
return upstreamReq, "x-request-id", nil
|
||
}
|
||
}
|
||
requestIDHeader = "x-request-id"
|
||
|
||
default:
|
||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Unsupported account type: "+account.Type)
|
||
}
|
||
|
||
var resp *http.Response
|
||
for attempt := 1; attempt <= geminiMaxRetries; attempt++ {
|
||
upstreamReq, idHeader, err := buildReq(ctx)
|
||
if err != nil {
|
||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||
return nil, err
|
||
}
|
||
// Local build error: don't retry.
|
||
if strings.Contains(err.Error(), "missing project_id") {
|
||
return nil, s.writeGoogleError(c, http.StatusBadRequest, err.Error())
|
||
}
|
||
return nil, s.writeGoogleError(c, http.StatusBadGateway, err.Error())
|
||
}
|
||
requestIDHeader = idHeader
|
||
|
||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||
if err != nil {
|
||
if attempt < geminiMaxRetries {
|
||
log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
|
||
sleepGeminiBackoff(attempt)
|
||
continue
|
||
}
|
||
if action == "countTokens" {
|
||
estimated := estimateGeminiCountTokens(body)
|
||
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
|
||
return &ForwardResult{
|
||
RequestID: "",
|
||
Usage: ClaudeUsage{},
|
||
Model: originalModel,
|
||
Stream: false,
|
||
Duration: time.Since(startTime),
|
||
FirstTokenMs: nil,
|
||
}, nil
|
||
}
|
||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries: "+sanitizeUpstreamErrorMessage(err.Error()))
|
||
}
|
||
|
||
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
|
||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||
_ = resp.Body.Close()
|
||
// Don't treat insufficient-scope as transient.
|
||
if resp.StatusCode == 403 && isGeminiInsufficientScope(resp.Header, respBody) {
|
||
resp = &http.Response{
|
||
StatusCode: resp.StatusCode,
|
||
Header: resp.Header.Clone(),
|
||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||
}
|
||
break
|
||
}
|
||
if resp.StatusCode == 429 {
|
||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||
}
|
||
if attempt < geminiMaxRetries {
|
||
log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
|
||
sleepGeminiBackoff(attempt)
|
||
continue
|
||
}
|
||
if action == "countTokens" {
|
||
estimated := estimateGeminiCountTokens(body)
|
||
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
|
||
return &ForwardResult{
|
||
RequestID: "",
|
||
Usage: ClaudeUsage{},
|
||
Model: originalModel,
|
||
Stream: false,
|
||
Duration: time.Since(startTime),
|
||
FirstTokenMs: nil,
|
||
}, nil
|
||
}
|
||
// Final attempt: surface the upstream error body (passed through below) instead of a generic retry error.
|
||
resp = &http.Response{
|
||
StatusCode: resp.StatusCode,
|
||
Header: resp.Header.Clone(),
|
||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||
}
|
||
break
|
||
}
|
||
|
||
break
|
||
}
|
||
defer func() { _ = resp.Body.Close() }()
|
||
|
||
requestID := resp.Header.Get(requestIDHeader)
|
||
if requestID == "" {
|
||
requestID = resp.Header.Get("x-goog-request-id")
|
||
}
|
||
if requestID != "" {
|
||
c.Header("x-request-id", requestID)
|
||
}
|
||
|
||
isOAuth := account.Type == AccountTypeOAuth
|
||
|
||
if resp.StatusCode >= 400 {
|
||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||
tempMatched := false
|
||
if s.rateLimitService != nil {
|
||
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
|
||
}
|
||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||
|
||
// Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens.
|
||
// This avoids Gemini SDKs failing hard during preflight token counting.
|
||
if action == "countTokens" && isOAuth && isGeminiInsufficientScope(resp.Header, respBody) {
|
||
estimated := estimateGeminiCountTokens(body)
|
||
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
|
||
return &ForwardResult{
|
||
RequestID: requestID,
|
||
Usage: ClaudeUsage{},
|
||
Model: originalModel,
|
||
Stream: false,
|
||
Duration: time.Since(startTime),
|
||
FirstTokenMs: nil,
|
||
}, nil
|
||
}
|
||
|
||
if tempMatched {
|
||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||
}
|
||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||
}
|
||
|
||
respBody = unwrapIfNeeded(isOAuth, respBody)
|
||
contentType := resp.Header.Get("Content-Type")
|
||
if contentType == "" {
|
||
contentType = "application/json"
|
||
}
|
||
c.Data(resp.StatusCode, contentType, respBody)
|
||
return nil, fmt.Errorf("gemini upstream error: %d", resp.StatusCode)
|
||
}
|
||
|
||
var usage *ClaudeUsage
|
||
var firstTokenMs *int
|
||
|
||
if stream {
|
||
streamRes, err := s.handleNativeStreamingResponse(c, resp, startTime, isOAuth)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
usage = streamRes.usage
|
||
firstTokenMs = streamRes.firstTokenMs
|
||
} else {
|
||
if useUpstreamStream {
|
||
collected, usageObj, err := collectGeminiSSE(resp.Body, isOAuth)
|
||
if err != nil {
|
||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Failed to read upstream stream")
|
||
}
|
||
b, _ := json.Marshal(collected)
|
||
c.Data(http.StatusOK, "application/json", b)
|
||
usage = usageObj
|
||
} else {
|
||
usageResp, err := s.handleNativeNonStreamingResponse(c, resp, isOAuth)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
usage = usageResp
|
||
}
|
||
}
|
||
|
||
if usage == nil {
|
||
usage = &ClaudeUsage{}
|
||
}
|
||
|
||
return &ForwardResult{
|
||
RequestID: requestID,
|
||
Usage: *usage,
|
||
Model: originalModel,
|
||
Stream: stream,
|
||
Duration: time.Since(startTime),
|
||
FirstTokenMs: firstTokenMs,
|
||
}, nil
|
||
}
|
||
|
||
func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Account, statusCode int) bool {
|
||
switch statusCode {
|
||
case 429, 500, 502, 503, 504, 529:
|
||
return true
|
||
case 403:
|
||
// GeminiCli OAuth occasionally returns 403 transiently (activation/quota propagation); allow retry.
|
||
if account == nil || account.Type != AccountTypeOAuth {
|
||
return false
|
||
}
|
||
oauthType := strings.ToLower(strings.TrimSpace(account.GetCredential("oauth_type")))
|
||
if oauthType == "" && strings.TrimSpace(account.GetCredential("project_id")) != "" {
|
||
// Legacy/implicit Code Assist OAuth accounts.
|
||
oauthType = "code_assist"
|
||
}
|
||
return oauthType == "code_assist"
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
func (s *GeminiMessagesCompatService) shouldFailoverGeminiUpstreamError(statusCode int) bool {
|
||
switch statusCode {
|
||
case 401, 403, 429, 529:
|
||
return true
|
||
default:
|
||
return statusCode >= 500
|
||
}
|
||
}
|
||
|
||
func sleepGeminiBackoff(attempt int) {
|
||
delay := geminiRetryBaseDelay * time.Duration(1<<uint(attempt-1))
|
||
if delay > geminiRetryMaxDelay {
|
||
delay = geminiRetryMaxDelay
|
||
}
|
||
|
||
// +/- 20% jitter
|
||
r := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
|
||
jitter := time.Duration(float64(delay) * 0.2 * (r.Float64()*2 - 1))
|
||
sleepFor := delay + jitter
|
||
if sleepFor < 0 {
|
||
sleepFor = 0
|
||
}
|
||
time.Sleep(sleepFor)
|
||
}
|
||
|
||
var (
|
||
sensitiveQueryParamRegex = regexp.MustCompile(`(?i)([?&](?:key|client_secret|access_token|refresh_token)=)[^&"\s]+`)
|
||
retryInRegex = regexp.MustCompile(`Please retry in ([0-9.]+)s`)
|
||
)
|
||
|
||
func sanitizeUpstreamErrorMessage(msg string) string {
|
||
if msg == "" {
|
||
return msg
|
||
}
|
||
return sensitiveQueryParamRegex.ReplaceAllString(msg, `$1***`)
|
||
}
|
||
|
||
func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, upstreamStatus int, body []byte) error {
|
||
var statusCode int
|
||
var errType, errMsg string
|
||
|
||
if mapped := mapGeminiErrorBodyToClaudeError(body); mapped != nil {
|
||
errType = mapped.Type
|
||
if mapped.Message != "" {
|
||
errMsg = mapped.Message
|
||
}
|
||
if mapped.StatusCode > 0 {
|
||
statusCode = mapped.StatusCode
|
||
}
|
||
}
|
||
|
||
switch upstreamStatus {
|
||
case 400:
|
||
if statusCode == 0 {
|
||
statusCode = http.StatusBadRequest
|
||
}
|
||
if errType == "" {
|
||
errType = "invalid_request_error"
|
||
}
|
||
if errMsg == "" {
|
||
errMsg = "Invalid request"
|
||
}
|
||
case 401:
|
||
if statusCode == 0 {
|
||
statusCode = http.StatusBadGateway
|
||
}
|
||
if errType == "" {
|
||
errType = "authentication_error"
|
||
}
|
||
if errMsg == "" {
|
||
errMsg = "Upstream authentication failed, please contact administrator"
|
||
}
|
||
case 403:
|
||
if statusCode == 0 {
|
||
statusCode = http.StatusBadGateway
|
||
}
|
||
if errType == "" {
|
||
errType = "permission_error"
|
||
}
|
||
if errMsg == "" {
|
||
errMsg = "Upstream access forbidden, please contact administrator"
|
||
}
|
||
case 404:
|
||
if statusCode == 0 {
|
||
statusCode = http.StatusNotFound
|
||
}
|
||
if errType == "" {
|
||
errType = "not_found_error"
|
||
}
|
||
if errMsg == "" {
|
||
errMsg = "Resource not found"
|
||
}
|
||
case 429:
|
||
if statusCode == 0 {
|
||
statusCode = http.StatusTooManyRequests
|
||
}
|
||
if errType == "" {
|
||
errType = "rate_limit_error"
|
||
}
|
||
if errMsg == "" {
|
||
errMsg = "Upstream rate limit exceeded, please retry later"
|
||
}
|
||
case 529:
|
||
if statusCode == 0 {
|
||
statusCode = http.StatusServiceUnavailable
|
||
}
|
||
if errType == "" {
|
||
errType = "overloaded_error"
|
||
}
|
||
if errMsg == "" {
|
||
errMsg = "Upstream service overloaded, please retry later"
|
||
}
|
||
case 500, 502, 503, 504:
|
||
if statusCode == 0 {
|
||
statusCode = http.StatusBadGateway
|
||
}
|
||
if errType == "" {
|
||
switch upstreamStatus {
|
||
case 504:
|
||
errType = "timeout_error"
|
||
case 503:
|
||
errType = "overloaded_error"
|
||
default:
|
||
errType = "api_error"
|
||
}
|
||
}
|
||
if errMsg == "" {
|
||
errMsg = "Upstream service temporarily unavailable"
|
||
}
|
||
default:
|
||
if statusCode == 0 {
|
||
statusCode = http.StatusBadGateway
|
||
}
|
||
if errType == "" {
|
||
errType = "upstream_error"
|
||
}
|
||
if errMsg == "" {
|
||
errMsg = "Upstream request failed"
|
||
}
|
||
}
|
||
|
||
c.JSON(statusCode, gin.H{
|
||
"type": "error",
|
||
"error": gin.H{"type": errType, "message": errMsg},
|
||
})
|
||
return fmt.Errorf("upstream error: %d", upstreamStatus)
|
||
}
|
||
|
||
type claudeErrorMapping struct {
|
||
Type string
|
||
Message string
|
||
StatusCode int
|
||
}
|
||
|
||
func mapGeminiErrorBodyToClaudeError(body []byte) *claudeErrorMapping {
|
||
if len(body) == 0 {
|
||
return nil
|
||
}
|
||
|
||
var parsed struct {
|
||
Error struct {
|
||
Code int `json:"code"`
|
||
Message string `json:"message"`
|
||
Status string `json:"status"`
|
||
} `json:"error"`
|
||
}
|
||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||
return nil
|
||
}
|
||
if strings.TrimSpace(parsed.Error.Status) == "" && parsed.Error.Code == 0 && strings.TrimSpace(parsed.Error.Message) == "" {
|
||
return nil
|
||
}
|
||
|
||
mapped := &claudeErrorMapping{
|
||
Type: mapGeminiStatusToClaudeErrorType(parsed.Error.Status),
|
||
Message: "",
|
||
}
|
||
if mapped.Type == "" {
|
||
mapped.Type = "upstream_error"
|
||
}
|
||
|
||
switch strings.ToUpper(strings.TrimSpace(parsed.Error.Status)) {
|
||
case "INVALID_ARGUMENT":
|
||
mapped.StatusCode = http.StatusBadRequest
|
||
case "NOT_FOUND":
|
||
mapped.StatusCode = http.StatusNotFound
|
||
case "RESOURCE_EXHAUSTED":
|
||
mapped.StatusCode = http.StatusTooManyRequests
|
||
default:
|
||
// Keep StatusCode unset and let HTTP status mapping decide.
|
||
}
|
||
|
||
// Keep messages generic by default; upstream error message can be long or include sensitive fragments.
|
||
return mapped
|
||
}
|
||
|
||
func mapGeminiStatusToClaudeErrorType(status string) string {
|
||
switch strings.ToUpper(strings.TrimSpace(status)) {
|
||
case "INVALID_ARGUMENT":
|
||
return "invalid_request_error"
|
||
case "PERMISSION_DENIED":
|
||
return "permission_error"
|
||
case "NOT_FOUND":
|
||
return "not_found_error"
|
||
case "RESOURCE_EXHAUSTED":
|
||
return "rate_limit_error"
|
||
case "UNAUTHENTICATED":
|
||
return "authentication_error"
|
||
case "UNAVAILABLE":
|
||
return "overloaded_error"
|
||
case "INTERNAL":
|
||
return "api_error"
|
||
case "DEADLINE_EXCEEDED":
|
||
return "timeout_error"
|
||
default:
|
||
return ""
|
||
}
|
||
}
|
||
|
||
type geminiStreamResult struct {
|
||
usage *ClaudeUsage
|
||
firstTokenMs *int
|
||
}
|
||
|
||
func (s *GeminiMessagesCompatService) handleNonStreamingResponse(c *gin.Context, resp *http.Response, originalModel string) (*ClaudeUsage, error) {
|
||
body, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
|
||
if err != nil {
|
||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response")
|
||
}
|
||
|
||
geminiResp, err := unwrapGeminiResponse(body)
|
||
if err != nil {
|
||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
|
||
}
|
||
|
||
claudeResp, usage := convertGeminiToClaudeMessage(geminiResp, originalModel)
|
||
c.JSON(http.StatusOK, claudeResp)
|
||
|
||
return usage, nil
|
||
}
|
||
|
||
func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*geminiStreamResult, error) {
|
||
c.Header("Content-Type", "text/event-stream")
|
||
c.Header("Cache-Control", "no-cache")
|
||
c.Header("Connection", "keep-alive")
|
||
c.Header("X-Accel-Buffering", "no")
|
||
c.Status(http.StatusOK)
|
||
|
||
flusher, ok := c.Writer.(http.Flusher)
|
||
if !ok {
|
||
return nil, errors.New("streaming not supported")
|
||
}
|
||
|
||
messageID := "msg_" + randomHex(12)
|
||
messageStart := map[string]any{
|
||
"type": "message_start",
|
||
"message": map[string]any{
|
||
"id": messageID,
|
||
"type": "message",
|
||
"role": "assistant",
|
||
"model": originalModel,
|
||
"content": []any{},
|
||
"stop_reason": nil,
|
||
"stop_sequence": nil,
|
||
"usage": map[string]any{
|
||
"input_tokens": 0,
|
||
"output_tokens": 0,
|
||
},
|
||
},
|
||
}
|
||
writeSSE(c.Writer, "message_start", messageStart)
|
||
flusher.Flush()
|
||
|
||
var firstTokenMs *int
|
||
var usage ClaudeUsage
|
||
finishReason := ""
|
||
sawToolUse := false
|
||
|
||
nextBlockIndex := 0
|
||
openBlockIndex := -1
|
||
openBlockType := ""
|
||
seenText := ""
|
||
openToolIndex := -1
|
||
openToolID := ""
|
||
openToolName := ""
|
||
seenToolJSON := ""
|
||
|
||
reader := bufio.NewReader(resp.Body)
|
||
for {
|
||
line, err := reader.ReadString('\n')
|
||
if err != nil && !errors.Is(err, io.EOF) {
|
||
return nil, fmt.Errorf("stream read error: %w", err)
|
||
}
|
||
|
||
if !strings.HasPrefix(line, "data:") {
|
||
if errors.Is(err, io.EOF) {
|
||
break
|
||
}
|
||
continue
|
||
}
|
||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||
if payload == "" || payload == "[DONE]" {
|
||
if errors.Is(err, io.EOF) {
|
||
break
|
||
}
|
||
continue
|
||
}
|
||
|
||
geminiResp, err := unwrapGeminiResponse([]byte(payload))
|
||
if err != nil {
|
||
continue
|
||
}
|
||
|
||
if fr := extractGeminiFinishReason(geminiResp); fr != "" {
|
||
finishReason = fr
|
||
}
|
||
|
||
parts := extractGeminiParts(geminiResp)
|
||
for _, part := range parts {
|
||
if text, ok := part["text"].(string); ok && text != "" {
|
||
delta, newSeen := computeGeminiTextDelta(seenText, text)
|
||
seenText = newSeen
|
||
if delta == "" {
|
||
continue
|
||
}
|
||
|
||
if openBlockType != "text" {
|
||
if openBlockIndex >= 0 {
|
||
writeSSE(c.Writer, "content_block_stop", map[string]any{
|
||
"type": "content_block_stop",
|
||
"index": openBlockIndex,
|
||
})
|
||
}
|
||
openBlockType = "text"
|
||
openBlockIndex = nextBlockIndex
|
||
nextBlockIndex++
|
||
writeSSE(c.Writer, "content_block_start", map[string]any{
|
||
"type": "content_block_start",
|
||
"index": openBlockIndex,
|
||
"content_block": map[string]any{
|
||
"type": "text",
|
||
"text": "",
|
||
},
|
||
})
|
||
}
|
||
|
||
if firstTokenMs == nil {
|
||
ms := int(time.Since(startTime).Milliseconds())
|
||
firstTokenMs = &ms
|
||
}
|
||
writeSSE(c.Writer, "content_block_delta", map[string]any{
|
||
"type": "content_block_delta",
|
||
"index": openBlockIndex,
|
||
"delta": map[string]any{
|
||
"type": "text_delta",
|
||
"text": delta,
|
||
},
|
||
})
|
||
flusher.Flush()
|
||
continue
|
||
}
|
||
|
||
if fc, ok := part["functionCall"].(map[string]any); ok && fc != nil {
|
||
name, _ := fc["name"].(string)
|
||
args := fc["args"]
|
||
if strings.TrimSpace(name) == "" {
|
||
name = "tool"
|
||
}
|
||
|
||
// Close any open text block before tool_use.
|
||
if openBlockIndex >= 0 {
|
||
writeSSE(c.Writer, "content_block_stop", map[string]any{
|
||
"type": "content_block_stop",
|
||
"index": openBlockIndex,
|
||
})
|
||
openBlockIndex = -1
|
||
openBlockType = ""
|
||
}
|
||
|
||
// If we receive streamed tool args in pieces, keep a single tool block open and emit deltas.
|
||
if openToolIndex >= 0 && openToolName != name {
|
||
writeSSE(c.Writer, "content_block_stop", map[string]any{
|
||
"type": "content_block_stop",
|
||
"index": openToolIndex,
|
||
})
|
||
openToolIndex = -1
|
||
openToolName = ""
|
||
seenToolJSON = ""
|
||
}
|
||
|
||
if openToolIndex < 0 {
|
||
openToolID = "toolu_" + randomHex(8)
|
||
openToolIndex = nextBlockIndex
|
||
openToolName = name
|
||
nextBlockIndex++
|
||
sawToolUse = true
|
||
|
||
writeSSE(c.Writer, "content_block_start", map[string]any{
|
||
"type": "content_block_start",
|
||
"index": openToolIndex,
|
||
"content_block": map[string]any{
|
||
"type": "tool_use",
|
||
"id": openToolID,
|
||
"name": name,
|
||
"input": map[string]any{},
|
||
},
|
||
})
|
||
}
|
||
|
||
argsJSONText := "{}"
|
||
switch v := args.(type) {
|
||
case nil:
|
||
// keep default "{}"
|
||
case string:
|
||
if strings.TrimSpace(v) != "" {
|
||
argsJSONText = v
|
||
}
|
||
default:
|
||
if b, err := json.Marshal(args); err == nil && len(b) > 0 {
|
||
argsJSONText = string(b)
|
||
}
|
||
}
|
||
|
||
delta, newSeen := computeGeminiTextDelta(seenToolJSON, argsJSONText)
|
||
seenToolJSON = newSeen
|
||
if delta != "" {
|
||
writeSSE(c.Writer, "content_block_delta", map[string]any{
|
||
"type": "content_block_delta",
|
||
"index": openToolIndex,
|
||
"delta": map[string]any{
|
||
"type": "input_json_delta",
|
||
"partial_json": delta,
|
||
},
|
||
})
|
||
}
|
||
flusher.Flush()
|
||
}
|
||
}
|
||
|
||
if u := extractGeminiUsage(geminiResp); u != nil {
|
||
usage = *u
|
||
}
|
||
|
||
// Process the final unterminated line at EOF as well.
|
||
if errors.Is(err, io.EOF) {
|
||
break
|
||
}
|
||
}
|
||
|
||
if openBlockIndex >= 0 {
|
||
writeSSE(c.Writer, "content_block_stop", map[string]any{
|
||
"type": "content_block_stop",
|
||
"index": openBlockIndex,
|
||
})
|
||
}
|
||
if openToolIndex >= 0 {
|
||
writeSSE(c.Writer, "content_block_stop", map[string]any{
|
||
"type": "content_block_stop",
|
||
"index": openToolIndex,
|
||
})
|
||
}
|
||
|
||
stopReason := mapGeminiFinishReasonToClaudeStopReason(finishReason)
|
||
if sawToolUse {
|
||
stopReason = "tool_use"
|
||
}
|
||
|
||
usageObj := map[string]any{
|
||
"output_tokens": usage.OutputTokens,
|
||
}
|
||
if usage.InputTokens > 0 {
|
||
usageObj["input_tokens"] = usage.InputTokens
|
||
}
|
||
writeSSE(c.Writer, "message_delta", map[string]any{
|
||
"type": "message_delta",
|
||
"delta": map[string]any{
|
||
"stop_reason": stopReason,
|
||
"stop_sequence": nil,
|
||
},
|
||
"usage": usageObj,
|
||
})
|
||
writeSSE(c.Writer, "message_stop", map[string]any{
|
||
"type": "message_stop",
|
||
})
|
||
flusher.Flush()
|
||
|
||
return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
|
||
}
|
||
|
||
func writeSSE(w io.Writer, event string, data any) {
|
||
if event != "" {
|
||
_, _ = fmt.Fprintf(w, "event: %s\n", event)
|
||
}
|
||
b, _ := json.Marshal(data)
|
||
_, _ = fmt.Fprintf(w, "data: %s\n\n", string(b))
|
||
}
|
||
|
||
func randomHex(nBytes int) string {
|
||
b := make([]byte, nBytes)
|
||
_, _ = rand.Read(b)
|
||
return hex.EncodeToString(b)
|
||
}
|
||
|
||
func (s *GeminiMessagesCompatService) writeClaudeError(c *gin.Context, status int, errType, message string) error {
|
||
c.JSON(status, gin.H{
|
||
"type": "error",
|
||
"error": gin.H{"type": errType, "message": message},
|
||
})
|
||
return fmt.Errorf("%s", message)
|
||
}
|
||
|
||
func (s *GeminiMessagesCompatService) writeGoogleError(c *gin.Context, status int, message string) error {
|
||
c.JSON(status, gin.H{
|
||
"error": gin.H{
|
||
"code": status,
|
||
"message": message,
|
||
"status": googleapi.HTTPStatusToGoogleStatus(status),
|
||
},
|
||
})
|
||
return fmt.Errorf("%s", message)
|
||
}
|
||
|
||
func unwrapIfNeeded(isOAuth bool, raw []byte) []byte {
|
||
if !isOAuth {
|
||
return raw
|
||
}
|
||
inner, err := unwrapGeminiResponse(raw)
|
||
if err != nil {
|
||
return raw
|
||
}
|
||
b, err := json.Marshal(inner)
|
||
if err != nil {
|
||
return raw
|
||
}
|
||
return b
|
||
}
|
||
|
||
func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsage, error) {
|
||
reader := bufio.NewReader(body)
|
||
|
||
var last map[string]any
|
||
var lastWithParts map[string]any
|
||
usage := &ClaudeUsage{}
|
||
|
||
for {
|
||
line, err := reader.ReadString('\n')
|
||
if len(line) > 0 {
|
||
trimmed := strings.TrimRight(line, "\r\n")
|
||
if strings.HasPrefix(trimmed, "data:") {
|
||
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
|
||
switch payload {
|
||
case "", "[DONE]":
|
||
if payload == "[DONE]" {
|
||
return pickGeminiCollectResult(last, lastWithParts), usage, nil
|
||
}
|
||
default:
|
||
var parsed map[string]any
|
||
if isOAuth {
|
||
inner, err := unwrapGeminiResponse([]byte(payload))
|
||
if err == nil && inner != nil {
|
||
parsed = inner
|
||
}
|
||
} else {
|
||
_ = json.Unmarshal([]byte(payload), &parsed)
|
||
}
|
||
if parsed != nil {
|
||
last = parsed
|
||
if u := extractGeminiUsage(parsed); u != nil {
|
||
usage = u
|
||
}
|
||
if parts := extractGeminiParts(parsed); len(parts) > 0 {
|
||
lastWithParts = parsed
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if errors.Is(err, io.EOF) {
|
||
break
|
||
}
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
}
|
||
|
||
return pickGeminiCollectResult(last, lastWithParts), usage, nil
|
||
}
|
||
|
||
func pickGeminiCollectResult(last map[string]any, lastWithParts map[string]any) map[string]any {
|
||
if lastWithParts != nil {
|
||
return lastWithParts
|
||
}
|
||
if last != nil {
|
||
return last
|
||
}
|
||
return map[string]any{}
|
||
}
|
||
|
||
type geminiNativeStreamResult struct {
|
||
usage *ClaudeUsage
|
||
firstTokenMs *int
|
||
}
|
||
|
||
func isGeminiInsufficientScope(headers http.Header, body []byte) bool {
|
||
if strings.Contains(strings.ToLower(headers.Get("Www-Authenticate")), "insufficient_scope") {
|
||
return true
|
||
}
|
||
lower := strings.ToLower(string(body))
|
||
return strings.Contains(lower, "insufficient authentication scopes") || strings.Contains(lower, "access_token_scope_insufficient")
|
||
}
|
||
|
||
func estimateGeminiCountTokens(reqBody []byte) int {
|
||
var obj map[string]any
|
||
if err := json.Unmarshal(reqBody, &obj); err != nil {
|
||
return 0
|
||
}
|
||
|
||
var texts []string
|
||
|
||
// systemInstruction.parts[].text
|
||
if si, ok := obj["systemInstruction"].(map[string]any); ok {
|
||
if parts, ok := si["parts"].([]any); ok {
|
||
for _, p := range parts {
|
||
if pm, ok := p.(map[string]any); ok {
|
||
if t, ok := pm["text"].(string); ok && strings.TrimSpace(t) != "" {
|
||
texts = append(texts, t)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// contents[].parts[].text
|
||
if contents, ok := obj["contents"].([]any); ok {
|
||
for _, c := range contents {
|
||
cm, ok := c.(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
parts, ok := cm["parts"].([]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
for _, p := range parts {
|
||
pm, ok := p.(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
if t, ok := pm["text"].(string); ok && strings.TrimSpace(t) != "" {
|
||
texts = append(texts, t)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
total := 0
|
||
for _, t := range texts {
|
||
total += estimateTokensForText(t)
|
||
}
|
||
if total < 0 {
|
||
return 0
|
||
}
|
||
return total
|
||
}
|
||
|
||
func estimateTokensForText(s string) int {
|
||
s = strings.TrimSpace(s)
|
||
if s == "" {
|
||
return 0
|
||
}
|
||
runes := []rune(s)
|
||
if len(runes) == 0 {
|
||
return 0
|
||
}
|
||
ascii := 0
|
||
for _, r := range runes {
|
||
if r <= 0x7f {
|
||
ascii++
|
||
}
|
||
}
|
||
asciiRatio := float64(ascii) / float64(len(runes))
|
||
if asciiRatio >= 0.8 {
|
||
// Roughly 4 chars per token for English-like text.
|
||
return (len(runes) + 3) / 4
|
||
}
|
||
// For CJK-heavy text, approximate 1 rune per token.
|
||
return len(runes)
|
||
}
|
||
|
||
type UpstreamHTTPResult struct {
|
||
StatusCode int
|
||
Headers http.Header
|
||
Body []byte
|
||
}
|
||
|
||
func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) {
|
||
// Log response headers for debugging
|
||
log.Printf("[GeminiAPI] ========== Response Headers ==========")
|
||
for key, values := range resp.Header {
|
||
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
||
log.Printf("[GeminiAPI] %s: %v", key, values)
|
||
}
|
||
}
|
||
log.Printf("[GeminiAPI] ========================================")
|
||
|
||
respBody, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var parsed map[string]any
|
||
if isOAuth {
|
||
parsed, err = unwrapGeminiResponse(respBody)
|
||
if err == nil && parsed != nil {
|
||
respBody, _ = json.Marshal(parsed)
|
||
}
|
||
} else {
|
||
_ = json.Unmarshal(respBody, &parsed)
|
||
}
|
||
|
||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||
|
||
contentType := resp.Header.Get("Content-Type")
|
||
if contentType == "" {
|
||
contentType = "application/json"
|
||
}
|
||
c.Data(resp.StatusCode, contentType, respBody)
|
||
|
||
if parsed != nil {
|
||
if u := extractGeminiUsage(parsed); u != nil {
|
||
return u, nil
|
||
}
|
||
}
|
||
return &ClaudeUsage{}, nil
|
||
}
|
||
|
||
func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) {
|
||
// Log response headers for debugging
|
||
log.Printf("[GeminiAPI] ========== Streaming Response Headers ==========")
|
||
for key, values := range resp.Header {
|
||
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
||
log.Printf("[GeminiAPI] %s: %v", key, values)
|
||
}
|
||
}
|
||
log.Printf("[GeminiAPI] ====================================================")
|
||
|
||
if s.cfg != nil {
|
||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||
}
|
||
|
||
c.Status(resp.StatusCode)
|
||
c.Header("Cache-Control", "no-cache")
|
||
c.Header("Connection", "keep-alive")
|
||
c.Header("X-Accel-Buffering", "no")
|
||
|
||
contentType := resp.Header.Get("Content-Type")
|
||
if contentType == "" {
|
||
contentType = "text/event-stream; charset=utf-8"
|
||
}
|
||
c.Header("Content-Type", contentType)
|
||
|
||
flusher, ok := c.Writer.(http.Flusher)
|
||
if !ok {
|
||
return nil, errors.New("streaming not supported")
|
||
}
|
||
|
||
reader := bufio.NewReader(resp.Body)
|
||
usage := &ClaudeUsage{}
|
||
var firstTokenMs *int
|
||
|
||
for {
|
||
line, err := reader.ReadString('\n')
|
||
if len(line) > 0 {
|
||
trimmed := strings.TrimRight(line, "\r\n")
|
||
if strings.HasPrefix(trimmed, "data:") {
|
||
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
|
||
// Keepalive / done markers
|
||
if payload == "" || payload == "[DONE]" {
|
||
_, _ = io.WriteString(c.Writer, line)
|
||
flusher.Flush()
|
||
} else {
|
||
var rawToWrite string
|
||
rawToWrite = payload
|
||
|
||
var parsed map[string]any
|
||
if isOAuth {
|
||
inner, err := unwrapGeminiResponse([]byte(payload))
|
||
if err == nil && inner != nil {
|
||
parsed = inner
|
||
if b, err := json.Marshal(inner); err == nil {
|
||
rawToWrite = string(b)
|
||
}
|
||
}
|
||
} else {
|
||
_ = json.Unmarshal([]byte(payload), &parsed)
|
||
}
|
||
|
||
if parsed != nil {
|
||
if u := extractGeminiUsage(parsed); u != nil {
|
||
usage = u
|
||
}
|
||
}
|
||
|
||
if firstTokenMs == nil {
|
||
ms := int(time.Since(startTime).Milliseconds())
|
||
firstTokenMs = &ms
|
||
}
|
||
|
||
if isOAuth {
|
||
// SSE format requires double newline (\n\n) to separate events
|
||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", rawToWrite)
|
||
} else {
|
||
// Pass-through for AI Studio responses.
|
||
_, _ = io.WriteString(c.Writer, line)
|
||
}
|
||
flusher.Flush()
|
||
}
|
||
} else {
|
||
_, _ = io.WriteString(c.Writer, line)
|
||
flusher.Flush()
|
||
}
|
||
}
|
||
|
||
if errors.Is(err, io.EOF) {
|
||
break
|
||
}
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
|
||
return &geminiNativeStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||
}
|
||
|
||
// ForwardAIStudioGET forwards a GET request to AI Studio (generativelanguage.googleapis.com) for
|
||
// endpoints like /v1beta/models and /v1beta/models/{model}.
|
||
//
|
||
// This is used to support Gemini SDKs that call models listing endpoints before generation.
|
||
func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, account *Account, path string) (*UpstreamHTTPResult, error) {
|
||
if account == nil {
|
||
return nil, errors.New("account is nil")
|
||
}
|
||
path = strings.TrimSpace(path)
|
||
if path == "" || !strings.HasPrefix(path, "/") {
|
||
return nil, errors.New("invalid path")
|
||
}
|
||
|
||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||
if baseURL == "" {
|
||
baseURL = geminicli.AIStudioBaseURL
|
||
}
|
||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
fullURL := strings.TrimRight(normalizedBaseURL, "/") + path
|
||
|
||
var proxyURL string
|
||
if account.ProxyID != nil && account.Proxy != nil {
|
||
proxyURL = account.Proxy.URL()
|
||
}
|
||
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
switch account.Type {
|
||
case AccountTypeAPIKey:
|
||
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
|
||
if apiKey == "" {
|
||
return nil, errors.New("gemini api_key not configured")
|
||
}
|
||
req.Header.Set("x-goog-api-key", apiKey)
|
||
case AccountTypeOAuth:
|
||
if s.tokenProvider == nil {
|
||
return nil, errors.New("gemini token provider not configured")
|
||
}
|
||
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||
default:
|
||
return nil, fmt.Errorf("unsupported account type: %s", account.Type)
|
||
}
|
||
|
||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() { _ = resp.Body.Close() }()
|
||
|
||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
|
||
wwwAuthenticate := resp.Header.Get("Www-Authenticate")
|
||
filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.cfg.Security.ResponseHeaders)
|
||
if wwwAuthenticate != "" {
|
||
filteredHeaders.Set("Www-Authenticate", wwwAuthenticate)
|
||
}
|
||
return &UpstreamHTTPResult{
|
||
StatusCode: resp.StatusCode,
|
||
Headers: filteredHeaders,
|
||
Body: body,
|
||
}, nil
|
||
}
|
||
|
||
func unwrapGeminiResponse(raw []byte) (map[string]any, error) {
|
||
var outer map[string]any
|
||
if err := json.Unmarshal(raw, &outer); err != nil {
|
||
return nil, err
|
||
}
|
||
if resp, ok := outer["response"].(map[string]any); ok && resp != nil {
|
||
return resp, nil
|
||
}
|
||
return outer, nil
|
||
}
|
||
|
||
func convertGeminiToClaudeMessage(geminiResp map[string]any, originalModel string) (map[string]any, *ClaudeUsage) {
|
||
usage := extractGeminiUsage(geminiResp)
|
||
if usage == nil {
|
||
usage = &ClaudeUsage{}
|
||
}
|
||
|
||
contentBlocks := make([]any, 0)
|
||
sawToolUse := false
|
||
if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 {
|
||
if cand, ok := candidates[0].(map[string]any); ok {
|
||
if content, ok := cand["content"].(map[string]any); ok {
|
||
if parts, ok := content["parts"].([]any); ok {
|
||
for _, part := range parts {
|
||
pm, ok := part.(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
if text, ok := pm["text"].(string); ok && text != "" {
|
||
contentBlocks = append(contentBlocks, map[string]any{
|
||
"type": "text",
|
||
"text": text,
|
||
})
|
||
}
|
||
if fc, ok := pm["functionCall"].(map[string]any); ok {
|
||
name, _ := fc["name"].(string)
|
||
if strings.TrimSpace(name) == "" {
|
||
name = "tool"
|
||
}
|
||
args := fc["args"]
|
||
sawToolUse = true
|
||
contentBlocks = append(contentBlocks, map[string]any{
|
||
"type": "tool_use",
|
||
"id": "toolu_" + randomHex(8),
|
||
"name": name,
|
||
"input": args,
|
||
})
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
stopReason := mapGeminiFinishReasonToClaudeStopReason(extractGeminiFinishReason(geminiResp))
|
||
if sawToolUse {
|
||
stopReason = "tool_use"
|
||
}
|
||
|
||
resp := map[string]any{
|
||
"id": "msg_" + randomHex(12),
|
||
"type": "message",
|
||
"role": "assistant",
|
||
"model": originalModel,
|
||
"content": contentBlocks,
|
||
"stop_reason": stopReason,
|
||
"stop_sequence": nil,
|
||
"usage": map[string]any{
|
||
"input_tokens": usage.InputTokens,
|
||
"output_tokens": usage.OutputTokens,
|
||
},
|
||
}
|
||
|
||
return resp, usage
|
||
}
|
||
|
||
func extractGeminiUsage(geminiResp map[string]any) *ClaudeUsage {
|
||
usageMeta, ok := geminiResp["usageMetadata"].(map[string]any)
|
||
if !ok || usageMeta == nil {
|
||
return nil
|
||
}
|
||
prompt, _ := asInt(usageMeta["promptTokenCount"])
|
||
cand, _ := asInt(usageMeta["candidatesTokenCount"])
|
||
return &ClaudeUsage{
|
||
InputTokens: prompt,
|
||
OutputTokens: cand,
|
||
}
|
||
}
|
||
|
||
func asInt(v any) (int, bool) {
|
||
switch t := v.(type) {
|
||
case float64:
|
||
return int(t), true
|
||
case int:
|
||
return t, true
|
||
case int64:
|
||
return int(t), true
|
||
case json.Number:
|
||
i, err := t.Int64()
|
||
if err != nil {
|
||
return 0, false
|
||
}
|
||
return int(i), true
|
||
default:
|
||
return 0, false
|
||
}
|
||
}
|
||
|
||
func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) {
|
||
if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) {
|
||
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
|
||
return
|
||
}
|
||
if statusCode != 429 {
|
||
return
|
||
}
|
||
|
||
oauthType := account.GeminiOAuthType()
|
||
tierID := account.GeminiTierID()
|
||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||
isCodeAssist := account.IsGeminiCodeAssist()
|
||
|
||
resetAt := ParseGeminiRateLimitResetTime(body)
|
||
if resetAt == nil {
|
||
// 根据账号类型使用不同的默认重置时间
|
||
var ra time.Time
|
||
if isCodeAssist {
|
||
// Code Assist: fallback cooldown by tier
|
||
cooldown := geminiCooldownForTier(tierID)
|
||
if s.rateLimitService != nil {
|
||
cooldown = s.rateLimitService.GeminiCooldown(ctx, account)
|
||
}
|
||
ra = time.Now().Add(cooldown)
|
||
log.Printf("[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second))
|
||
} else {
|
||
// API Key / AI Studio OAuth: PST 午夜
|
||
if ts := nextGeminiDailyResetUnix(); ts != nil {
|
||
ra = time.Unix(*ts, 0)
|
||
log.Printf("[Gemini 429] Account %d (API Key/AI Studio, type=%s) rate limited, reset at PST midnight (%v)", account.ID, account.Type, ra)
|
||
} else {
|
||
// 兜底:5 分钟
|
||
ra = time.Now().Add(5 * time.Minute)
|
||
log.Printf("[Gemini 429] Account %d rate limited, fallback to 5min", account.ID)
|
||
}
|
||
}
|
||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra)
|
||
return
|
||
}
|
||
|
||
// 使用解析到的重置时间
|
||
resetTime := time.Unix(*resetAt, 0)
|
||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, resetTime)
|
||
log.Printf("[Gemini 429] Account %d rate limited until %v (oauth_type=%s, tier=%s)",
|
||
account.ID, resetTime, oauthType, tierID)
|
||
}
|
||
|
||
// ParseGeminiRateLimitResetTime 解析 Gemini 格式的 429 响应,返回重置时间的 Unix 时间戳
|
||
func ParseGeminiRateLimitResetTime(body []byte) *int64 {
|
||
// Try to parse metadata.quotaResetDelay like "12.345s"
|
||
var parsed map[string]any
|
||
if err := json.Unmarshal(body, &parsed); err == nil {
|
||
if errObj, ok := parsed["error"].(map[string]any); ok {
|
||
if msg, ok := errObj["message"].(string); ok {
|
||
if looksLikeGeminiDailyQuota(msg) {
|
||
if ts := nextGeminiDailyResetUnix(); ts != nil {
|
||
return ts
|
||
}
|
||
}
|
||
}
|
||
if details, ok := errObj["details"].([]any); ok {
|
||
for _, d := range details {
|
||
dm, ok := d.(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
if meta, ok := dm["metadata"].(map[string]any); ok {
|
||
if v, ok := meta["quotaResetDelay"].(string); ok {
|
||
if dur, err := time.ParseDuration(v); err == nil {
|
||
ts := time.Now().Unix() + int64(dur.Seconds())
|
||
return &ts
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// Match "Please retry in Xs"
|
||
matches := retryInRegex.FindStringSubmatch(string(body))
|
||
if len(matches) == 2 {
|
||
if dur, err := time.ParseDuration(matches[1] + "s"); err == nil {
|
||
ts := time.Now().Unix() + int64(math.Ceil(dur.Seconds()))
|
||
return &ts
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func looksLikeGeminiDailyQuota(message string) bool {
|
||
m := strings.ToLower(message)
|
||
if strings.Contains(m, "per day") || strings.Contains(m, "requests per day") || strings.Contains(m, "quota") && strings.Contains(m, "per day") {
|
||
return true
|
||
}
|
||
return false
|
||
}
|
||
|
||
func nextGeminiDailyResetUnix() *int64 {
|
||
reset := geminiDailyResetTime(time.Now())
|
||
ts := reset.Unix()
|
||
return &ts
|
||
}
|
||
|
||
func extractGeminiFinishReason(geminiResp map[string]any) string {
|
||
if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 {
|
||
if cand, ok := candidates[0].(map[string]any); ok {
|
||
if fr, ok := cand["finishReason"].(string); ok {
|
||
return fr
|
||
}
|
||
}
|
||
}
|
||
return ""
|
||
}
|
||
|
||
func extractGeminiParts(geminiResp map[string]any) []map[string]any {
|
||
if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 {
|
||
if cand, ok := candidates[0].(map[string]any); ok {
|
||
if content, ok := cand["content"].(map[string]any); ok {
|
||
if partsAny, ok := content["parts"].([]any); ok && len(partsAny) > 0 {
|
||
out := make([]map[string]any, 0, len(partsAny))
|
||
for _, p := range partsAny {
|
||
pm, ok := p.(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
out = append(out, pm)
|
||
}
|
||
return out
|
||
}
|
||
}
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func computeGeminiTextDelta(seen, incoming string) (delta, newSeen string) {
|
||
incoming = strings.TrimSuffix(incoming, "\u0000")
|
||
if incoming == "" {
|
||
return "", seen
|
||
}
|
||
|
||
// Cumulative mode: incoming contains full text so far.
|
||
if strings.HasPrefix(incoming, seen) {
|
||
return strings.TrimPrefix(incoming, seen), incoming
|
||
}
|
||
// Duplicate/rewind: ignore.
|
||
if strings.HasPrefix(seen, incoming) {
|
||
return "", seen
|
||
}
|
||
// Delta mode: treat incoming as incremental chunk.
|
||
return incoming, seen + incoming
|
||
}
|
||
|
||
func mapGeminiFinishReasonToClaudeStopReason(finishReason string) string {
|
||
switch strings.ToUpper(strings.TrimSpace(finishReason)) {
|
||
case "MAX_TOKENS":
|
||
return "max_tokens"
|
||
case "STOP":
|
||
return "end_turn"
|
||
default:
|
||
return "end_turn"
|
||
}
|
||
}
|
||
|
||
func convertClaudeMessagesToGeminiGenerateContent(body []byte) ([]byte, error) {
|
||
var req map[string]any
|
||
if err := json.Unmarshal(body, &req); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
toolUseIDToName := make(map[string]string)
|
||
|
||
systemText := extractClaudeSystemText(req["system"])
|
||
contents, err := convertClaudeMessagesToGeminiContents(req["messages"], toolUseIDToName)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
out := make(map[string]any)
|
||
if systemText != "" {
|
||
out["systemInstruction"] = map[string]any{
|
||
"parts": []any{map[string]any{"text": systemText}},
|
||
}
|
||
}
|
||
out["contents"] = contents
|
||
|
||
if tools := convertClaudeToolsToGeminiTools(req["tools"]); tools != nil {
|
||
out["tools"] = tools
|
||
}
|
||
|
||
generationConfig := convertClaudeGenerationConfig(req)
|
||
if generationConfig != nil {
|
||
out["generationConfig"] = generationConfig
|
||
}
|
||
|
||
stripGeminiFunctionIDs(out)
|
||
return json.Marshal(out)
|
||
}
|
||
|
||
func stripGeminiFunctionIDs(req map[string]any) {
|
||
// Defensive cleanup: some upstreams reject unexpected `id` fields in functionCall/functionResponse.
|
||
contents, ok := req["contents"].([]any)
|
||
if !ok {
|
||
return
|
||
}
|
||
for _, c := range contents {
|
||
cm, ok := c.(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
contentParts, ok := cm["parts"].([]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
for _, p := range contentParts {
|
||
pm, ok := p.(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
if fc, ok := pm["functionCall"].(map[string]any); ok && fc != nil {
|
||
delete(fc, "id")
|
||
}
|
||
if fr, ok := pm["functionResponse"].(map[string]any); ok && fr != nil {
|
||
delete(fr, "id")
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
func extractClaudeSystemText(system any) string {
|
||
switch v := system.(type) {
|
||
case string:
|
||
return strings.TrimSpace(v)
|
||
case []any:
|
||
var parts []string
|
||
for _, p := range v {
|
||
pm, ok := p.(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
if t, _ := pm["type"].(string); t != "text" {
|
||
continue
|
||
}
|
||
if text, ok := pm["text"].(string); ok && strings.TrimSpace(text) != "" {
|
||
parts = append(parts, text)
|
||
}
|
||
}
|
||
return strings.TrimSpace(strings.Join(parts, "\n"))
|
||
default:
|
||
return ""
|
||
}
|
||
}
|
||
|
||
func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[string]string) ([]any, error) {
|
||
arr, ok := messages.([]any)
|
||
if !ok {
|
||
return nil, errors.New("messages must be an array")
|
||
}
|
||
|
||
out := make([]any, 0, len(arr))
|
||
for _, m := range arr {
|
||
mm, ok := m.(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
role, _ := mm["role"].(string)
|
||
role = strings.ToLower(strings.TrimSpace(role))
|
||
gRole := "user"
|
||
if role == "assistant" {
|
||
gRole = "model"
|
||
}
|
||
|
||
parts := make([]any, 0)
|
||
switch content := mm["content"].(type) {
|
||
case string:
|
||
// 字符串形式的 content,保留所有内容(包括空白)
|
||
parts = append(parts, map[string]any{"text": content})
|
||
case []any:
|
||
// 如果只有一个 block,不过滤空白(让上游 API 报错)
|
||
singleBlock := len(content) == 1
|
||
|
||
for _, block := range content {
|
||
bm, ok := block.(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
bt, _ := bm["type"].(string)
|
||
switch bt {
|
||
case "text":
|
||
if text, ok := bm["text"].(string); ok {
|
||
// 单个 block 时保留所有内容(包括空白)
|
||
// 多个 blocks 时过滤掉空白
|
||
if singleBlock || strings.TrimSpace(text) != "" {
|
||
parts = append(parts, map[string]any{"text": text})
|
||
}
|
||
}
|
||
case "tool_use":
|
||
id, _ := bm["id"].(string)
|
||
name, _ := bm["name"].(string)
|
||
if strings.TrimSpace(id) != "" && strings.TrimSpace(name) != "" {
|
||
toolUseIDToName[id] = name
|
||
}
|
||
parts = append(parts, map[string]any{
|
||
"functionCall": map[string]any{
|
||
"name": name,
|
||
"args": bm["input"],
|
||
},
|
||
})
|
||
case "tool_result":
|
||
toolUseID, _ := bm["tool_use_id"].(string)
|
||
name := toolUseIDToName[toolUseID]
|
||
if name == "" {
|
||
name = "tool"
|
||
}
|
||
parts = append(parts, map[string]any{
|
||
"functionResponse": map[string]any{
|
||
"name": name,
|
||
"response": map[string]any{
|
||
"content": extractClaudeContentText(bm["content"]),
|
||
},
|
||
},
|
||
})
|
||
case "image":
|
||
if src, ok := bm["source"].(map[string]any); ok {
|
||
if srcType, _ := src["type"].(string); srcType == "base64" {
|
||
mediaType, _ := src["media_type"].(string)
|
||
data, _ := src["data"].(string)
|
||
if mediaType != "" && data != "" {
|
||
parts = append(parts, map[string]any{
|
||
"inlineData": map[string]any{
|
||
"mimeType": mediaType,
|
||
"data": data,
|
||
},
|
||
})
|
||
}
|
||
}
|
||
}
|
||
default:
|
||
// best-effort: preserve unknown blocks as text
|
||
if b, err := json.Marshal(bm); err == nil {
|
||
parts = append(parts, map[string]any{"text": string(b)})
|
||
}
|
||
}
|
||
}
|
||
default:
|
||
// ignore
|
||
}
|
||
|
||
out = append(out, map[string]any{
|
||
"role": gRole,
|
||
"parts": parts,
|
||
})
|
||
}
|
||
return out, nil
|
||
}
|
||
|
||
func extractClaudeContentText(v any) string {
|
||
switch t := v.(type) {
|
||
case string:
|
||
return t
|
||
case []any:
|
||
var sb strings.Builder
|
||
for _, part := range t {
|
||
pm, ok := part.(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
if pm["type"] == "text" {
|
||
if text, ok := pm["text"].(string); ok {
|
||
_, _ = sb.WriteString(text)
|
||
}
|
||
}
|
||
}
|
||
return sb.String()
|
||
default:
|
||
b, _ := json.Marshal(t)
|
||
return string(b)
|
||
}
|
||
}
|
||
|
||
func convertClaudeToolsToGeminiTools(tools any) []any {
|
||
arr, ok := tools.([]any)
|
||
if !ok || len(arr) == 0 {
|
||
return nil
|
||
}
|
||
|
||
funcDecls := make([]any, 0, len(arr))
|
||
for _, t := range arr {
|
||
tm, ok := t.(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
var name, desc string
|
||
var params any
|
||
|
||
// 检查是否为 custom 类型工具 (MCP)
|
||
toolType, _ := tm["type"].(string)
|
||
if toolType == "custom" {
|
||
// Custom 格式: 从 custom 字段获取 description 和 input_schema
|
||
custom, ok := tm["custom"].(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
name, _ = tm["name"].(string)
|
||
desc, _ = custom["description"].(string)
|
||
params = custom["input_schema"]
|
||
} else {
|
||
// 标准格式: 从顶层字段获取
|
||
name, _ = tm["name"].(string)
|
||
desc, _ = tm["description"].(string)
|
||
params = tm["input_schema"]
|
||
}
|
||
|
||
if name == "" {
|
||
continue
|
||
}
|
||
|
||
// 为 nil params 提供默认值
|
||
if params == nil {
|
||
params = map[string]any{
|
||
"type": "object",
|
||
"properties": map[string]any{},
|
||
}
|
||
}
|
||
// 清理 JSON Schema
|
||
cleanedParams := cleanToolSchema(params)
|
||
|
||
funcDecls = append(funcDecls, map[string]any{
|
||
"name": name,
|
||
"description": desc,
|
||
"parameters": cleanedParams,
|
||
})
|
||
}
|
||
|
||
if len(funcDecls) == 0 {
|
||
return nil
|
||
}
|
||
return []any{
|
||
map[string]any{
|
||
"functionDeclarations": funcDecls,
|
||
},
|
||
}
|
||
}
|
||
|
||
// cleanToolSchema 清理工具的 JSON Schema,移除 Gemini 不支持的字段
|
||
func cleanToolSchema(schema any) any {
|
||
if schema == nil {
|
||
return nil
|
||
}
|
||
|
||
switch v := schema.(type) {
|
||
case map[string]any:
|
||
cleaned := make(map[string]any)
|
||
for key, value := range v {
|
||
// 跳过不支持的字段
|
||
if key == "$schema" || key == "$id" || key == "$ref" ||
|
||
key == "additionalProperties" || key == "minLength" ||
|
||
key == "maxLength" || key == "minItems" || key == "maxItems" {
|
||
continue
|
||
}
|
||
// 递归清理嵌套对象
|
||
cleaned[key] = cleanToolSchema(value)
|
||
}
|
||
// 规范化 type 字段为大写
|
||
if typeVal, ok := cleaned["type"].(string); ok {
|
||
cleaned["type"] = strings.ToUpper(typeVal)
|
||
}
|
||
return cleaned
|
||
case []any:
|
||
cleaned := make([]any, len(v))
|
||
for i, item := range v {
|
||
cleaned[i] = cleanToolSchema(item)
|
||
}
|
||
return cleaned
|
||
default:
|
||
return v
|
||
}
|
||
}
|
||
|
||
func convertClaudeGenerationConfig(req map[string]any) map[string]any {
|
||
out := make(map[string]any)
|
||
if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 {
|
||
out["maxOutputTokens"] = mt
|
||
}
|
||
if temp, ok := req["temperature"].(float64); ok {
|
||
out["temperature"] = temp
|
||
}
|
||
if topP, ok := req["top_p"].(float64); ok {
|
||
out["topP"] = topP
|
||
}
|
||
if stopSeq, ok := req["stop_sequences"].([]any); ok && len(stopSeq) > 0 {
|
||
out["stopSequences"] = stopSeq
|
||
}
|
||
if len(out) == 0 {
|
||
return nil
|
||
}
|
||
return out
|
||
}
|