Files
sub2api/backend/internal/service/gemini_messages_compat_service.go
IanShaw 45bd9ac705 运维监控系统安全加固和功能优化 (#21)
* fix(ops): 修复运维监控系统的关键安全和稳定性问题

## 修复内容

### P0 严重问题
1. **DNS Rebinding防护** (ops_alert_service.go)
   - 实现IP钉住机制防止验证后的DNS rebinding攻击
   - 自定义Transport.DialContext强制只允许拨号到验证过的公网IP
   - 扩展IP黑名单,包括云metadata地址(169.254.169.254)
   - 添加完整的单元测试覆盖

2. **OpsAlertService生命周期管理** (wire.go)
   - 在ProvideOpsMetricsCollector中添加opsAlertService.Start()调用
   - 确保stopCtx正确初始化,避免nil指针问题
   - 实现防御式启动,保证服务启动顺序

3. **数据库查询排序** (ops_repo.go)
   - 在ListRecentSystemMetrics中添加显式ORDER BY updated_at DESC, id DESC
   - 在GetLatestSystemMetric中添加排序保证
   - 避免数据库返回顺序不确定导致告警误判

### P1 重要问题
4. **并发安全** (ops_metrics_collector.go)
   - 为lastGCPauseTotal字段添加sync.Mutex保护
   - 防止数据竞争

5. **Goroutine泄漏** (ops_error_logger.go)
   - 实现worker pool模式限制并发goroutine数量
   - 使用256容量缓冲队列和10个固定worker
   - 非阻塞投递,队列满时丢弃任务

6. **生命周期控制** (ops_alert_service.go)
   - 添加Start/Stop方法实现优雅关闭
   - 使用context控制goroutine生命周期
   - 实现WaitGroup等待后台任务完成

7. **Webhook URL验证** (ops_alert_service.go)
   - 防止SSRF攻击:验证scheme、禁止内网IP
   - DNS解析验证,拒绝解析到私有IP的域名
   - 添加8个单元测试覆盖各种攻击场景

8. **资源泄漏** (ops_repo.go)
   - 修复多处defer rows.Close()问题
   - 简化冗余的defer func()包装

9. **HTTP超时控制** (ops_alert_service.go)
   - 创建带10秒超时的http.Client
   - 添加buildWebhookHTTPClient辅助函数
   - 防止HTTP请求无限期挂起

10. **数据库查询优化** (ops_repo.go)
    - 将GetWindowStats的4次独立查询合并为1次CTE查询
    - 减少网络往返和表扫描次数
    - 显著提升性能

11. **重试机制** (ops_alert_service.go)
    - 实现邮件发送重试:最多3次,指数退避(1s/2s/4s)
    - 添加webhook备用通道
    - 实现完整的错误处理和日志记录

12. **魔法数字** (ops_repo.go, ops_metrics_collector.go)
    - 提取硬编码数字为有意义的常量
    - 提高代码可读性和可维护性

## 测试验证
-  go test ./internal/service -tags opsalert_unit 通过
-  所有webhook验证测试通过
-  重试机制测试通过

## 影响范围
- 运维监控系统安全性显著提升
- 系统稳定性和性能优化
- 无破坏性变更,向后兼容

* feat(ops): 运维监控系统V2 - 完整实现

## 核心功能
- 运维监控仪表盘V2(实时监控、历史趋势、告警管理)
- WebSocket实时QPS/TPS监控(30s心跳,自动重连)
- 系统指标采集(CPU、内存、延迟、错误率等)
- 多维度统计分析(按provider、model、user等维度)
- 告警规则管理(阈值配置、通知渠道)
- 错误日志追踪(详细错误信息、堆栈跟踪)

## 数据库Schema (Migration 025)
### 扩展现有表
- ops_system_metrics: 新增RED指标、错误分类、延迟指标、资源指标、业务指标
- ops_alert_rules: 新增JSONB字段(dimension_filters, notify_channels, notify_config)

### 新增表
- ops_dimension_stats: 多维度统计数据
- ops_data_retention_config: 数据保留策略配置

### 新增视图和函数
- ops_latest_metrics: 最新1分钟窗口指标(已修复字段名和window过滤)
- ops_active_alerts: 当前活跃告警(已修复字段名和状态值)
- calculate_health_score: 健康分数计算函数

## 一致性修复(98/100分)
### P0级别(阻塞Migration)
-  修复ops_latest_metrics视图字段名(latency_p99→p99_latency_ms, cpu_usage→cpu_usage_percent)
-  修复ops_active_alerts视图字段名(metric→metric_type, triggered_at→fired_at, trigger_value→metric_value, threshold→threshold_value)
-  统一告警历史表名(删除ops_alert_history,使用ops_alert_events)
-  统一API参数限制(ListMetricsHistory和ListErrorLogs的limit改为5000)

### P1级别(功能完整性)
-  修复ops_latest_metrics视图未过滤window_minutes(添加WHERE m.window_minutes = 1)
-  修复数据回填UPDATE逻辑(QPS计算改为request_count/(window_minutes*60.0))
-  添加ops_alert_rules JSONB字段后端支持(Go结构体+序列化)

### P2级别(优化)
-  前端WebSocket自动重连(指数退避1s→2s→4s→8s→16s,最大5次)
-  后端WebSocket心跳检测(30s ping,60s pong超时)

## 技术实现
### 后端 (Go)
- Handler层: ops_handler.go(REST API), ops_ws_handler.go(WebSocket)
- Service层: ops_service.go(核心逻辑), ops_cache.go(缓存), ops_alerts.go(告警)
- Repository层: ops_repo.go(数据访问), ops.go(模型定义)
- 路由: admin.go(新增ops相关路由)
- 依赖注入: wire_gen.go(自动生成)

### 前端 (Vue3 + TypeScript)
- 组件: OpsDashboardV2.vue(仪表盘主组件)
- API: ops.ts(REST API + WebSocket封装)
- 路由: index.ts(新增/admin/ops路由)
- 国际化: en.ts, zh.ts(中英文支持)

## 测试验证
-  所有Go测试通过
-  Migration可正常执行
-  WebSocket连接稳定
-  前后端数据结构对齐

* refactor: 代码清理和测试优化

## 测试文件优化
- 简化integration test fixtures和断言
- 优化test helper函数
- 统一测试数据格式

## 代码清理
- 移除未使用的代码和注释
- 简化concurrency_cache实现
- 优化middleware错误处理

## 小修复
- 修复gateway_handler和openai_gateway_handler的小问题
- 统一代码风格和格式

变更统计: 27个文件,292行新增,322行删除(净减少30行)

* fix(ops): 运维监控系统安全加固和功能优化

## 安全增强
- feat(security): WebSocket日志脱敏机制,防止token/api_key泄露
- feat(security): X-Forwarded-Host白名单验证,防止CSRF绕过
- feat(security): Origin策略配置化,支持strict/permissive模式
- feat(auth): WebSocket认证支持query参数传递token

## 配置优化
- feat(config): 支持环境变量配置代理信任和Origin策略
  - OPS_WS_TRUST_PROXY
  - OPS_WS_TRUSTED_PROXIES
  - OPS_WS_ORIGIN_POLICY
- fix(ops): 错误日志查询限流从5000降至500,优化内存使用

## 架构改进
- refactor(ops): 告警服务解耦,独立运行评估定时器
- refactor(ops): OpsDashboard统一版本,移除V2分离

## 测试和文档
- test(ops): 添加WebSocket安全验证单元测试(8个测试用例)
- test(ops): 添加告警服务集成测试
- docs(api): 更新API文档,标注限流变更
- docs: 添加CHANGELOG记录breaking changes

## 修复文件
Backend:
- backend/internal/server/middleware/logger.go
- backend/internal/handler/admin/ops_handler.go
- backend/internal/handler/admin/ops_ws_handler.go
- backend/internal/server/middleware/admin_auth.go
- backend/internal/service/ops_alert_service.go
- backend/internal/service/ops_metrics_collector.go
- backend/internal/service/wire.go

Frontend:
- frontend/src/views/admin/ops/OpsDashboard.vue
- frontend/src/router/index.ts
- frontend/src/api/admin/ops.ts

Tests:
- backend/internal/handler/admin/ops_ws_handler_test.go (新增)
- backend/internal/service/ops_alert_service_integration_test.go (新增)

Docs:
- CHANGELOG.md (新增)
- docs/API-运维监控中心2.0.md (更新)

* fix(migrations): 修复calculate_health_score函数类型匹配问题

在ops_latest_metrics视图中添加显式类型转换,确保参数类型与函数签名匹配

* fix(lint): 修复golangci-lint检查发现的所有问题

- 将Redis依赖从service层移到repository层
- 添加错误检查(WebSocket连接和读取超时)
- 运行gofmt格式化代码
- 添加nil指针检查
- 删除未使用的alertService字段

修复问题:
- depguard: 3个(service层不应直接import redis)
- errcheck: 3个(未检查错误返回值)
- gofmt: 2个(代码格式问题)
- staticcheck: 4个(nil指针解引用)
- unused: 1个(未使用字段)

代码统计:
- 修改文件:11个
- 删除代码:490行
- 新增代码:105行
- 净减少:385行
2026-01-02 20:01:12 +08:00

2398 lines
66 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"bufio"
"bytes"
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"math"
mathrand "math/rand"
"net/http"
"regexp"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/gin-gonic/gin"
)
const geminiStickySessionTTL = time.Hour
const (
geminiMaxRetries = 5
geminiRetryBaseDelay = 1 * time.Second
geminiRetryMaxDelay = 16 * time.Second
)
type GeminiMessagesCompatService struct {
accountRepo AccountRepository
groupRepo GroupRepository
cache GatewayCache
tokenProvider *GeminiTokenProvider
rateLimitService *RateLimitService
httpUpstream HTTPUpstream
antigravityGatewayService *AntigravityGatewayService
}
func NewGeminiMessagesCompatService(
accountRepo AccountRepository,
groupRepo GroupRepository,
cache GatewayCache,
tokenProvider *GeminiTokenProvider,
rateLimitService *RateLimitService,
httpUpstream HTTPUpstream,
antigravityGatewayService *AntigravityGatewayService,
) *GeminiMessagesCompatService {
return &GeminiMessagesCompatService{
accountRepo: accountRepo,
groupRepo: groupRepo,
cache: cache,
tokenProvider: tokenProvider,
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
antigravityGatewayService: antigravityGatewayService,
}
}
// GetTokenProvider returns the token provider for OAuth accounts
func (s *GeminiMessagesCompatService) GetTokenProvider() *GeminiTokenProvider {
return s.tokenProvider
}
func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil)
}
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 优先检查 context 中的强制平台(/antigravity 路由)
var platform string
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
platform = forcePlatform
} else if groupID != nil {
// 根据分组 platform 决定查询哪种账号
group, err := s.groupRepo.GetByID(ctx, *groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
}
platform = group.Platform
} else {
// 无分组时只使用原生 gemini 平台
platform = PlatformGemini
}
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
var queryPlatforms []string
if useMixedScheduling {
queryPlatforms = []string{PlatformGemini, PlatformAntigravity}
} else {
queryPlatforms = []string{platform}
}
cacheKey := "gemini:" + sessionHash
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
// 检查账号是否有效原生平台直接匹配antigravity 需要启用混合调度
if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
valid := false
if account.Platform == platform {
valid = true
} else if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
valid = true
}
if valid {
usable := true
if s.rateLimitService != nil && requestedModel != "" {
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
}
if !ok {
usable = false
}
}
if usable {
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
return account, nil
}
}
}
}
}
}
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
// 强制平台模式下,分组中找不到账户时回退查询全部
if len(accounts) == 0 && hasForcePlatform {
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
}
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
var selected *Account
for i := range accounts {
acc := &accounts[i]
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// 混合调度模式下原生平台直接通过antigravity 需要启用 mixed_scheduling
// 非混合调度模式antigravity 分组):不需要过滤
if useMixedScheduling && acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
if s.rateLimitService != nil && requestedModel != "" {
ok, err := s.rateLimitService.PreCheckUsage(ctx, acc, requestedModel)
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", acc.ID, err)
}
if !ok {
continue
}
}
if selected == nil {
selected = acc
continue
}
if acc.Priority < selected.Priority {
selected = acc
} else if acc.Priority == selected.Priority {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
// Prefer OAuth accounts when both are unused (more compatible for Code Assist flows).
if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
}
}
}
}
if selected == nil {
if requestedModel != "" {
return nil, fmt.Errorf("no available Gemini accounts supporting model: %s", requestedModel)
}
return nil, errors.New("no available Gemini accounts")
}
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, cacheKey, selected.ID, geminiStickySessionTTL)
}
return selected, nil
}
// isModelSupportedByAccount 根据账户平台检查模型支持
func (s *GeminiMessagesCompatService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {
return IsAntigravityModelSupported(requestedModel)
}
return account.IsModelSupported(requestedModel)
}
// GetAntigravityGatewayService 返回 AntigravityGatewayService
func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *AntigravityGatewayService {
return s.antigravityGatewayService
}
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAntigravity)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAntigravity)
}
if err != nil {
return false, err
}
return len(accounts) > 0, nil
}
// SelectAccountForAIStudioEndpoints selects an account that is likely to succeed against
// generativelanguage.googleapis.com (e.g. GET /v1beta/models).
//
// Preference order:
// 1) API key accounts (AI Studio)
// 2) OAuth accounts without project_id (AI Studio OAuth)
// 3) OAuth accounts explicitly marked as ai_studio
// 4) Any remaining Gemini accounts (fallback)
func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*Account, error) {
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
if len(accounts) == 0 {
return nil, errors.New("no available Gemini accounts")
}
rank := func(a *Account) int {
if a == nil {
return 999
}
switch a.Type {
case AccountTypeAPIKey:
if strings.TrimSpace(a.GetCredential("api_key")) != "" {
return 0
}
return 9
case AccountTypeOAuth:
if strings.TrimSpace(a.GetCredential("project_id")) == "" {
return 1
}
if strings.TrimSpace(a.GetCredential("oauth_type")) == "ai_studio" {
return 2
}
// Code Assist OAuth tokens often lack AI Studio scopes for models listing.
return 3
default:
return 10
}
}
var selected *Account
for i := range accounts {
acc := &accounts[i]
if selected == nil {
selected = acc
continue
}
r1, r2 := rank(acc), rank(selected)
if r1 < r2 {
selected = acc
continue
}
if r1 > r2 {
continue
}
if acc.Priority < selected.Priority {
selected = acc
} else if acc.Priority == selected.Priority {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
}
}
}
}
if selected == nil {
return nil, errors.New("no available Gemini accounts")
}
return selected, nil
}
func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
startTime := time.Now()
var req struct {
Model string `json:"model"`
Stream bool `json:"stream"`
}
if err := json.Unmarshal(body, &req); err != nil {
return nil, fmt.Errorf("parse request: %w", err)
}
if strings.TrimSpace(req.Model) == "" {
return nil, fmt.Errorf("missing model")
}
originalModel := req.Model
mappedModel := req.Model
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(req.Model)
}
geminiReq, err := convertClaudeMessagesToGeminiGenerateContent(body)
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
}
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
var requestIDHeader string
var buildReq func(ctx context.Context) (*http.Request, string, error)
useUpstreamStream := req.Stream
if account.Type == AccountTypeOAuth && !req.Stream && strings.TrimSpace(account.GetCredential("project_id")) != "" {
// Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate.
useUpstreamStream = true
}
switch account.Type {
case AccountTypeAPIKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" {
return nil, "", errors.New("gemini api_key not configured")
}
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
action := "generateContent"
if req.Stream {
action = "streamGenerateContent"
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, action)
if req.Stream {
fullURL += "?alt=sse"
}
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiReq))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("x-goog-api-key", apiKey)
return upstreamReq, "x-request-id", nil
}
requestIDHeader = "x-request-id"
case AccountTypeOAuth:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
if s.tokenProvider == nil {
return nil, "", errors.New("gemini token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, "", err
}
projectID := strings.TrimSpace(account.GetCredential("project_id"))
action := "generateContent"
if useUpstreamStream {
action = "streamGenerateContent"
}
// Two modes for OAuth:
// 1. With project_id -> Code Assist API (wrapped request)
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if projectID != "" {
// Mode 1: Code Assist API
fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, action)
if useUpstreamStream {
fullURL += "?alt=sse"
}
wrapped := map[string]any{
"model": mappedModel,
"project": projectID,
}
var inner any
if err := json.Unmarshal(geminiReq, &inner); err != nil {
return nil, "", fmt.Errorf("failed to parse gemini request: %w", err)
}
wrapped["request"] = inner
wrappedBytes, _ := json.Marshal(wrapped)
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBytes))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
upstreamReq.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
return upstreamReq, "x-request-id", nil
} else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, action)
if useUpstreamStream {
fullURL += "?alt=sse"
}
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiReq))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
return upstreamReq, "x-request-id", nil
}
}
requestIDHeader = "x-request-id"
default:
return nil, fmt.Errorf("unsupported account type: %s", account.Type)
}
var resp *http.Response
for attempt := 1; attempt <= geminiMaxRetries; attempt++ {
upstreamReq, idHeader, err := buildReq(ctx)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, err
}
// Local build error: don't retry.
if strings.Contains(err.Error(), "missing project_id") {
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
}
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", err.Error())
}
requestIDHeader = idHeader
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
if attempt < geminiMaxRetries {
log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
sleepGeminiBackoff(attempt)
continue
}
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries: "+sanitizeUpstreamErrorMessage(err.Error()))
}
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
// Don't treat insufficient-scope as transient.
if resp.StatusCode == 403 && isGeminiInsufficientScope(resp.Header, respBody) {
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break
}
if resp.StatusCode == 429 {
// Mark as rate-limited early so concurrent requests avoid this account.
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
if attempt < geminiMaxRetries {
log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
sleepGeminiBackoff(attempt)
continue
}
// Final attempt: surface the upstream error body (mapped below) instead of a generic retry error.
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break
}
break
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return nil, s.writeGeminiMappedError(c, resp.StatusCode, respBody)
}
requestID := resp.Header.Get(requestIDHeader)
if requestID == "" {
requestID = resp.Header.Get("x-goog-request-id")
}
if requestID != "" {
c.Header("x-request-id", requestID)
}
var usage *ClaudeUsage
var firstTokenMs *int
if req.Stream {
streamRes, err := s.handleStreamingResponse(c, resp, startTime, originalModel)
if err != nil {
return nil, err
}
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
} else {
if useUpstreamStream {
collected, usageObj, err := collectGeminiSSE(resp.Body, true)
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream stream")
}
claudeResp, usageObj2 := convertGeminiToClaudeMessage(collected, originalModel)
c.JSON(http.StatusOK, claudeResp)
usage = usageObj2
if usageObj != nil && (usageObj.InputTokens > 0 || usageObj.OutputTokens > 0) {
usage = usageObj
}
} else {
usage, err = s.handleNonStreamingResponse(c, resp, originalModel)
if err != nil {
return nil, err
}
}
}
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
Model: originalModel,
Stream: req.Stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
startTime := time.Now()
if strings.TrimSpace(originalModel) == "" {
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing model in URL")
}
if strings.TrimSpace(action) == "" {
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing action in URL")
}
if len(body) == 0 {
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
}
switch action {
case "generateContent", "streamGenerateContent", "countTokens":
// ok
default:
return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action)
}
mappedModel := originalModel
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(originalModel)
}
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
useUpstreamStream := stream
upstreamAction := action
if account.Type == AccountTypeOAuth && !stream && action == "generateContent" && strings.TrimSpace(account.GetCredential("project_id")) != "" {
// Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate.
useUpstreamStream = true
upstreamAction = "streamGenerateContent"
}
forceAIStudio := action == "countTokens"
var requestIDHeader string
var buildReq func(ctx context.Context) (*http.Request, string, error)
switch account.Type {
case AccountTypeAPIKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" {
return nil, "", errors.New("gemini api_key not configured")
}
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(body))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("x-goog-api-key", apiKey)
return upstreamReq, "x-request-id", nil
}
requestIDHeader = "x-request-id"
case AccountTypeOAuth:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
if s.tokenProvider == nil {
return nil, "", errors.New("gemini token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, "", err
}
projectID := strings.TrimSpace(account.GetCredential("project_id"))
// Two modes for OAuth:
// 1. With project_id -> Code Assist API (wrapped request)
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if projectID != "" && !forceAIStudio {
// Mode 1: Code Assist API
fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
wrapped := map[string]any{
"model": mappedModel,
"project": projectID,
}
var inner any
if err := json.Unmarshal(body, &inner); err != nil {
return nil, "", fmt.Errorf("failed to parse gemini request: %w", err)
}
wrapped["request"] = inner
wrappedBytes, _ := json.Marshal(wrapped)
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBytes))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
upstreamReq.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
return upstreamReq, "x-request-id", nil
} else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(body))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
return upstreamReq, "x-request-id", nil
}
}
requestIDHeader = "x-request-id"
default:
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Unsupported account type: "+account.Type)
}
var resp *http.Response
for attempt := 1; attempt <= geminiMaxRetries; attempt++ {
upstreamReq, idHeader, err := buildReq(ctx)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, err
}
// Local build error: don't retry.
if strings.Contains(err.Error(), "missing project_id") {
return nil, s.writeGoogleError(c, http.StatusBadRequest, err.Error())
}
return nil, s.writeGoogleError(c, http.StatusBadGateway, err.Error())
}
requestIDHeader = idHeader
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
if attempt < geminiMaxRetries {
log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
sleepGeminiBackoff(attempt)
continue
}
if action == "countTokens" {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
}
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries: "+sanitizeUpstreamErrorMessage(err.Error()))
}
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
// Don't treat insufficient-scope as transient.
if resp.StatusCode == 403 && isGeminiInsufficientScope(resp.Header, respBody) {
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break
}
if resp.StatusCode == 429 {
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
if attempt < geminiMaxRetries {
log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
sleepGeminiBackoff(attempt)
continue
}
if action == "countTokens" {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
}
// Final attempt: surface the upstream error body (passed through below) instead of a generic retry error.
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break
}
break
}
defer func() { _ = resp.Body.Close() }()
requestID := resp.Header.Get(requestIDHeader)
if requestID == "" {
requestID = resp.Header.Get("x-goog-request-id")
}
if requestID != "" {
c.Header("x-request-id", requestID)
}
isOAuth := account.Type == AccountTypeOAuth
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens.
// This avoids Gemini SDKs failing hard during preflight token counting.
if action == "countTokens" && isOAuth && isGeminiInsufficientScope(resp.Header, respBody) {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{
RequestID: requestID,
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
respBody = unwrapIfNeeded(isOAuth, respBody)
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}
c.Data(resp.StatusCode, contentType, respBody)
return nil, fmt.Errorf("gemini upstream error: %d", resp.StatusCode)
}
var usage *ClaudeUsage
var firstTokenMs *int
if stream {
streamRes, err := s.handleNativeStreamingResponse(c, resp, startTime, isOAuth)
if err != nil {
return nil, err
}
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
} else {
if useUpstreamStream {
collected, usageObj, err := collectGeminiSSE(resp.Body, isOAuth)
if err != nil {
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Failed to read upstream stream")
}
b, _ := json.Marshal(collected)
c.Data(http.StatusOK, "application/json", b)
usage = usageObj
} else {
usageResp, err := s.handleNativeNonStreamingResponse(c, resp, isOAuth)
if err != nil {
return nil, err
}
usage = usageResp
}
}
if usage == nil {
usage = &ClaudeUsage{}
}
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
Model: originalModel,
Stream: stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Account, statusCode int) bool {
switch statusCode {
case 429, 500, 502, 503, 504, 529:
return true
case 403:
// GeminiCli OAuth occasionally returns 403 transiently (activation/quota propagation); allow retry.
if account == nil || account.Type != AccountTypeOAuth {
return false
}
oauthType := strings.ToLower(strings.TrimSpace(account.GetCredential("oauth_type")))
if oauthType == "" && strings.TrimSpace(account.GetCredential("project_id")) != "" {
// Legacy/implicit Code Assist OAuth accounts.
oauthType = "code_assist"
}
return oauthType == "code_assist"
default:
return false
}
}
func (s *GeminiMessagesCompatService) shouldFailoverGeminiUpstreamError(statusCode int) bool {
switch statusCode {
case 401, 403, 429, 529:
return true
default:
return statusCode >= 500
}
}
func sleepGeminiBackoff(attempt int) {
delay := geminiRetryBaseDelay * time.Duration(1<<uint(attempt-1))
if delay > geminiRetryMaxDelay {
delay = geminiRetryMaxDelay
}
// +/- 20% jitter
r := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
jitter := time.Duration(float64(delay) * 0.2 * (r.Float64()*2 - 1))
sleepFor := delay + jitter
if sleepFor < 0 {
sleepFor = 0
}
time.Sleep(sleepFor)
}
var (
sensitiveQueryParamRegex = regexp.MustCompile(`(?i)([?&](?:key|client_secret|access_token|refresh_token)=)[^&"\s]+`)
retryInRegex = regexp.MustCompile(`Please retry in ([0-9.]+)s`)
)
func sanitizeUpstreamErrorMessage(msg string) string {
if msg == "" {
return msg
}
return sensitiveQueryParamRegex.ReplaceAllString(msg, `$1***`)
}
func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, upstreamStatus int, body []byte) error {
var statusCode int
var errType, errMsg string
if mapped := mapGeminiErrorBodyToClaudeError(body); mapped != nil {
errType = mapped.Type
if mapped.Message != "" {
errMsg = mapped.Message
}
if mapped.StatusCode > 0 {
statusCode = mapped.StatusCode
}
}
switch upstreamStatus {
case 400:
if statusCode == 0 {
statusCode = http.StatusBadRequest
}
if errType == "" {
errType = "invalid_request_error"
}
if errMsg == "" {
errMsg = "Invalid request"
}
case 401:
if statusCode == 0 {
statusCode = http.StatusBadGateway
}
if errType == "" {
errType = "authentication_error"
}
if errMsg == "" {
errMsg = "Upstream authentication failed, please contact administrator"
}
case 403:
if statusCode == 0 {
statusCode = http.StatusBadGateway
}
if errType == "" {
errType = "permission_error"
}
if errMsg == "" {
errMsg = "Upstream access forbidden, please contact administrator"
}
case 404:
if statusCode == 0 {
statusCode = http.StatusNotFound
}
if errType == "" {
errType = "not_found_error"
}
if errMsg == "" {
errMsg = "Resource not found"
}
case 429:
if statusCode == 0 {
statusCode = http.StatusTooManyRequests
}
if errType == "" {
errType = "rate_limit_error"
}
if errMsg == "" {
errMsg = "Upstream rate limit exceeded, please retry later"
}
case 529:
if statusCode == 0 {
statusCode = http.StatusServiceUnavailable
}
if errType == "" {
errType = "overloaded_error"
}
if errMsg == "" {
errMsg = "Upstream service overloaded, please retry later"
}
case 500, 502, 503, 504:
if statusCode == 0 {
statusCode = http.StatusBadGateway
}
if errType == "" {
switch upstreamStatus {
case 504:
errType = "timeout_error"
case 503:
errType = "overloaded_error"
default:
errType = "api_error"
}
}
if errMsg == "" {
errMsg = "Upstream service temporarily unavailable"
}
default:
if statusCode == 0 {
statusCode = http.StatusBadGateway
}
if errType == "" {
errType = "upstream_error"
}
if errMsg == "" {
errMsg = "Upstream request failed"
}
}
c.JSON(statusCode, gin.H{
"type": "error",
"error": gin.H{"type": errType, "message": errMsg},
})
return fmt.Errorf("upstream error: %d", upstreamStatus)
}
type claudeErrorMapping struct {
Type string
Message string
StatusCode int
}
func mapGeminiErrorBodyToClaudeError(body []byte) *claudeErrorMapping {
if len(body) == 0 {
return nil
}
var parsed struct {
Error struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
} `json:"error"`
}
if err := json.Unmarshal(body, &parsed); err != nil {
return nil
}
if strings.TrimSpace(parsed.Error.Status) == "" && parsed.Error.Code == 0 && strings.TrimSpace(parsed.Error.Message) == "" {
return nil
}
mapped := &claudeErrorMapping{
Type: mapGeminiStatusToClaudeErrorType(parsed.Error.Status),
Message: "",
}
if mapped.Type == "" {
mapped.Type = "upstream_error"
}
switch strings.ToUpper(strings.TrimSpace(parsed.Error.Status)) {
case "INVALID_ARGUMENT":
mapped.StatusCode = http.StatusBadRequest
case "NOT_FOUND":
mapped.StatusCode = http.StatusNotFound
case "RESOURCE_EXHAUSTED":
mapped.StatusCode = http.StatusTooManyRequests
default:
// Keep StatusCode unset and let HTTP status mapping decide.
}
// Keep messages generic by default; upstream error message can be long or include sensitive fragments.
return mapped
}
func mapGeminiStatusToClaudeErrorType(status string) string {
switch strings.ToUpper(strings.TrimSpace(status)) {
case "INVALID_ARGUMENT":
return "invalid_request_error"
case "PERMISSION_DENIED":
return "permission_error"
case "NOT_FOUND":
return "not_found_error"
case "RESOURCE_EXHAUSTED":
return "rate_limit_error"
case "UNAUTHENTICATED":
return "authentication_error"
case "UNAVAILABLE":
return "overloaded_error"
case "INTERNAL":
return "api_error"
case "DEADLINE_EXCEEDED":
return "timeout_error"
default:
return ""
}
}
type geminiStreamResult struct {
usage *ClaudeUsage
firstTokenMs *int
}
func (s *GeminiMessagesCompatService) handleNonStreamingResponse(c *gin.Context, resp *http.Response, originalModel string) (*ClaudeUsage, error) {
body, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response")
}
geminiResp, err := unwrapGeminiResponse(body)
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
}
claudeResp, usage := convertGeminiToClaudeMessage(geminiResp, originalModel)
c.JSON(http.StatusOK, claudeResp)
return usage, nil
}
func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*geminiStreamResult, error) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
c.Status(http.StatusOK)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
messageID := "msg_" + randomHex(12)
messageStart := map[string]any{
"type": "message_start",
"message": map[string]any{
"id": messageID,
"type": "message",
"role": "assistant",
"model": originalModel,
"content": []any{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]any{
"input_tokens": 0,
"output_tokens": 0,
},
},
}
writeSSE(c.Writer, "message_start", messageStart)
flusher.Flush()
var firstTokenMs *int
var usage ClaudeUsage
finishReason := ""
sawToolUse := false
nextBlockIndex := 0
openBlockIndex := -1
openBlockType := ""
seenText := ""
openToolIndex := -1
openToolID := ""
openToolName := ""
seenToolJSON := ""
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadString('\n')
if err != nil && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("stream read error: %w", err)
}
if !strings.HasPrefix(line, "data:") {
if errors.Is(err, io.EOF) {
break
}
continue
}
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if payload == "" || payload == "[DONE]" {
if errors.Is(err, io.EOF) {
break
}
continue
}
geminiResp, err := unwrapGeminiResponse([]byte(payload))
if err != nil {
continue
}
if fr := extractGeminiFinishReason(geminiResp); fr != "" {
finishReason = fr
}
parts := extractGeminiParts(geminiResp)
for _, part := range parts {
if text, ok := part["text"].(string); ok && text != "" {
delta, newSeen := computeGeminiTextDelta(seenText, text)
seenText = newSeen
if delta == "" {
continue
}
if openBlockType != "text" {
if openBlockIndex >= 0 {
writeSSE(c.Writer, "content_block_stop", map[string]any{
"type": "content_block_stop",
"index": openBlockIndex,
})
}
openBlockType = "text"
openBlockIndex = nextBlockIndex
nextBlockIndex++
writeSSE(c.Writer, "content_block_start", map[string]any{
"type": "content_block_start",
"index": openBlockIndex,
"content_block": map[string]any{
"type": "text",
"text": "",
},
})
}
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
writeSSE(c.Writer, "content_block_delta", map[string]any{
"type": "content_block_delta",
"index": openBlockIndex,
"delta": map[string]any{
"type": "text_delta",
"text": delta,
},
})
flusher.Flush()
continue
}
if fc, ok := part["functionCall"].(map[string]any); ok && fc != nil {
name, _ := fc["name"].(string)
args := fc["args"]
if strings.TrimSpace(name) == "" {
name = "tool"
}
// Close any open text block before tool_use.
if openBlockIndex >= 0 {
writeSSE(c.Writer, "content_block_stop", map[string]any{
"type": "content_block_stop",
"index": openBlockIndex,
})
openBlockIndex = -1
openBlockType = ""
}
// If we receive streamed tool args in pieces, keep a single tool block open and emit deltas.
if openToolIndex >= 0 && openToolName != name {
writeSSE(c.Writer, "content_block_stop", map[string]any{
"type": "content_block_stop",
"index": openToolIndex,
})
openToolIndex = -1
openToolName = ""
seenToolJSON = ""
}
if openToolIndex < 0 {
openToolID = "toolu_" + randomHex(8)
openToolIndex = nextBlockIndex
openToolName = name
nextBlockIndex++
sawToolUse = true
writeSSE(c.Writer, "content_block_start", map[string]any{
"type": "content_block_start",
"index": openToolIndex,
"content_block": map[string]any{
"type": "tool_use",
"id": openToolID,
"name": name,
"input": map[string]any{},
},
})
}
argsJSONText := "{}"
switch v := args.(type) {
case nil:
// keep default "{}"
case string:
if strings.TrimSpace(v) != "" {
argsJSONText = v
}
default:
if b, err := json.Marshal(args); err == nil && len(b) > 0 {
argsJSONText = string(b)
}
}
delta, newSeen := computeGeminiTextDelta(seenToolJSON, argsJSONText)
seenToolJSON = newSeen
if delta != "" {
writeSSE(c.Writer, "content_block_delta", map[string]any{
"type": "content_block_delta",
"index": openToolIndex,
"delta": map[string]any{
"type": "input_json_delta",
"partial_json": delta,
},
})
}
flusher.Flush()
}
}
if u := extractGeminiUsage(geminiResp); u != nil {
usage = *u
}
// Process the final unterminated line at EOF as well.
if errors.Is(err, io.EOF) {
break
}
}
if openBlockIndex >= 0 {
writeSSE(c.Writer, "content_block_stop", map[string]any{
"type": "content_block_stop",
"index": openBlockIndex,
})
}
if openToolIndex >= 0 {
writeSSE(c.Writer, "content_block_stop", map[string]any{
"type": "content_block_stop",
"index": openToolIndex,
})
}
stopReason := mapGeminiFinishReasonToClaudeStopReason(finishReason)
if sawToolUse {
stopReason = "tool_use"
}
usageObj := map[string]any{
"output_tokens": usage.OutputTokens,
}
if usage.InputTokens > 0 {
usageObj["input_tokens"] = usage.InputTokens
}
writeSSE(c.Writer, "message_delta", map[string]any{
"type": "message_delta",
"delta": map[string]any{
"stop_reason": stopReason,
"stop_sequence": nil,
},
"usage": usageObj,
})
writeSSE(c.Writer, "message_stop", map[string]any{
"type": "message_stop",
})
flusher.Flush()
return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
}
func writeSSE(w io.Writer, event string, data any) {
if event != "" {
_, _ = fmt.Fprintf(w, "event: %s\n", event)
}
b, _ := json.Marshal(data)
_, _ = fmt.Fprintf(w, "data: %s\n\n", string(b))
}
func randomHex(nBytes int) string {
b := make([]byte, nBytes)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
func (s *GeminiMessagesCompatService) writeClaudeError(c *gin.Context, status int, errType, message string) error {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{"type": errType, "message": message},
})
return fmt.Errorf("%s", message)
}
func (s *GeminiMessagesCompatService) writeGoogleError(c *gin.Context, status int, message string) error {
c.JSON(status, gin.H{
"error": gin.H{
"code": status,
"message": message,
"status": googleapi.HTTPStatusToGoogleStatus(status),
},
})
return fmt.Errorf("%s", message)
}
func unwrapIfNeeded(isOAuth bool, raw []byte) []byte {
if !isOAuth {
return raw
}
inner, err := unwrapGeminiResponse(raw)
if err != nil {
return raw
}
b, err := json.Marshal(inner)
if err != nil {
return raw
}
return b
}
func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsage, error) {
reader := bufio.NewReader(body)
var last map[string]any
var lastWithParts map[string]any
usage := &ClaudeUsage{}
for {
line, err := reader.ReadString('\n')
if len(line) > 0 {
trimmed := strings.TrimRight(line, "\r\n")
if strings.HasPrefix(trimmed, "data:") {
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
switch payload {
case "", "[DONE]":
if payload == "[DONE]" {
return pickGeminiCollectResult(last, lastWithParts), usage, nil
}
default:
var parsed map[string]any
if isOAuth {
inner, err := unwrapGeminiResponse([]byte(payload))
if err == nil && inner != nil {
parsed = inner
}
} else {
_ = json.Unmarshal([]byte(payload), &parsed)
}
if parsed != nil {
last = parsed
if u := extractGeminiUsage(parsed); u != nil {
usage = u
}
if parts := extractGeminiParts(parsed); len(parts) > 0 {
lastWithParts = parsed
}
}
}
}
}
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, nil, err
}
}
return pickGeminiCollectResult(last, lastWithParts), usage, nil
}
func pickGeminiCollectResult(last map[string]any, lastWithParts map[string]any) map[string]any {
if lastWithParts != nil {
return lastWithParts
}
if last != nil {
return last
}
return map[string]any{}
}
type geminiNativeStreamResult struct {
usage *ClaudeUsage
firstTokenMs *int
}
func isGeminiInsufficientScope(headers http.Header, body []byte) bool {
if strings.Contains(strings.ToLower(headers.Get("Www-Authenticate")), "insufficient_scope") {
return true
}
lower := strings.ToLower(string(body))
return strings.Contains(lower, "insufficient authentication scopes") || strings.Contains(lower, "access_token_scope_insufficient")
}
func estimateGeminiCountTokens(reqBody []byte) int {
var obj map[string]any
if err := json.Unmarshal(reqBody, &obj); err != nil {
return 0
}
var texts []string
// systemInstruction.parts[].text
if si, ok := obj["systemInstruction"].(map[string]any); ok {
if parts, ok := si["parts"].([]any); ok {
for _, p := range parts {
if pm, ok := p.(map[string]any); ok {
if t, ok := pm["text"].(string); ok && strings.TrimSpace(t) != "" {
texts = append(texts, t)
}
}
}
}
}
// contents[].parts[].text
if contents, ok := obj["contents"].([]any); ok {
for _, c := range contents {
cm, ok := c.(map[string]any)
if !ok {
continue
}
parts, ok := cm["parts"].([]any)
if !ok {
continue
}
for _, p := range parts {
pm, ok := p.(map[string]any)
if !ok {
continue
}
if t, ok := pm["text"].(string); ok && strings.TrimSpace(t) != "" {
texts = append(texts, t)
}
}
}
}
total := 0
for _, t := range texts {
total += estimateTokensForText(t)
}
if total < 0 {
return 0
}
return total
}
func estimateTokensForText(s string) int {
s = strings.TrimSpace(s)
if s == "" {
return 0
}
runes := []rune(s)
if len(runes) == 0 {
return 0
}
ascii := 0
for _, r := range runes {
if r <= 0x7f {
ascii++
}
}
asciiRatio := float64(ascii) / float64(len(runes))
if asciiRatio >= 0.8 {
// Roughly 4 chars per token for English-like text.
return (len(runes) + 3) / 4
}
// For CJK-heavy text, approximate 1 rune per token.
return len(runes)
}
type UpstreamHTTPResult struct {
StatusCode int
Headers http.Header
Body []byte
}
func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) {
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var parsed map[string]any
if isOAuth {
parsed, err = unwrapGeminiResponse(respBody)
if err == nil && parsed != nil {
respBody, _ = json.Marshal(parsed)
}
} else {
_ = json.Unmarshal(respBody, &parsed)
}
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}
c.Data(resp.StatusCode, contentType, respBody)
if parsed != nil {
if u := extractGeminiUsage(parsed); u != nil {
return u, nil
}
}
return &ClaudeUsage{}, nil
}
func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) {
c.Status(resp.StatusCode)
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "text/event-stream; charset=utf-8"
}
c.Header("Content-Type", contentType)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
reader := bufio.NewReader(resp.Body)
usage := &ClaudeUsage{}
var firstTokenMs *int
for {
line, err := reader.ReadString('\n')
if len(line) > 0 {
trimmed := strings.TrimRight(line, "\r\n")
if strings.HasPrefix(trimmed, "data:") {
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
// Keepalive / done markers
if payload == "" || payload == "[DONE]" {
_, _ = io.WriteString(c.Writer, line)
flusher.Flush()
} else {
var rawToWrite string
rawToWrite = payload
var parsed map[string]any
if isOAuth {
inner, err := unwrapGeminiResponse([]byte(payload))
if err == nil && inner != nil {
parsed = inner
if b, err := json.Marshal(inner); err == nil {
rawToWrite = string(b)
}
}
} else {
_ = json.Unmarshal([]byte(payload), &parsed)
}
if parsed != nil {
if u := extractGeminiUsage(parsed); u != nil {
usage = u
}
}
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
if isOAuth {
// SSE format requires double newline (\n\n) to separate events
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", rawToWrite)
} else {
// Pass-through for AI Studio responses.
_, _ = io.WriteString(c.Writer, line)
}
flusher.Flush()
}
} else {
_, _ = io.WriteString(c.Writer, line)
flusher.Flush()
}
}
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, err
}
}
return &geminiNativeStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
// ForwardAIStudioGET forwards a GET request to AI Studio (generativelanguage.googleapis.com) for
// endpoints like /v1beta/models and /v1beta/models/{model}.
//
// This is used to support Gemini SDKs that call models listing endpoints before generation.
func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, account *Account, path string) (*UpstreamHTTPResult, error) {
if account == nil {
return nil, errors.New("account is nil")
}
path = strings.TrimSpace(path)
if path == "" || !strings.HasPrefix(path, "/") {
return nil, errors.New("invalid path")
}
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
fullURL := strings.TrimRight(baseURL, "/") + path
var proxyURL string
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
if err != nil {
return nil, err
}
switch account.Type {
case AccountTypeAPIKey:
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
if apiKey == "" {
return nil, errors.New("gemini api_key not configured")
}
req.Header.Set("x-goog-api-key", apiKey)
case AccountTypeOAuth:
if s.tokenProvider == nil {
return nil, errors.New("gemini token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+accessToken)
default:
return nil, fmt.Errorf("unsupported account type: %s", account.Type)
}
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
return &UpstreamHTTPResult{
StatusCode: resp.StatusCode,
Headers: resp.Header.Clone(),
Body: body,
}, nil
}
func unwrapGeminiResponse(raw []byte) (map[string]any, error) {
var outer map[string]any
if err := json.Unmarshal(raw, &outer); err != nil {
return nil, err
}
if resp, ok := outer["response"].(map[string]any); ok && resp != nil {
return resp, nil
}
return outer, nil
}
func convertGeminiToClaudeMessage(geminiResp map[string]any, originalModel string) (map[string]any, *ClaudeUsage) {
usage := extractGeminiUsage(geminiResp)
if usage == nil {
usage = &ClaudeUsage{}
}
contentBlocks := make([]any, 0)
sawToolUse := false
if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 {
if cand, ok := candidates[0].(map[string]any); ok {
if content, ok := cand["content"].(map[string]any); ok {
if parts, ok := content["parts"].([]any); ok {
for _, part := range parts {
pm, ok := part.(map[string]any)
if !ok {
continue
}
if text, ok := pm["text"].(string); ok && text != "" {
contentBlocks = append(contentBlocks, map[string]any{
"type": "text",
"text": text,
})
}
if fc, ok := pm["functionCall"].(map[string]any); ok {
name, _ := fc["name"].(string)
if strings.TrimSpace(name) == "" {
name = "tool"
}
args := fc["args"]
sawToolUse = true
contentBlocks = append(contentBlocks, map[string]any{
"type": "tool_use",
"id": "toolu_" + randomHex(8),
"name": name,
"input": args,
})
}
}
}
}
}
}
stopReason := mapGeminiFinishReasonToClaudeStopReason(extractGeminiFinishReason(geminiResp))
if sawToolUse {
stopReason = "tool_use"
}
resp := map[string]any{
"id": "msg_" + randomHex(12),
"type": "message",
"role": "assistant",
"model": originalModel,
"content": contentBlocks,
"stop_reason": stopReason,
"stop_sequence": nil,
"usage": map[string]any{
"input_tokens": usage.InputTokens,
"output_tokens": usage.OutputTokens,
},
}
return resp, usage
}
func extractGeminiUsage(geminiResp map[string]any) *ClaudeUsage {
usageMeta, ok := geminiResp["usageMetadata"].(map[string]any)
if !ok || usageMeta == nil {
return nil
}
prompt, _ := asInt(usageMeta["promptTokenCount"])
cand, _ := asInt(usageMeta["candidatesTokenCount"])
return &ClaudeUsage{
InputTokens: prompt,
OutputTokens: cand,
}
}
func asInt(v any) (int, bool) {
switch t := v.(type) {
case float64:
return int(t), true
case int:
return t, true
case int64:
return int(t), true
case json.Number:
i, err := t.Int64()
if err != nil {
return 0, false
}
return int(i), true
default:
return 0, false
}
}
func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) {
if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) {
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
return
}
if statusCode != 429 {
return
}
oauthType := account.GeminiOAuthType()
tierID := account.GeminiTierID()
projectID := strings.TrimSpace(account.GetCredential("project_id"))
isCodeAssist := account.IsGeminiCodeAssist()
resetAt := ParseGeminiRateLimitResetTime(body)
if resetAt == nil {
// 根据账号类型使用不同的默认重置时间
var ra time.Time
if isCodeAssist {
// Code Assist: fallback cooldown by tier
cooldown := geminiCooldownForTier(tierID)
if s.rateLimitService != nil {
cooldown = s.rateLimitService.GeminiCooldown(ctx, account)
}
ra = time.Now().Add(cooldown)
log.Printf("[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second))
} else {
// API Key / AI Studio OAuth: PST 午夜
if ts := nextGeminiDailyResetUnix(); ts != nil {
ra = time.Unix(*ts, 0)
log.Printf("[Gemini 429] Account %d (API Key/AI Studio, type=%s) rate limited, reset at PST midnight (%v)", account.ID, account.Type, ra)
} else {
// 兜底5 分钟
ra = time.Now().Add(5 * time.Minute)
log.Printf("[Gemini 429] Account %d rate limited, fallback to 5min", account.ID)
}
}
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra)
return
}
// 使用解析到的重置时间
resetTime := time.Unix(*resetAt, 0)
_ = s.accountRepo.SetRateLimited(ctx, account.ID, resetTime)
log.Printf("[Gemini 429] Account %d rate limited until %v (oauth_type=%s, tier=%s)",
account.ID, resetTime, oauthType, tierID)
}
// ParseGeminiRateLimitResetTime 解析 Gemini 格式的 429 响应,返回重置时间的 Unix 时间戳
func ParseGeminiRateLimitResetTime(body []byte) *int64 {
// Try to parse metadata.quotaResetDelay like "12.345s"
var parsed map[string]any
if err := json.Unmarshal(body, &parsed); err == nil {
if errObj, ok := parsed["error"].(map[string]any); ok {
if msg, ok := errObj["message"].(string); ok {
if looksLikeGeminiDailyQuota(msg) {
if ts := nextGeminiDailyResetUnix(); ts != nil {
return ts
}
}
}
if details, ok := errObj["details"].([]any); ok {
for _, d := range details {
dm, ok := d.(map[string]any)
if !ok {
continue
}
if meta, ok := dm["metadata"].(map[string]any); ok {
if v, ok := meta["quotaResetDelay"].(string); ok {
if dur, err := time.ParseDuration(v); err == nil {
ts := time.Now().Unix() + int64(dur.Seconds())
return &ts
}
}
}
}
}
}
}
// Match "Please retry in Xs"
matches := retryInRegex.FindStringSubmatch(string(body))
if len(matches) == 2 {
if dur, err := time.ParseDuration(matches[1] + "s"); err == nil {
ts := time.Now().Unix() + int64(math.Ceil(dur.Seconds()))
return &ts
}
}
return nil
}
func looksLikeGeminiDailyQuota(message string) bool {
m := strings.ToLower(message)
if strings.Contains(m, "per day") || strings.Contains(m, "requests per day") || strings.Contains(m, "quota") && strings.Contains(m, "per day") {
return true
}
return false
}
func nextGeminiDailyResetUnix() *int64 {
reset := geminiDailyResetTime(time.Now())
ts := reset.Unix()
return &ts
}
func extractGeminiFinishReason(geminiResp map[string]any) string {
if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 {
if cand, ok := candidates[0].(map[string]any); ok {
if fr, ok := cand["finishReason"].(string); ok {
return fr
}
}
}
return ""
}
func extractGeminiParts(geminiResp map[string]any) []map[string]any {
if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 {
if cand, ok := candidates[0].(map[string]any); ok {
if content, ok := cand["content"].(map[string]any); ok {
if partsAny, ok := content["parts"].([]any); ok && len(partsAny) > 0 {
out := make([]map[string]any, 0, len(partsAny))
for _, p := range partsAny {
pm, ok := p.(map[string]any)
if !ok {
continue
}
out = append(out, pm)
}
return out
}
}
}
}
return nil
}
func computeGeminiTextDelta(seen, incoming string) (delta, newSeen string) {
incoming = strings.TrimSuffix(incoming, "\u0000")
if incoming == "" {
return "", seen
}
// Cumulative mode: incoming contains full text so far.
if strings.HasPrefix(incoming, seen) {
return strings.TrimPrefix(incoming, seen), incoming
}
// Duplicate/rewind: ignore.
if strings.HasPrefix(seen, incoming) {
return "", seen
}
// Delta mode: treat incoming as incremental chunk.
return incoming, seen + incoming
}
func mapGeminiFinishReasonToClaudeStopReason(finishReason string) string {
switch strings.ToUpper(strings.TrimSpace(finishReason)) {
case "MAX_TOKENS":
return "max_tokens"
case "STOP":
return "end_turn"
default:
return "end_turn"
}
}
func convertClaudeMessagesToGeminiGenerateContent(body []byte) ([]byte, error) {
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
}
toolUseIDToName := make(map[string]string)
systemText := extractClaudeSystemText(req["system"])
contents, err := convertClaudeMessagesToGeminiContents(req["messages"], toolUseIDToName)
if err != nil {
return nil, err
}
out := make(map[string]any)
if systemText != "" {
out["systemInstruction"] = map[string]any{
"parts": []any{map[string]any{"text": systemText}},
}
}
out["contents"] = contents
if tools := convertClaudeToolsToGeminiTools(req["tools"]); tools != nil {
out["tools"] = tools
}
generationConfig := convertClaudeGenerationConfig(req)
if generationConfig != nil {
out["generationConfig"] = generationConfig
}
stripGeminiFunctionIDs(out)
return json.Marshal(out)
}
func stripGeminiFunctionIDs(req map[string]any) {
// Defensive cleanup: some upstreams reject unexpected `id` fields in functionCall/functionResponse.
contents, ok := req["contents"].([]any)
if !ok {
return
}
for _, c := range contents {
cm, ok := c.(map[string]any)
if !ok {
continue
}
contentParts, ok := cm["parts"].([]any)
if !ok {
continue
}
for _, p := range contentParts {
pm, ok := p.(map[string]any)
if !ok {
continue
}
if fc, ok := pm["functionCall"].(map[string]any); ok && fc != nil {
delete(fc, "id")
}
if fr, ok := pm["functionResponse"].(map[string]any); ok && fr != nil {
delete(fr, "id")
}
}
}
}
func extractClaudeSystemText(system any) string {
switch v := system.(type) {
case string:
return strings.TrimSpace(v)
case []any:
var parts []string
for _, p := range v {
pm, ok := p.(map[string]any)
if !ok {
continue
}
if t, _ := pm["type"].(string); t != "text" {
continue
}
if text, ok := pm["text"].(string); ok && strings.TrimSpace(text) != "" {
parts = append(parts, text)
}
}
return strings.TrimSpace(strings.Join(parts, "\n"))
default:
return ""
}
}
func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[string]string) ([]any, error) {
arr, ok := messages.([]any)
if !ok {
return nil, errors.New("messages must be an array")
}
out := make([]any, 0, len(arr))
for _, m := range arr {
mm, ok := m.(map[string]any)
if !ok {
continue
}
role, _ := mm["role"].(string)
role = strings.ToLower(strings.TrimSpace(role))
gRole := "user"
if role == "assistant" {
gRole = "model"
}
parts := make([]any, 0)
switch content := mm["content"].(type) {
case string:
if strings.TrimSpace(content) != "" {
parts = append(parts, map[string]any{"text": content})
}
case []any:
for _, block := range content {
bm, ok := block.(map[string]any)
if !ok {
continue
}
bt, _ := bm["type"].(string)
switch bt {
case "text":
if text, ok := bm["text"].(string); ok && strings.TrimSpace(text) != "" {
parts = append(parts, map[string]any{"text": text})
}
case "tool_use":
id, _ := bm["id"].(string)
name, _ := bm["name"].(string)
if strings.TrimSpace(id) != "" && strings.TrimSpace(name) != "" {
toolUseIDToName[id] = name
}
parts = append(parts, map[string]any{
"functionCall": map[string]any{
"name": name,
"args": bm["input"],
},
})
case "tool_result":
toolUseID, _ := bm["tool_use_id"].(string)
name := toolUseIDToName[toolUseID]
if name == "" {
name = "tool"
}
parts = append(parts, map[string]any{
"functionResponse": map[string]any{
"name": name,
"response": map[string]any{
"content": extractClaudeContentText(bm["content"]),
},
},
})
case "image":
if src, ok := bm["source"].(map[string]any); ok {
if srcType, _ := src["type"].(string); srcType == "base64" {
mediaType, _ := src["media_type"].(string)
data, _ := src["data"].(string)
if mediaType != "" && data != "" {
parts = append(parts, map[string]any{
"inlineData": map[string]any{
"mimeType": mediaType,
"data": data,
},
})
}
}
}
default:
// best-effort: preserve unknown blocks as text
if b, err := json.Marshal(bm); err == nil {
parts = append(parts, map[string]any{"text": string(b)})
}
}
}
default:
// ignore
}
out = append(out, map[string]any{
"role": gRole,
"parts": parts,
})
}
return out, nil
}
func extractClaudeContentText(v any) string {
switch t := v.(type) {
case string:
return t
case []any:
var sb strings.Builder
for _, part := range t {
pm, ok := part.(map[string]any)
if !ok {
continue
}
if pm["type"] == "text" {
if text, ok := pm["text"].(string); ok {
_, _ = sb.WriteString(text)
}
}
}
return sb.String()
default:
b, _ := json.Marshal(t)
return string(b)
}
}
func convertClaudeToolsToGeminiTools(tools any) []any {
arr, ok := tools.([]any)
if !ok || len(arr) == 0 {
return nil
}
funcDecls := make([]any, 0, len(arr))
for _, t := range arr {
tm, ok := t.(map[string]any)
if !ok {
continue
}
var name, desc string
var params any
// 检查是否为 custom 类型工具 (MCP)
toolType, _ := tm["type"].(string)
if toolType == "custom" {
// Custom 格式: 从 custom 字段获取 description 和 input_schema
custom, ok := tm["custom"].(map[string]any)
if !ok {
continue
}
name, _ = tm["name"].(string)
desc, _ = custom["description"].(string)
params = custom["input_schema"]
} else {
// 标准格式: 从顶层字段获取
name, _ = tm["name"].(string)
desc, _ = tm["description"].(string)
params = tm["input_schema"]
}
if name == "" {
continue
}
// 为 nil params 提供默认值
if params == nil {
params = map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
// 清理 JSON Schema
cleanedParams := cleanToolSchema(params)
funcDecls = append(funcDecls, map[string]any{
"name": name,
"description": desc,
"parameters": cleanedParams,
})
}
if len(funcDecls) == 0 {
return nil
}
return []any{
map[string]any{
"functionDeclarations": funcDecls,
},
}
}
// cleanToolSchema 清理工具的 JSON Schema移除 Gemini 不支持的字段
func cleanToolSchema(schema any) any {
if schema == nil {
return nil
}
switch v := schema.(type) {
case map[string]any:
cleaned := make(map[string]any)
for key, value := range v {
// 跳过不支持的字段
if key == "$schema" || key == "$id" || key == "$ref" ||
key == "additionalProperties" || key == "minLength" ||
key == "maxLength" || key == "minItems" || key == "maxItems" {
continue
}
// 递归清理嵌套对象
cleaned[key] = cleanToolSchema(value)
}
// 规范化 type 字段为大写
if typeVal, ok := cleaned["type"].(string); ok {
cleaned["type"] = strings.ToUpper(typeVal)
}
return cleaned
case []any:
cleaned := make([]any, len(v))
for i, item := range v {
cleaned[i] = cleanToolSchema(item)
}
return cleaned
default:
return v
}
}
func convertClaudeGenerationConfig(req map[string]any) map[string]any {
out := make(map[string]any)
if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 {
out["maxOutputTokens"] = mt
}
if temp, ok := req["temperature"].(float64); ok {
out["temperature"] = temp
}
if topP, ok := req["top_p"].(float64); ok {
out["topP"] = topP
}
if stopSeq, ok := req["stop_sequences"].([]any); ok && len(stopSeq) > 0 {
out["stopSequences"] = stopSeq
}
if len(out) == 0 {
return nil
}
return out
}