Files
sub2api/backend/internal/service/antigravity_gateway_service.go

3743 lines
124 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"log/slog"
mathrand "math/rand"
"net"
"net/http"
"os"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
const (
antigravityStickySessionTTL = time.Hour
antigravityMaxRetries = 3
antigravityRetryBaseDelay = 1 * time.Second
antigravityRetryMaxDelay = 16 * time.Second
// 限流相关常量
// antigravityRateLimitThreshold 限流等待/切换阈值
// - 智能重试retryDelay < 此阈值时等待后重试,>= 此阈值时直接限流模型
// - 预检查:剩余限流时间 < 此阈值时等待,>= 此阈值时切换账号
antigravityRateLimitThreshold = 7 * time.Second
antigravitySmartRetryMinWait = 1 * time.Second // 智能重试最小等待时间
antigravitySmartRetryMaxAttempts = 1 // 智能重试最大次数(仅重试 1 次,防止重复限流/长期等待)
antigravityDefaultRateLimitDuration = 30 * time.Second // 默认限流时间(无 retryDelay 时使用)
// Google RPC 状态和类型常量
googleRPCStatusResourceExhausted = "RESOURCE_EXHAUSTED"
googleRPCStatusUnavailable = "UNAVAILABLE"
googleRPCTypeRetryInfo = "type.googleapis.com/google.rpc.RetryInfo"
googleRPCTypeErrorInfo = "type.googleapis.com/google.rpc.ErrorInfo"
googleRPCReasonModelCapacityExhausted = "MODEL_CAPACITY_EXHAUSTED"
googleRPCReasonRateLimitExceeded = "RATE_LIMIT_EXCEEDED"
)
// antigravityPassthroughErrorMessages 透传给客户端的错误消息白名单(小写)
// 匹配时使用 strings.Contains无需完全匹配
var antigravityPassthroughErrorMessages = []string{
"prompt is too long",
}
const (
antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
)
// AntigravityAccountSwitchError 账号切换信号
// 当账号限流时间超过阈值时,通知上层切换账号
type AntigravityAccountSwitchError struct {
OriginalAccountID int64
RateLimitedModel string
IsStickySession bool // 是否为粘性会话切换(决定是否缓存计费)
}
func (e *AntigravityAccountSwitchError) Error() string {
return fmt.Sprintf("account %d model %s rate limited, need switch",
e.OriginalAccountID, e.RateLimitedModel)
}
// IsAntigravityAccountSwitchError 检查错误是否为账号切换信号
func IsAntigravityAccountSwitchError(err error) (*AntigravityAccountSwitchError, bool) {
var switchErr *AntigravityAccountSwitchError
if errors.As(err, &switchErr) {
return switchErr, true
}
return nil, false
}
// PromptTooLongError 表示上游明确返回 prompt too long
type PromptTooLongError struct {
StatusCode int
RequestID string
Body []byte
}
func (e *PromptTooLongError) Error() string {
return fmt.Sprintf("prompt too long: status=%d", e.StatusCode)
}
// antigravityRetryLoopParams 重试循环的参数
type antigravityRetryLoopParams struct {
ctx context.Context
prefix string
account *Account
proxyURL string
accessToken string
action string
body []byte
quotaScope AntigravityQuotaScope
c *gin.Context
httpUpstream HTTPUpstream
settingService *SettingService
accountRepo AccountRepository // 用于智能重试的模型级别限流
handleError func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult
requestedModel string // 用于限流检查的原始请求模型
isStickySession bool // 是否为粘性会话(用于账号切换时的缓存计费判断)
groupID int64 // 用于模型级限流时清除粘性会话
sessionHash string // 用于模型级限流时清除粘性会话
}
// antigravityRetryLoopResult 重试循环的结果
type antigravityRetryLoopResult struct {
resp *http.Response
}
// smartRetryAction 智能重试的处理结果
type smartRetryAction int
const (
smartRetryActionContinue smartRetryAction = iota // 继续默认重试逻辑
smartRetryActionBreakWithResp // 结束循环并返回 resp
smartRetryActionContinueURL // 继续 URL fallback 循环
)
// smartRetryResult 智能重试的结果
type smartRetryResult struct {
action smartRetryAction
resp *http.Response
err error
switchError *AntigravityAccountSwitchError // 模型限流时返回账号切换信号
}
// handleSmartRetry 处理 OAuth 账号的智能重试逻辑
// 将 429/503 限流处理逻辑抽取为独立函数,减少 antigravityRetryLoop 的复杂度
func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParams, resp *http.Response, respBody []byte, baseURL string, urlIdx int, availableURLs []string) *smartRetryResult {
// "Resource has been exhausted" 是 URL 级别限流,切换 URL仅 429
if resp.StatusCode == http.StatusTooManyRequests && isURLLevelRateLimit(respBody) && urlIdx < len(availableURLs)-1 {
log.Printf("%s URL fallback (429): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
return &smartRetryResult{action: smartRetryActionContinueURL}
}
// 判断是否触发智能重试
shouldSmartRetry, shouldRateLimitModel, waitDuration, modelName := shouldTriggerAntigravitySmartRetry(p.account, respBody)
// 情况1: retryDelay >= 阈值,限流模型并切换账号
if shouldRateLimitModel {
log.Printf("%s status=%d oauth_long_delay model=%s account=%d (model rate limit, switch account)",
p.prefix, resp.StatusCode, modelName, p.account.ID)
resetAt := time.Now().Add(antigravityDefaultRateLimitDuration)
if !setModelRateLimitByModelName(p.ctx, p.accountRepo, p.account.ID, modelName, p.prefix, resp.StatusCode, resetAt, false) {
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope, p.groupID, p.sessionHash, p.isStickySession)
log.Printf("%s status=%d rate_limited account=%d (no scope mapping)", p.prefix, resp.StatusCode, p.account.ID)
} else {
s.updateAccountModelRateLimitInCache(p.ctx, p.account, modelName, resetAt)
}
// 返回账号切换信号,让上层切换账号重试
return &smartRetryResult{
action: smartRetryActionBreakWithResp,
switchError: &AntigravityAccountSwitchError{
OriginalAccountID: p.account.ID,
RateLimitedModel: modelName,
IsStickySession: p.isStickySession,
},
}
}
// 情况2: retryDelay < 阈值,智能重试(最多 antigravitySmartRetryMaxAttempts 次)
if shouldSmartRetry {
var lastRetryResp *http.Response
var lastRetryBody []byte
for attempt := 1; attempt <= antigravitySmartRetryMaxAttempts; attempt++ {
log.Printf("%s status=%d oauth_smart_retry attempt=%d/%d delay=%v model=%s account=%d",
p.prefix, resp.StatusCode, attempt, antigravitySmartRetryMaxAttempts, waitDuration, modelName, p.account.ID)
select {
case <-p.ctx.Done():
log.Printf("%s status=context_canceled_during_smart_retry", p.prefix)
return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()}
case <-time.After(waitDuration):
}
// 智能重试:创建新请求
retryReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body)
if err != nil {
log.Printf("%s status=smart_retry_request_build_failed error=%v", p.prefix, err)
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope, p.groupID, p.sessionHash, p.isStickySession)
return &smartRetryResult{
action: smartRetryActionBreakWithResp,
resp: &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
},
}
}
retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency)
if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable {
log.Printf("%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, antigravitySmartRetryMaxAttempts)
return &smartRetryResult{action: smartRetryActionBreakWithResp, resp: retryResp}
}
// 网络错误时,继续重试
if retryErr != nil || retryResp == nil {
log.Printf("%s status=smart_retry_network_error attempt=%d/%d error=%v", p.prefix, attempt, antigravitySmartRetryMaxAttempts, retryErr)
continue
}
// 重试失败,关闭之前的响应
if lastRetryResp != nil {
_ = lastRetryResp.Body.Close()
}
lastRetryResp = retryResp
if retryResp != nil {
lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close()
}
// 解析新的重试信息,用于下次重试的等待时间
if attempt < antigravitySmartRetryMaxAttempts && lastRetryBody != nil {
newShouldRetry, _, newWaitDuration, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody)
if newShouldRetry && newWaitDuration > 0 {
waitDuration = newWaitDuration
}
}
}
// 所有重试都失败,限流当前模型并切换账号
log.Printf("%s status=%d smart_retry_exhausted attempts=%d model=%s account=%d (switch account)",
p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID)
resetAt := time.Now().Add(antigravityDefaultRateLimitDuration)
if p.accountRepo != nil && modelName != "" {
if err := p.accountRepo.SetModelRateLimit(p.ctx, p.account.ID, modelName, resetAt); err != nil {
log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", p.prefix, resp.StatusCode, modelName, err)
} else {
log.Printf("%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v",
p.prefix, resp.StatusCode, modelName, p.account.ID, antigravityDefaultRateLimitDuration)
s.updateAccountModelRateLimitInCache(p.ctx, p.account, modelName, resetAt)
}
}
// 清除粘性会话绑定,避免下次请求仍命中限流账号
if s.cache != nil && p.sessionHash != "" {
_ = s.cache.DeleteSessionAccountID(p.ctx, p.groupID, p.sessionHash)
}
// 返回账号切换信号,让上层切换账号重试
return &smartRetryResult{
action: smartRetryActionBreakWithResp,
switchError: &AntigravityAccountSwitchError{
OriginalAccountID: p.account.ID,
RateLimitedModel: modelName,
IsStickySession: p.isStickySession,
},
}
}
// 未触发智能重试,继续默认重试逻辑
return &smartRetryResult{action: smartRetryActionContinue}
}
// antigravityRetryLoop 执行带 URL fallback 的重试循环
func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
// 预检查:如果账号已限流,直接返回切换信号
if p.requestedModel != "" {
if remaining := p.account.GetRateLimitRemainingTimeWithContext(p.ctx, p.requestedModel); remaining > 0 {
log.Printf("%s pre_check: rate_limit_switch remaining=%v model=%s account=%d",
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
return nil, &AntigravityAccountSwitchError{
OriginalAccountID: p.account.ID,
RateLimitedModel: p.requestedModel,
IsStickySession: p.isStickySession,
}
}
}
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
if len(availableURLs) == 0 {
availableURLs = antigravity.BaseURLs
}
var resp *http.Response
var usedBaseURL string
logBody := p.settingService != nil && p.settingService.cfg != nil && p.settingService.cfg.Gateway.LogUpstreamErrorBody
maxBytes := 2048
if p.settingService != nil && p.settingService.cfg != nil && p.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = p.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
}
getUpstreamDetail := func(body []byte) string {
if !logBody {
return ""
}
return truncateString(string(body), maxBytes)
}
urlFallbackLoop:
for urlIdx, baseURL := range availableURLs {
usedBaseURL = baseURL
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
select {
case <-p.ctx.Done():
log.Printf("%s status=context_canceled error=%v", p.prefix, p.ctx.Err())
return nil, p.ctx.Err()
default:
}
upstreamReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body)
if err != nil {
return nil, err
}
// Capture upstream request body for ops retry of this attempt.
if p.c != nil && len(p.body) > 0 {
p.c.Set(OpsUpstreamRequestBodyKey, string(p.body))
}
resp, err = p.httpUpstream.Do(upstreamReq, p.proxyURL, p.account.ID, p.account.Concurrency)
if err == nil && resp == nil {
err = errors.New("upstream returned nil response")
}
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
Platform: p.account.Platform,
AccountID: p.account.ID,
AccountName: p.account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
log.Printf("%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
continue urlFallbackLoop
}
if attempt < antigravityMaxRetries {
log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, antigravityMaxRetries, err)
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
return nil, p.ctx.Err()
}
continue
}
log.Printf("%s status=request_failed retries_exhausted error=%v", p.prefix, err)
setOpsUpstreamError(p.c, 0, safeErr, "")
return nil, fmt.Errorf("upstream request failed after retries: %w", err)
}
// 统一处理错误响应
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
// ★ 统一入口:自定义错误码 + 临时不可调度
if handled, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled {
if policyErr != nil {
return nil, policyErr
}
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break urlFallbackLoop
}
// 429/503 限流处理:区分 URL 级别限流、智能重试和账户配额限流
if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable {
// 尝试智能重试处理OAuth 账号专用)
smartResult := s.handleSmartRetry(p, resp, respBody, baseURL, urlIdx, availableURLs)
switch smartResult.action {
case smartRetryActionContinueURL:
continue urlFallbackLoop
case smartRetryActionBreakWithResp:
if smartResult.err != nil {
return nil, smartResult.err
}
// 模型限流时返回切换账号信号
if smartResult.switchError != nil {
return nil, smartResult.switchError
}
resp = smartResult.resp
break urlFallbackLoop
}
// smartRetryActionContinue: 继续默认重试逻辑
// 账户/模型配额限流,重试 3 次(指数退避)- 默认逻辑(非 OAuth 账号或解析失败)
if attempt < antigravityMaxRetries {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
Platform: p.account.Platform,
AccountID: p.account.ID,
AccountName: p.account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
Message: upstreamMsg,
Detail: getUpstreamDetail(respBody),
})
log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200))
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
return nil, p.ctx.Err()
}
continue
}
// 重试用尽,标记账户限流
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope, p.groupID, p.sessionHash, p.isStickySession)
log.Printf("%s status=%d rate_limited base_url=%s body=%s", p.prefix, resp.StatusCode, baseURL, truncateForLog(respBody, 200))
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break urlFallbackLoop
}
// 其他可重试错误500/502/504/529不包括 429 和 503
if shouldRetryAntigravityError(resp.StatusCode) {
if attempt < antigravityMaxRetries {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
Platform: p.account.Platform,
AccountID: p.account.ID,
AccountName: p.account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
Message: upstreamMsg,
Detail: getUpstreamDetail(respBody),
})
log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
return nil, p.ctx.Err()
}
continue
}
}
// 其他 4xx 错误或重试用尽,直接返回
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break urlFallbackLoop
}
// 成功响应(< 400
break urlFallbackLoop
}
}
if resp != nil && resp.StatusCode < 400 && usedBaseURL != "" {
antigravity.DefaultURLAvailability.MarkSuccess(usedBaseURL)
}
return &antigravityRetryLoopResult{resp: resp}, nil
}
// shouldRetryAntigravityError 判断是否应该重试
func shouldRetryAntigravityError(statusCode int) bool {
switch statusCode {
case 429, 500, 502, 503, 504, 529:
return true
default:
return false
}
}
// isURLLevelRateLimit 判断是否为 URL 级别的限流(应切换 URL 重试)
// "Resource has been exhausted" 是 URL/节点级别限流,切换 URL 可能成功
// "exhausted your capacity on this model" 是账户/模型配额限流,切换 URL 无效
func isURLLevelRateLimit(body []byte) bool {
// 快速检查:包含 "Resource has been exhausted" 且不包含 "capacity on this model"
bodyStr := string(body)
return strings.Contains(bodyStr, "Resource has been exhausted") &&
!strings.Contains(bodyStr, "capacity on this model")
}
// isAntigravityConnectionError 判断是否为连接错误网络超时、DNS 失败、连接拒绝)
func isAntigravityConnectionError(err error) bool {
if err == nil {
return false
}
// 检查超时错误
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return true
}
// 检查连接错误DNS 失败、连接拒绝)
var opErr *net.OpError
return errors.As(err, &opErr)
}
// shouldAntigravityFallbackToNextURL 判断是否应切换到下一个 URL
// 仅连接错误和 HTTP 429 触发 URL 降级
func shouldAntigravityFallbackToNextURL(err error, statusCode int) bool {
if isAntigravityConnectionError(err) {
return true
}
return statusCode == http.StatusTooManyRequests
}
// getSessionID 从 gin.Context 获取 session_id用于日志追踪
func getSessionID(c *gin.Context) string {
if c == nil {
return ""
}
return c.GetHeader("session_id")
}
// logPrefix 生成统一的日志前缀
func logPrefix(sessionID, accountName string) string {
if sessionID != "" {
return fmt.Sprintf("[antigravity-Forward] session=%s account=%s", sessionID, accountName)
}
return fmt.Sprintf("[antigravity-Forward] account=%s", accountName)
}
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
type AntigravityGatewayService struct {
accountRepo AccountRepository
tokenProvider *AntigravityTokenProvider
rateLimitService *RateLimitService
httpUpstream HTTPUpstream
settingService *SettingService
cache GatewayCache // 用于模型级限流时清除粘性会话绑定
schedulerSnapshot *SchedulerSnapshotService
}
func NewAntigravityGatewayService(
accountRepo AccountRepository,
cache GatewayCache,
schedulerSnapshot *SchedulerSnapshotService,
tokenProvider *AntigravityTokenProvider,
rateLimitService *RateLimitService,
httpUpstream HTTPUpstream,
settingService *SettingService,
) *AntigravityGatewayService {
return &AntigravityGatewayService{
accountRepo: accountRepo,
tokenProvider: tokenProvider,
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
settingService: settingService,
cache: cache,
schedulerSnapshot: schedulerSnapshot,
}
}
// GetTokenProvider 返回 token provider
func (s *AntigravityGatewayService) GetTokenProvider() *AntigravityTokenProvider {
return s.tokenProvider
}
// getLogConfig 获取上游错误日志配置
// 返回是否记录日志体和最大字节数
func (s *AntigravityGatewayService) getLogConfig() (logBody bool, maxBytes int) {
maxBytes = 2048 // 默认值
if s.settingService == nil || s.settingService.cfg == nil {
return false, maxBytes
}
cfg := s.settingService.cfg.Gateway
if cfg.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = cfg.LogUpstreamErrorBodyMaxBytes
}
return cfg.LogUpstreamErrorBody, maxBytes
}
// getUpstreamErrorDetail 获取上游错误详情(用于日志记录)
func (s *AntigravityGatewayService) getUpstreamErrorDetail(body []byte) string {
logBody, maxBytes := s.getLogConfig()
if !logBody {
return ""
}
return truncateString(string(body), maxBytes)
}
// checkErrorPolicy nil 安全的包装
func (s *AntigravityGatewayService) checkErrorPolicy(ctx context.Context, account *Account, statusCode int, body []byte) ErrorPolicyResult {
if s.rateLimitService == nil {
return ErrorPolicyNone
}
return s.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, body)
}
// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环
func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, retErr error) {
switch s.checkErrorPolicy(p.ctx, p.account, statusCode, respBody) {
case ErrorPolicySkipped:
return true, nil
case ErrorPolicyMatched:
_ = p.handleError(p.ctx, p.prefix, p.account, statusCode, headers, respBody,
p.quotaScope, p.groupID, p.sessionHash, p.isStickySession)
return true, nil
case ErrorPolicyTempUnscheduled:
slog.Info("temp_unschedulable_matched",
"prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID)
return true, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession}
}
return false, nil
}
// mapAntigravityModel 获取映射后的模型名
// 完全依赖映射配置:账户映射(通配符)→ 默认映射兜底DefaultAntigravityModelMapping
// 注意:返回空字符串表示模型不被支持,调度时会过滤掉该账号
func mapAntigravityModel(account *Account, requestedModel string) string {
if account == nil {
return ""
}
// 获取映射表(未配置时自动使用 DefaultAntigravityModelMapping
mapping := account.GetModelMapping()
if len(mapping) == 0 {
return "" // 无映射配置(非 Antigravity 平台)
}
// 通过映射表查询(支持精确匹配 + 通配符)
mapped := account.GetMappedModel(requestedModel)
// 判断是否映射成功mapped != requestedModel 说明找到了映射规则)
if mapped != requestedModel {
return mapped
}
// 如果 mapped == requestedModel检查是否在映射表中配置精确或通配符
// 这区分两种情况:
// 1. 映射表中有 "model-a": "model-a"(显式透传)→ 返回 model-a
// 2. 通配符匹配 "claude-*": "claude-sonnet-4-5" 恰好目标等于请求名 → 返回 model-a
// 3. 映射表中没有 model-a 的配置 → 返回空(不支持)
if account.IsModelSupported(requestedModel) {
return requestedModel
}
// 未在映射表中配置的模型,返回空字符串(不支持)
return ""
}
// getMappedModel 获取映射后的模型名
// 完全依赖映射配置:账户映射(通配符)→ 默认映射兜底
func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedModel string) string {
return mapAntigravityModel(account, requestedModel)
}
// applyThinkingModelSuffix 根据 thinking 配置调整模型名
// 当映射结果是 claude-sonnet-4-5 且请求开启了 thinking 时,改为 claude-sonnet-4-5-thinking
func applyThinkingModelSuffix(mappedModel string, thinkingEnabled bool) string {
if !thinkingEnabled {
return mappedModel
}
if mappedModel == "claude-sonnet-4-5" {
return "claude-sonnet-4-5-thinking"
}
return mappedModel
}
// IsModelSupported 检查模型是否被支持
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持
func (s *AntigravityGatewayService) IsModelSupported(requestedModel string) bool {
return strings.HasPrefix(requestedModel, "claude-") ||
strings.HasPrefix(requestedModel, "gemini-")
}
// TestConnectionResult 测试连接结果
type TestConnectionResult struct {
Text string // 响应文本
MappedModel string // 实际使用的模型
}
// TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费)
// 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择
func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
// 获取 token
if s.tokenProvider == nil {
return nil, errors.New("antigravity token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("获取 access_token 失败: %w", err)
}
// 获取 project_id部分账户类型可能没有
projectID := strings.TrimSpace(account.GetCredential("project_id"))
// 模型映射
mappedModel := s.getMappedModel(account, modelID)
if mappedModel == "" {
return nil, fmt.Errorf("model %s not in whitelist", modelID)
}
// 构建请求体
var requestBody []byte
if strings.HasPrefix(modelID, "gemini-") {
// Gemini 模型:直接使用 Gemini 格式
requestBody, err = s.buildGeminiTestRequest(projectID, mappedModel)
} else {
// Claude 模型:使用协议转换
requestBody, err = s.buildClaudeTestRequest(projectID, mappedModel)
}
if err != nil {
return nil, fmt.Errorf("构建请求失败: %w", err)
}
// 代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// URL fallback 循环
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
if len(availableURLs) == 0 {
availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有
}
var lastErr error
for urlIdx, baseURL := range availableURLs {
// 构建 HTTP 请求(总是使用流式 endpoint与官方客户端一致
req, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, "streamGenerateContent", accessToken, requestBody)
if err != nil {
lastErr = err
continue
}
// 调试日志Test 请求信息
log.Printf("[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String())
// 发送请求
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
lastErr = fmt.Errorf("请求失败: %w", err)
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
log.Printf("[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
continue
}
return nil, lastErr
}
// 读取响应
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
// 检查是否需要 URL 降级
if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
log.Printf("[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
continue
}
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody))
}
// 解析流式响应,提取文本
text := extractTextFromSSEResponse(respBody)
// 标记成功的 URL下次优先使用
antigravity.DefaultURLAvailability.MarkSuccess(baseURL)
return &TestConnectionResult{
Text: text,
MappedModel: mappedModel,
}, nil
}
return nil, lastErr
}
// buildGeminiTestRequest 构建 Gemini 格式测试请求
// 使用最小 token 消耗:输入 "." + maxOutputTokens: 1
func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model string) ([]byte, error) {
payload := map[string]any{
"contents": []map[string]any{
{
"role": "user",
"parts": []map[string]any{
{"text": "."},
},
},
},
// Antigravity 上游要求必须包含身份提示词
"systemInstruction": map[string]any{
"parts": []map[string]any{
{"text": antigravity.GetDefaultIdentityPatch()},
},
},
"generationConfig": map[string]any{
"maxOutputTokens": 1,
},
}
payloadBytes, _ := json.Marshal(payload)
return s.wrapV1InternalRequest(projectID, model, payloadBytes)
}
// buildClaudeTestRequest 构建 Claude 格式测试请求并转换为 Gemini 格式
// 使用最小 token 消耗:输入 "." + MaxTokens: 1
func (s *AntigravityGatewayService) buildClaudeTestRequest(projectID, mappedModel string) ([]byte, error) {
claudeReq := &antigravity.ClaudeRequest{
Model: mappedModel,
Messages: []antigravity.ClaudeMessage{
{
Role: "user",
Content: json.RawMessage(`"."`),
},
},
MaxTokens: 1,
Stream: false,
}
return antigravity.TransformClaudeToGemini(claudeReq, projectID, mappedModel)
}
func (s *AntigravityGatewayService) getClaudeTransformOptions(ctx context.Context) antigravity.TransformOptions {
opts := antigravity.DefaultTransformOptions()
if s.settingService == nil {
return opts
}
opts.EnableIdentityPatch = s.settingService.IsIdentityPatchEnabled(ctx)
opts.IdentityPatch = s.settingService.GetIdentityPatchPrompt(ctx)
return opts
}
// extractTextFromSSEResponse 从 SSE 流式响应中提取文本
func extractTextFromSSEResponse(respBody []byte) string {
var texts []string
lines := bytes.Split(respBody, []byte("\n"))
for _, line := range lines {
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
// 跳过 SSE 前缀
if bytes.HasPrefix(line, []byte("data:")) {
line = bytes.TrimPrefix(line, []byte("data:"))
line = bytes.TrimSpace(line)
}
// 跳过非 JSON 行
if len(line) == 0 || line[0] != '{' {
continue
}
// 解析 JSON
var data map[string]any
if err := json.Unmarshal(line, &data); err != nil {
continue
}
// 尝试从 response.candidates[0].content.parts[].text 提取
response, ok := data["response"].(map[string]any)
if !ok {
// 尝试直接从 candidates 提取(某些响应格式)
response = data
}
candidates, ok := response["candidates"].([]any)
if !ok || len(candidates) == 0 {
continue
}
candidate, ok := candidates[0].(map[string]any)
if !ok {
continue
}
content, ok := candidate["content"].(map[string]any)
if !ok {
continue
}
parts, ok := content["parts"].([]any)
if !ok {
continue
}
for _, part := range parts {
if partMap, ok := part.(map[string]any); ok {
if text, ok := partMap["text"].(string); ok && text != "" {
texts = append(texts, text)
}
}
}
}
return strings.Join(texts, "")
}
// injectIdentityPatchToGeminiRequest 为 Gemini 格式请求注入身份提示词
// 如果请求中已包含 "You are Antigravity" 则不重复注入
func injectIdentityPatchToGeminiRequest(body []byte) ([]byte, error) {
var request map[string]any
if err := json.Unmarshal(body, &request); err != nil {
return nil, fmt.Errorf("解析 Gemini 请求失败: %w", err)
}
// 检查现有 systemInstruction 是否已包含身份提示词
if sysInst, ok := request["systemInstruction"].(map[string]any); ok {
if parts, ok := sysInst["parts"].([]any); ok {
for _, part := range parts {
if partMap, ok := part.(map[string]any); ok {
if text, ok := partMap["text"].(string); ok {
if strings.Contains(text, "You are Antigravity") {
// 已包含身份提示词,直接返回原始请求
return body, nil
}
}
}
}
}
}
// 获取默认身份提示词
identityPatch := antigravity.GetDefaultIdentityPatch()
// 构建新的 systemInstruction
newPart := map[string]any{"text": identityPatch}
if existing, ok := request["systemInstruction"].(map[string]any); ok {
// 已有 systemInstruction在开头插入身份提示词
if parts, ok := existing["parts"].([]any); ok {
existing["parts"] = append([]any{newPart}, parts...)
} else {
existing["parts"] = []any{newPart}
}
} else {
// 没有 systemInstruction创建新的
request["systemInstruction"] = map[string]any{
"parts": []any{newPart},
}
}
return json.Marshal(request)
}
// wrapV1InternalRequest 包装请求为 v1internal 格式
func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model string, originalBody []byte) ([]byte, error) {
var request any
if err := json.Unmarshal(originalBody, &request); err != nil {
return nil, fmt.Errorf("解析请求体失败: %w", err)
}
wrapped := map[string]any{
"project": projectID,
"requestId": "agent-" + uuid.New().String(),
"userAgent": "antigravity", // 固定值,与官方客户端一致
"requestType": "agent",
"model": model,
"request": request,
}
return json.Marshal(wrapped)
}
// unwrapV1InternalResponse 解包 v1internal 响应
func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byte, error) {
var outer map[string]any
if err := json.Unmarshal(body, &outer); err != nil {
return nil, err
}
if resp, ok := outer["response"]; ok {
return json.Marshal(resp)
}
return body, nil
}
// isModelNotFoundError 检测是否为模型不存在的 404 错误
func isModelNotFoundError(statusCode int, body []byte) bool {
if statusCode != 404 {
return false
}
bodyStr := strings.ToLower(string(body))
keywords := []string{"model not found", "unknown model", "not found"}
for _, keyword := range keywords {
if strings.Contains(bodyStr, keyword) {
return true
}
}
return true // 404 without specific message also treated as model not found
}
// Forward 转发 Claude 协议请求Claude → Gemini 转换)
//
// 限流处理流程:
//
// 请求 → antigravityRetryLoop → 预检查(remaining>0? → 切换账号) → 发送上游
// ├─ 成功 → 正常返回
// └─ 429/503 → handleSmartRetry
// ├─ retryDelay >= 7s → 设置模型限流 + 清除粘性绑定 → 切换账号
// └─ retryDelay < 7s → 等待后重试 1 次
// ├─ 成功 → 正常返回
// └─ 失败 → 设置模型限流 + 清除粘性绑定 → 切换账号
func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) {
startTime := time.Now()
sessionID := getSessionID(c)
prefix := logPrefix(sessionID, account.Name)
// 解析 Claude 请求
var claudeReq antigravity.ClaudeRequest
if err := json.Unmarshal(body, &claudeReq); err != nil {
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Invalid request body")
}
if strings.TrimSpace(claudeReq.Model) == "" {
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Missing model")
}
originalModel := claudeReq.Model
mappedModel := s.getMappedModel(account, claudeReq.Model)
if mappedModel == "" {
return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model))
}
loadModel := mappedModel
// 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5自动改为 thinking 版本
thinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled)
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
// 获取 access_token
if s.tokenProvider == nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "api_error", "Antigravity token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "authentication_error", "Failed to get upstream access token")
}
// 获取 project_id部分账户类型可能没有
projectID := strings.TrimSpace(account.GetCredential("project_id"))
// 代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 获取转换选项
// Antigravity 上游要求必须包含身份提示词,否则会返回 429
transformOpts := s.getClaudeTransformOptions(ctx)
transformOpts.EnableIdentityPatch = true // 强制启用Antigravity 上游必需
// 转换 Claude 请求为 Gemini 格式
geminiBody, err := antigravity.TransformClaudeToGeminiWithOptions(&claudeReq, projectID, mappedModel, transformOpts)
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Invalid request")
}
// Antigravity 上游只支持流式请求,统一使用 streamGenerateContent
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
action := "streamGenerateContent"
// 统计模型调用次数(包括粘性会话,用于负载均衡调度)
if s.cache != nil {
_, _ = s.cache.IncrModelCallCount(ctx, account.ID, loadModel)
}
// 执行带重试的请求
result, err := s.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,
prefix: prefix,
account: account,
proxyURL: proxyURL,
accessToken: accessToken,
action: action,
body: geminiBody,
quotaScope: quotaScope,
c: c,
httpUpstream: s.httpUpstream,
settingService: s.settingService,
accountRepo: s.accountRepo,
handleError: s.handleUpstreamError,
requestedModel: originalModel,
isStickySession: isStickySession, // Forward 由上层判断粘性会话
groupID: 0, // Forward 方法没有 groupID由上层处理粘性会话清除
sessionHash: "", // Forward 方法没有 sessionHash由上层处理粘性会话清除
})
if err != nil {
// 检查是否是账号切换信号,转换为 UpstreamFailoverError 让 Handler 切换账号
if switchErr, ok := IsAntigravityAccountSwitchError(err); ok {
return nil, &UpstreamFailoverError{
StatusCode: http.StatusServiceUnavailable,
ForceCacheBilling: switchErr.IsStickySession,
}
}
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
}
resp := result.resp
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
// 优先检测 thinking block 的 signature 相关错误400并重试一次
// Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验,
// 当历史消息携带的 signature 不合法时会直接 400去除 thinking 后可继续完成请求。
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
logBody, maxBytes := s.getLogConfig()
upstreamDetail := s.getUpstreamErrorDetail(respBody)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "signature_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
// Conservative two-stage fallback:
// 1) Disable top-level thinking + thinking->text
// 2) Only if still signature-related 400: also downgrade tool_use/tool_result to text.
retryStages := []struct {
name string
strip func(*antigravity.ClaudeRequest) (bool, error)
}{
{name: "thinking-only", strip: stripThinkingFromClaudeRequest},
{name: "thinking+tools", strip: stripSignatureSensitiveBlocksFromClaudeRequest},
}
for _, stage := range retryStages {
retryClaudeReq := claudeReq
retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...)
stripped, stripErr := stage.strip(&retryClaudeReq)
if stripErr != nil || !stripped {
continue
}
log.Printf("Antigravity account %d: detected signature-related 400, retrying once (%s)", account.ID, stage.name)
retryGeminiBody, txErr := antigravity.TransformClaudeToGeminiWithOptions(&retryClaudeReq, projectID, mappedModel, s.getClaudeTransformOptions(ctx))
if txErr != nil {
continue
}
retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,
prefix: prefix,
account: account,
proxyURL: proxyURL,
accessToken: accessToken,
action: action,
body: retryGeminiBody,
quotaScope: quotaScope,
c: c,
httpUpstream: s.httpUpstream,
settingService: s.settingService,
accountRepo: s.accountRepo,
handleError: s.handleUpstreamError,
requestedModel: originalModel,
isStickySession: isStickySession,
groupID: 0, // Forward 方法没有 groupID由上层处理粘性会话清除
sessionHash: "", // Forward 方法没有 sessionHash由上层处理粘性会话清除
})
if retryErr != nil {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "signature_retry_request_error",
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
})
log.Printf("Antigravity account %d: signature retry request failed (%s): %v", account.ID, stage.name, retryErr)
continue
}
retryResp := retryResult.resp
if retryResp.StatusCode < 400 {
_ = resp.Body.Close()
resp = retryResp
respBody = nil
break
}
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close()
if retryResp.StatusCode == http.StatusTooManyRequests {
retryBaseURL := ""
if retryResp.Request != nil && retryResp.Request.URL != nil {
retryBaseURL = retryResp.Request.URL.Scheme + "://" + retryResp.Request.URL.Host
}
log.Printf("%s status=429 rate_limited base_url=%s retry_stage=%s body=%s", prefix, retryBaseURL, stage.name, truncateForLog(retryBody, 200))
}
kind := "signature_retry"
if strings.TrimSpace(stage.name) != "" {
kind = "signature_retry_" + strings.ReplaceAll(stage.name, "+", "_")
}
retryUpstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(retryBody))
retryUpstreamMsg = sanitizeUpstreamErrorMessage(retryUpstreamMsg)
retryUpstreamDetail := ""
if logBody {
retryUpstreamDetail = truncateString(string(retryBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: retryResp.StatusCode,
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
Kind: kind,
Message: retryUpstreamMsg,
Detail: retryUpstreamDetail,
})
// If this stage fixed the signature issue, we stop; otherwise we may try the next stage.
if retryResp.StatusCode != http.StatusBadRequest || !isSignatureRelatedError(retryBody) {
respBody = retryBody
resp = &http.Response{
StatusCode: retryResp.StatusCode,
Header: retryResp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(retryBody)),
}
break
}
// Still signature-related; capture context and allow next stage.
respBody = retryBody
resp = &http.Response{
StatusCode: retryResp.StatusCode,
Header: retryResp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(retryBody)),
}
}
}
// 处理错误响应(重试后仍失败或不触发重试)
if resp.StatusCode >= 400 {
// 检测 prompt too long 错误,返回特殊错误类型供上层 fallback
if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := s.getUpstreamErrorDetail(respBody)
logBody, maxBytes := s.getLogConfig()
if logBody {
log.Printf("%s status=400 prompt_too_long=true upstream_message=%q request_id=%s body=%s", prefix, upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, maxBytes))
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "prompt_too_long",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &PromptTooLongError{
StatusCode: resp.StatusCode,
RequestID: resp.Header.Get("x-request-id"),
Body: respBody,
}
}
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := s.getUpstreamErrorDetail(respBody)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody)
}
}
requestID := resp.Header.Get("x-request-id")
if requestID != "" {
c.Header("x-request-id", requestID)
}
var usage *ClaudeUsage
var firstTokenMs *int
var clientDisconnect bool
if claudeReq.Stream {
// 客户端要求流式,直接透传转换
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
if err != nil {
log.Printf("%s status=stream_error error=%v", prefix, err)
return nil, err
}
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
clientDisconnect = streamRes.clientDisconnect
} else {
// 客户端要求非流式,收集流式响应后转换返回
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel)
if err != nil {
log.Printf("%s status=stream_collect_error error=%v", prefix, err)
return nil, err
}
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
}
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
Model: originalModel, // 使用原始模型用于计费和日志
Stream: claudeReq.Stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ClientDisconnect: clientDisconnect,
}, nil
}
func isSignatureRelatedError(respBody []byte) bool {
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
if msg == "" {
// Fallback: best-effort scan of the raw payload.
msg = strings.ToLower(string(respBody))
}
// Keep this intentionally broad: different upstreams may use "signature" or "thought_signature".
if strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") {
return true
}
// Also detect thinking block structural errors:
// "Expected `thinking` or `redacted_thinking`, but found `text`"
if strings.Contains(msg, "expected") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) {
return true
}
return false
}
// isPromptTooLongError 检测是否为 prompt too long 错误
func isPromptTooLongError(respBody []byte) bool {
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
if msg == "" {
msg = strings.ToLower(string(respBody))
}
return strings.Contains(msg, "prompt is too long") ||
strings.Contains(msg, "request is too long") ||
strings.Contains(msg, "context length exceeded") ||
strings.Contains(msg, "max_tokens")
}
// isPassthroughErrorMessage 检查错误消息是否在透传白名单中
func isPassthroughErrorMessage(msg string) bool {
lower := strings.ToLower(msg)
for _, pattern := range antigravityPassthroughErrorMessages {
if strings.Contains(lower, pattern) {
return true
}
}
return false
}
// getPassthroughOrDefault 若消息在白名单内则返回原始消息,否则返回默认消息
func getPassthroughOrDefault(upstreamMsg, defaultMsg string) string {
if isPassthroughErrorMessage(upstreamMsg) {
return upstreamMsg
}
return defaultMsg
}
func extractAntigravityErrorMessage(body []byte) string {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return ""
}
// Google-style: {"error": {"message": "..."}}
if errObj, ok := payload["error"].(map[string]any); ok {
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
return msg
}
}
// Fallback: top-level message
if msg, ok := payload["message"].(string); ok && strings.TrimSpace(msg) != "" {
return msg
}
return ""
}
// stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request.
// This preserves the thinking content while avoiding signature validation errors.
// Note: redacted_thinking blocks are removed because they cannot be converted to text.
// It also disables top-level `thinking` to avoid upstream structural constraints for thinking mode.
func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error) {
if req == nil {
return false, nil
}
changed := false
if req.Thinking != nil {
req.Thinking = nil
changed = true
}
for i := range req.Messages {
raw := req.Messages[i].Content
if len(raw) == 0 {
continue
}
// If content is a string, nothing to strip.
var str string
if json.Unmarshal(raw, &str) == nil {
continue
}
// Otherwise treat as an array of blocks and convert thinking blocks to text.
var blocks []map[string]any
if err := json.Unmarshal(raw, &blocks); err != nil {
continue
}
filtered := make([]map[string]any, 0, len(blocks))
modifiedAny := false
for _, block := range blocks {
t, _ := block["type"].(string)
switch t {
case "thinking":
thinkingText, _ := block["thinking"].(string)
if thinkingText != "" {
filtered = append(filtered, map[string]any{
"type": "text",
"text": thinkingText,
})
}
modifiedAny = true
case "redacted_thinking":
modifiedAny = true
case "":
if thinkingText, hasThinking := block["thinking"].(string); hasThinking {
if thinkingText != "" {
filtered = append(filtered, map[string]any{
"type": "text",
"text": thinkingText,
})
}
modifiedAny = true
} else {
filtered = append(filtered, block)
}
default:
filtered = append(filtered, block)
}
}
if !modifiedAny {
continue
}
if len(filtered) == 0 {
filtered = append(filtered, map[string]any{
"type": "text",
"text": "(content removed)",
})
}
newRaw, err := json.Marshal(filtered)
if err != nil {
return changed, err
}
req.Messages[i].Content = newRaw
changed = true
}
return changed, nil
}
// stripSignatureSensitiveBlocksFromClaudeRequest is a stronger retry degradation that additionally converts
// tool blocks to plain text. Use this only after a thinking-only retry still fails with signature errors.
func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error) {
if req == nil {
return false, nil
}
changed := false
if req.Thinking != nil {
req.Thinking = nil
changed = true
}
for i := range req.Messages {
raw := req.Messages[i].Content
if len(raw) == 0 {
continue
}
// If content is a string, nothing to strip.
var str string
if json.Unmarshal(raw, &str) == nil {
continue
}
// Otherwise treat as an array of blocks and convert signature-sensitive blocks to text.
var blocks []map[string]any
if err := json.Unmarshal(raw, &blocks); err != nil {
continue
}
filtered := make([]map[string]any, 0, len(blocks))
modifiedAny := false
for _, block := range blocks {
t, _ := block["type"].(string)
switch t {
case "thinking":
// Convert thinking to text, skip if empty
thinkingText, _ := block["thinking"].(string)
if thinkingText != "" {
filtered = append(filtered, map[string]any{
"type": "text",
"text": thinkingText,
})
}
modifiedAny = true
case "redacted_thinking":
// Remove redacted_thinking (cannot convert encrypted content)
modifiedAny = true
case "tool_use":
// Convert tool_use to text to avoid upstream signature/thought_signature validation errors.
// This is a retry-only degradation path, so we prioritise request validity over tool semantics.
name, _ := block["name"].(string)
id, _ := block["id"].(string)
input := block["input"]
inputJSON, _ := json.Marshal(input)
text := "(tool_use)"
if name != "" {
text += " name=" + name
}
if id != "" {
text += " id=" + id
}
if len(inputJSON) > 0 && string(inputJSON) != "null" {
text += " input=" + string(inputJSON)
}
filtered = append(filtered, map[string]any{
"type": "text",
"text": text,
})
modifiedAny = true
case "tool_result":
// Convert tool_result to text so it stays consistent when tool_use is downgraded.
toolUseID, _ := block["tool_use_id"].(string)
isError, _ := block["is_error"].(bool)
content := block["content"]
contentJSON, _ := json.Marshal(content)
text := "(tool_result)"
if toolUseID != "" {
text += " tool_use_id=" + toolUseID
}
if isError {
text += " is_error=true"
}
if len(contentJSON) > 0 && string(contentJSON) != "null" {
text += "\n" + string(contentJSON)
}
filtered = append(filtered, map[string]any{
"type": "text",
"text": text,
})
modifiedAny = true
case "":
// Handle untyped block with "thinking" field
if thinkingText, hasThinking := block["thinking"].(string); hasThinking {
if thinkingText != "" {
filtered = append(filtered, map[string]any{
"type": "text",
"text": thinkingText,
})
}
modifiedAny = true
} else {
filtered = append(filtered, block)
}
default:
filtered = append(filtered, block)
}
}
if !modifiedAny {
continue
}
if len(filtered) == 0 {
// Keep request valid: upstream rejects empty content arrays.
filtered = append(filtered, map[string]any{
"type": "text",
"text": "(content removed)",
})
}
newRaw, err := json.Marshal(filtered)
if err != nil {
return changed, err
}
req.Messages[i].Content = newRaw
changed = true
}
return changed, nil
}
// ForwardGemini 转发 Gemini 协议请求
//
// 限流处理流程:
//
// 请求 → antigravityRetryLoop → 预检查(remaining>0? → 切换账号) → 发送上游
// ├─ 成功 → 正常返回
// └─ 429/503 → handleSmartRetry
// ├─ retryDelay >= 7s → 设置模型限流 + 清除粘性绑定 → 切换账号
// └─ retryDelay < 7s → 等待后重试 1 次
// ├─ 成功 → 正常返回
// └─ 失败 → 设置模型限流 + 清除粘性绑定 → 切换账号
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) {
startTime := time.Now()
sessionID := getSessionID(c)
prefix := logPrefix(sessionID, account.Name)
if strings.TrimSpace(originalModel) == "" {
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing model in URL")
}
if strings.TrimSpace(action) == "" {
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing action in URL")
}
if len(body) == 0 {
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
}
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
// 解析请求以获取 image_size用于图片计费
imageSize := s.extractImageSize(body)
switch action {
case "generateContent", "streamGenerateContent":
// ok
case "countTokens":
// 直接返回空值,不透传上游
c.JSON(http.StatusOK, map[string]any{"totalTokens": 0})
return &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(time.Now()),
FirstTokenMs: nil,
}, nil
default:
return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action)
}
mappedModel := s.getMappedModel(account, originalModel)
if mappedModel == "" {
return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel))
}
// 获取 access_token
if s.tokenProvider == nil {
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Antigravity token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Failed to get upstream access token")
}
// 获取 project_id部分账户类型可能没有
projectID := strings.TrimSpace(account.GetCredential("project_id"))
// 代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// Antigravity 上游要求必须包含身份提示词,注入到请求中
injectedBody, err := injectIdentityPatchToGeminiRequest(body)
if err != nil {
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Invalid request body")
}
// 清理 Schema
if cleanedBody, err := cleanGeminiRequest(injectedBody); err == nil {
injectedBody = cleanedBody
log.Printf("[Antigravity] Cleaned request schema in forwarded request for account %s", account.Name)
} else {
log.Printf("[Antigravity] Failed to clean schema: %v", err)
}
// 包装请求
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody)
if err != nil {
return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build upstream request")
}
// Antigravity 上游只支持流式请求,统一使用 streamGenerateContent
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回
upstreamAction := "streamGenerateContent"
// 统计模型调用次数(包括粘性会话,用于负载均衡调度)
if s.cache != nil {
_, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel)
}
// 执行带重试的请求
result, err := s.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,
prefix: prefix,
account: account,
proxyURL: proxyURL,
accessToken: accessToken,
action: upstreamAction,
body: wrappedBody,
quotaScope: quotaScope,
c: c,
httpUpstream: s.httpUpstream,
settingService: s.settingService,
accountRepo: s.accountRepo,
handleError: s.handleUpstreamError,
requestedModel: originalModel,
isStickySession: isStickySession, // ForwardGemini 由上层判断粘性会话
groupID: 0, // ForwardGemini 方法没有 groupID由上层处理粘性会话清除
sessionHash: "", // ForwardGemini 方法没有 sessionHash由上层处理粘性会话清除
})
if err != nil {
// 检查是否是账号切换信号,转换为 UpstreamFailoverError 让 Handler 切换账号
if switchErr, ok := IsAntigravityAccountSwitchError(err); ok {
return nil, &UpstreamFailoverError{
StatusCode: http.StatusServiceUnavailable,
ForceCacheBilling: switchErr.IsStickySession,
}
}
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
}
resp := result.resp
defer func() {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
}()
// 处理错误响应
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
contentType := resp.Header.Get("Content-Type")
// 尽早关闭原始响应体,释放连接;后续逻辑仍可能需要读取 body因此用内存副本重新包装。
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
// 模型兜底:模型不存在且开启 fallback 时,自动用 fallback 模型重试一次
if s.settingService != nil && s.settingService.IsModelFallbackEnabled(ctx) &&
isModelNotFoundError(resp.StatusCode, respBody) {
fallbackModel := s.settingService.GetFallbackModel(ctx, PlatformAntigravity)
if fallbackModel != "" && fallbackModel != mappedModel {
log.Printf("[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name)
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, injectedBody)
if err == nil {
fallbackReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, fallbackWrapped)
if err == nil {
fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency)
if err == nil && fallbackResp.StatusCode < 400 {
_ = resp.Body.Close()
resp = fallbackResp
} else if fallbackResp != nil {
_ = fallbackResp.Body.Close()
}
}
}
}
}
// fallback 成功:继续按正常响应处理
if resp.StatusCode < 400 {
goto handleSuccess
}
requestID := resp.Header.Get("x-request-id")
if requestID != "" {
c.Header("x-request-id", requestID)
}
unwrapped, unwrapErr := s.unwrapV1InternalResponse(respBody)
unwrappedForOps := unwrapped
if unwrapErr != nil || len(unwrappedForOps) == 0 {
unwrappedForOps = respBody
}
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession)
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(unwrappedForOps))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := s.getUpstreamErrorDetail(unwrappedForOps)
// Always record upstream context for Ops error logs, even when we will failover.
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: unwrappedForOps}
}
if contentType == "" {
contentType = "application/json"
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "http_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
log.Printf("[antigravity-Forward] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(unwrappedForOps, 500))
c.Data(resp.StatusCode, contentType, unwrappedForOps)
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
}
handleSuccess:
requestID := resp.Header.Get("x-request-id")
if requestID != "" {
c.Header("x-request-id", requestID)
}
var usage *ClaudeUsage
var firstTokenMs *int
var clientDisconnect bool
if stream {
// 客户端要求流式,直接透传
streamRes, err := s.handleGeminiStreamingResponse(c, resp, startTime)
if err != nil {
log.Printf("%s status=stream_error error=%v", prefix, err)
return nil, err
}
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
clientDisconnect = streamRes.clientDisconnect
} else {
// 客户端要求非流式,收集流式响应后返回
streamRes, err := s.handleGeminiStreamToNonStreaming(c, resp, startTime)
if err != nil {
log.Printf("%s status=stream_collect_error error=%v", prefix, err)
return nil, err
}
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
}
if usage == nil {
usage = &ClaudeUsage{}
}
// 判断是否为图片生成模型
imageCount := 0
if isImageGenerationModel(mappedModel) {
// Gemini 图片生成 API 每次请求只生成一张图片API 限制)
imageCount = 1
}
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
Model: originalModel,
Stream: stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ClientDisconnect: clientDisconnect,
ImageCount: imageCount,
ImageSize: imageSize,
}, nil
}
func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
switch statusCode {
case 401, 403, 429, 529:
return true
default:
return statusCode >= 500
}
}
// sleepAntigravityBackoffWithContext 带 context 取消检查的退避等待
// 返回 true 表示正常完成等待false 表示 context 已取消
func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
delay := antigravityRetryBaseDelay * time.Duration(1<<uint(attempt-1))
if delay > antigravityRetryMaxDelay {
delay = antigravityRetryMaxDelay
}
// +/- 20% jitter
r := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
jitter := time.Duration(float64(delay) * 0.2 * (r.Float64()*2 - 1))
sleepFor := delay + jitter
if sleepFor < 0 {
sleepFor = 0
}
select {
case <-ctx.Done():
return false
case <-time.After(sleepFor):
return true
}
}
// setModelRateLimitByModelName 使用官方模型 ID 设置模型级限流
// 直接使用上游返回的模型 ID如 claude-sonnet-4-5作为限流 key
// 返回是否已成功设置(若模型名为空或 repo 为 nil 将返回 false
func setModelRateLimitByModelName(ctx context.Context, repo AccountRepository, accountID int64, modelName, prefix string, statusCode int, resetAt time.Time, afterSmartRetry bool) bool {
if repo == nil || modelName == "" {
return false
}
// 直接使用官方模型 ID 作为 key不再转换为 scope
if err := repo.SetModelRateLimit(ctx, accountID, modelName, resetAt); err != nil {
log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", prefix, statusCode, modelName, err)
return false
}
if afterSmartRetry {
log.Printf("%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second))
} else {
log.Printf("%s status=%d model_rate_limited model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second))
}
return true
}
func antigravityFallbackCooldownSeconds() (time.Duration, bool) {
raw := strings.TrimSpace(os.Getenv(antigravityFallbackSecondsEnv))
if raw == "" {
return 0, false
}
seconds, err := strconv.Atoi(raw)
if err != nil || seconds <= 0 {
return 0, false
}
return time.Duration(seconds) * time.Second, true
}
// antigravitySmartRetryInfo 智能重试所需的信息
type antigravitySmartRetryInfo struct {
RetryDelay time.Duration // 重试延迟时间
ModelName string // 限流的模型名称(如 "claude-sonnet-4-5"
}
// parseAntigravitySmartRetryInfo 解析 Google RPC RetryInfo 和 ErrorInfo 信息
// 返回解析结果,如果解析失败或不满足条件返回 nil
//
// 支持两种情况:
// 1. 429 RESOURCE_EXHAUSTED + RATE_LIMIT_EXCEEDED
// - error.status == "RESOURCE_EXHAUSTED"
// - error.details[].reason == "RATE_LIMIT_EXCEEDED"
//
// 2. 503 UNAVAILABLE + MODEL_CAPACITY_EXHAUSTED
// - error.status == "UNAVAILABLE"
// - error.details[].reason == "MODEL_CAPACITY_EXHAUSTED"
//
// 必须满足以下条件才会返回有效值:
// - error.details[] 中存在 @type == "type.googleapis.com/google.rpc.RetryInfo" 的元素
// - 该元素包含 retryDelay 字段,格式为 "数字s"(如 "0.201506475s"
func parseAntigravitySmartRetryInfo(body []byte) *antigravitySmartRetryInfo {
var parsed map[string]any
if err := json.Unmarshal(body, &parsed); err != nil {
return nil
}
errObj, ok := parsed["error"].(map[string]any)
if !ok {
return nil
}
// 检查 status 是否符合条件
// 情况1: 429 RESOURCE_EXHAUSTED (需要进一步检查 reason == RATE_LIMIT_EXCEEDED)
// 情况2: 503 UNAVAILABLE (需要进一步检查 reason == MODEL_CAPACITY_EXHAUSTED)
status, _ := errObj["status"].(string)
isResourceExhausted := status == googleRPCStatusResourceExhausted
isUnavailable := status == googleRPCStatusUnavailable
if !isResourceExhausted && !isUnavailable {
return nil
}
details, ok := errObj["details"].([]any)
if !ok {
return nil
}
var retryDelay time.Duration
var modelName string
var hasRateLimitExceeded bool // 429 需要此 reason
var hasModelCapacityExhausted bool // 503 需要此 reason
for _, d := range details {
dm, ok := d.(map[string]any)
if !ok {
continue
}
atType, _ := dm["@type"].(string)
// 从 ErrorInfo 提取模型名称和 reason
if atType == googleRPCTypeErrorInfo {
if meta, ok := dm["metadata"].(map[string]any); ok {
if model, ok := meta["model"].(string); ok {
modelName = model
}
}
// 检查 reason
if reason, ok := dm["reason"].(string); ok {
if reason == googleRPCReasonModelCapacityExhausted {
hasModelCapacityExhausted = true
}
if reason == googleRPCReasonRateLimitExceeded {
hasRateLimitExceeded = true
}
}
continue
}
// 从 RetryInfo 提取重试延迟
if atType == googleRPCTypeRetryInfo {
delay, ok := dm["retryDelay"].(string)
if !ok || delay == "" {
continue
}
// 使用 time.ParseDuration 解析,支持所有 Go duration 格式
// 例如: "0.5s", "10s", "4m50s", "1h30m", "200ms" 等
dur, err := time.ParseDuration(delay)
if err != nil {
log.Printf("[Antigravity] failed to parse retryDelay: %s error=%v", delay, err)
continue
}
retryDelay = dur
}
}
// 验证条件
// 情况1: RESOURCE_EXHAUSTED 需要有 RATE_LIMIT_EXCEEDED reason
// 情况2: UNAVAILABLE 需要有 MODEL_CAPACITY_EXHAUSTED reason
if isResourceExhausted && !hasRateLimitExceeded {
return nil
}
if isUnavailable && !hasModelCapacityExhausted {
return nil
}
// 必须有模型名才返回有效结果
if modelName == "" {
return nil
}
// 如果上游未提供 retryDelay使用默认限流时间
if retryDelay <= 0 {
retryDelay = antigravityDefaultRateLimitDuration
}
return &antigravitySmartRetryInfo{
RetryDelay: retryDelay,
ModelName: modelName,
}
}
// shouldTriggerAntigravitySmartRetry 判断是否应该触发智能重试
// 返回:
// - shouldRetry: 是否应该智能重试retryDelay < antigravityRateLimitThreshold
// - shouldRateLimitModel: 是否应该限流模型retryDelay >= antigravityRateLimitThreshold
// - waitDuration: 等待时间智能重试时使用shouldRateLimitModel=true 时为 0
// - modelName: 限流的模型名称
func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shouldRetry bool, shouldRateLimitModel bool, waitDuration time.Duration, modelName string) {
if account.Platform != PlatformAntigravity {
return false, false, 0, ""
}
info := parseAntigravitySmartRetryInfo(respBody)
if info == nil {
return false, false, 0, ""
}
// retryDelay >= 阈值:直接限流模型,不重试
// 注意:如果上游未提供 retryDelayparseAntigravitySmartRetryInfo 已设置为默认 5 分钟
if info.RetryDelay >= antigravityRateLimitThreshold {
return false, true, 0, info.ModelName
}
// retryDelay < 阈值:智能重试
waitDuration = info.RetryDelay
if waitDuration < antigravitySmartRetryMinWait {
waitDuration = antigravitySmartRetryMinWait
}
return true, false, waitDuration, info.ModelName
}
// handleModelRateLimitParams 模型级限流处理参数
type handleModelRateLimitParams struct {
ctx context.Context
prefix string
account *Account
statusCode int
body []byte
cache GatewayCache
groupID int64
sessionHash string
isStickySession bool
}
// handleModelRateLimitResult 模型级限流处理结果
type handleModelRateLimitResult struct {
Handled bool // 是否已处理
ShouldRetry bool // 是否等待后重试
WaitDuration time.Duration // 等待时间
SwitchError *AntigravityAccountSwitchError // 账号切换错误
}
// handleModelRateLimit 处理模型级限流(在原有逻辑之前调用)
// 仅处理 429/503解析模型名和 retryDelay
// - retryDelay < antigravityRateLimitThreshold: 返回 ShouldRetry=true由调用方等待后重试
// - retryDelay >= antigravityRateLimitThreshold: 设置模型限流 + 清除粘性会话 + 返回 SwitchError
func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimitParams) *handleModelRateLimitResult {
if p.statusCode != 429 && p.statusCode != 503 {
return &handleModelRateLimitResult{Handled: false}
}
info := parseAntigravitySmartRetryInfo(p.body)
if info == nil || info.ModelName == "" {
return &handleModelRateLimitResult{Handled: false}
}
// < antigravityRateLimitThreshold: 等待后重试
if info.RetryDelay < antigravityRateLimitThreshold {
log.Printf("%s status=%d model_rate_limit_wait model=%s wait=%v",
p.prefix, p.statusCode, info.ModelName, info.RetryDelay)
return &handleModelRateLimitResult{
Handled: true,
ShouldRetry: true,
WaitDuration: info.RetryDelay,
}
}
// >= antigravityRateLimitThreshold: 设置限流 + 清除粘性会话 + 切换账号
s.setModelRateLimitAndClearSession(p, info)
return &handleModelRateLimitResult{
Handled: true,
SwitchError: &AntigravityAccountSwitchError{
OriginalAccountID: p.account.ID,
RateLimitedModel: info.ModelName,
IsStickySession: p.isStickySession,
},
}
}
// setModelRateLimitAndClearSession 设置模型限流并清除粘性会话
func (s *AntigravityGatewayService) setModelRateLimitAndClearSession(p *handleModelRateLimitParams, info *antigravitySmartRetryInfo) {
resetAt := time.Now().Add(info.RetryDelay)
log.Printf("%s status=%d model_rate_limited model=%s account=%d reset_in=%v",
p.prefix, p.statusCode, info.ModelName, p.account.ID, info.RetryDelay)
// 设置模型限流状态(数据库)
if err := s.accountRepo.SetModelRateLimit(p.ctx, p.account.ID, info.ModelName, resetAt); err != nil {
log.Printf("%s model_rate_limit_failed model=%s error=%v", p.prefix, info.ModelName, err)
}
// 立即更新 Redis 快照中账号的限流状态,避免并发请求重复选中
s.updateAccountModelRateLimitInCache(p.ctx, p.account, info.ModelName, resetAt)
// 清除粘性会话绑定
if p.cache != nil && p.sessionHash != "" {
_ = p.cache.DeleteSessionAccountID(p.ctx, p.groupID, p.sessionHash)
}
}
// updateAccountModelRateLimitInCache 立即更新 Redis 中账号的模型限流状态
func (s *AntigravityGatewayService) updateAccountModelRateLimitInCache(ctx context.Context, account *Account, modelKey string, resetAt time.Time) {
if s.schedulerSnapshot == nil || account == nil || modelKey == "" {
return
}
// 更新账号对象的 Extra 字段
if account.Extra == nil {
account.Extra = make(map[string]any)
}
limits, _ := account.Extra["model_rate_limits"].(map[string]any)
if limits == nil {
limits = make(map[string]any)
account.Extra["model_rate_limits"] = limits
}
limits[modelKey] = map[string]any{
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
}
// 更新 Redis 快照
if err := s.schedulerSnapshot.UpdateAccountInCache(ctx, account); err != nil {
log.Printf("[antigravity-Forward] cache_update_failed account=%d model=%s err=%v", account.ID, modelKey, err)
}
}
func (s *AntigravityGatewayService) handleUpstreamError(
ctx context.Context, prefix string, account *Account,
statusCode int, headers http.Header, body []byte,
quotaScope AntigravityQuotaScope,
groupID int64, sessionHash string, isStickySession bool,
) *handleModelRateLimitResult {
// ✨ 模型级限流处理(在原有逻辑之前)
result := s.handleModelRateLimit(&handleModelRateLimitParams{
ctx: ctx,
prefix: prefix,
account: account,
statusCode: statusCode,
body: body,
cache: s.cache,
groupID: groupID,
sessionHash: sessionHash,
isStickySession: isStickySession,
})
if result.Handled {
return result
}
// 503 仅处理模型限流MODEL_CAPACITY_EXHAUSTED非模型限流不做额外处理
// 避免将普通的 503 错误误判为账号问题
if statusCode == 503 {
return nil
}
// ========== 原有逻辑,保持不变 ==========
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
if statusCode == 429 {
// 调试日志遵循统一日志开关与长度限制,避免无条件记录完整上游响应体。
if logBody, maxBytes := s.getLogConfig(); logBody {
log.Printf("[Antigravity-Debug] 429 response body: %s", truncateString(string(body), maxBytes))
}
useScopeLimit := quotaScope != ""
resetAt := ParseGeminiRateLimitResetTime(body)
if resetAt == nil {
// 解析失败:使用默认限流时间(与临时限流保持一致)
// 可通过配置或环境变量覆盖
defaultDur := antigravityDefaultRateLimitDuration
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes > 0 {
defaultDur = time.Duration(s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes) * time.Minute
}
// 秒级环境变量优先级最高
if override, ok := antigravityFallbackCooldownSeconds(); ok {
defaultDur = override
}
ra := time.Now().Add(defaultDur)
if useScopeLimit {
log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur)
if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, ra); err != nil {
log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
}
} else {
log.Printf("%s status=429 rate_limited account=%d reset_in=%v (fallback)", prefix, account.ID, defaultDur)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, ra); err != nil {
log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err)
}
}
return nil
}
resetTime := time.Unix(*resetAt, 0)
if useScopeLimit {
log.Printf("%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v", prefix, quotaScope, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, resetTime); err != nil {
log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
}
} else {
log.Printf("%s status=429 rate_limited account=%d reset_at=%v reset_in=%v", prefix, account.ID, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err)
}
}
return nil
}
// 其他错误码继续使用 rateLimitService
if s.rateLimitService == nil {
return nil
}
shouldDisable := s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
if shouldDisable {
log.Printf("%s status=%d marked_error", prefix, statusCode)
}
return nil
}
type antigravityStreamResult struct {
usage *ClaudeUsage
firstTokenMs *int
clientDisconnect bool // 客户端是否在流式传输过程中断开
}
// antigravityClientWriter 封装流式响应的客户端写入,自动检测断开并标记。
// 断开后所有写入操作变为 no-op调用方通过 Disconnected() 判断是否继续 drain 上游。
type antigravityClientWriter struct {
w gin.ResponseWriter
flusher http.Flusher
disconnected bool
prefix string // 日志前缀,标识来源方法
}
func newAntigravityClientWriter(w gin.ResponseWriter, flusher http.Flusher, prefix string) *antigravityClientWriter {
return &antigravityClientWriter{w: w, flusher: flusher, prefix: prefix}
}
// Write 写入数据到客户端,写入失败时标记断开并返回 false
func (cw *antigravityClientWriter) Write(p []byte) bool {
if cw.disconnected {
return false
}
if _, err := cw.w.Write(p); err != nil {
cw.markDisconnected()
return false
}
cw.flusher.Flush()
return true
}
// Fprintf 格式化写入数据到客户端,写入失败时标记断开并返回 false
func (cw *antigravityClientWriter) Fprintf(format string, args ...any) bool {
if cw.disconnected {
return false
}
if _, err := fmt.Fprintf(cw.w, format, args...); err != nil {
cw.markDisconnected()
return false
}
cw.flusher.Flush()
return true
}
func (cw *antigravityClientWriter) Disconnected() bool { return cw.disconnected }
func (cw *antigravityClientWriter) markDisconnected() {
cw.disconnected = true
log.Printf("Client disconnected during streaming (%s), continuing to drain upstream for billing", cw.prefix)
}
// handleStreamReadError 处理上游读取错误的通用逻辑。
// 返回 (clientDisconnect, handled)handled=true 表示错误已处理,调用方应返回已收集的 usage。
func handleStreamReadError(err error, clientDisconnected bool, prefix string) (disconnect bool, handled bool) {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf("Context canceled during streaming (%s), returning collected usage", prefix)
return true, true
}
if clientDisconnected {
log.Printf("Upstream read error after client disconnect (%s): %v, returning collected usage", prefix, err)
return true, true
}
return false, false
}
func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) {
c.Status(resp.StatusCode)
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "text/event-stream; charset=utf-8"
}
c.Header("Content-Type", contentType)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
usage := &ClaudeUsage{}
var firstTokenMs *int
type scanEvent struct {
line string
err error
}
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
streamInterval := time.Duration(0)
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity gemini")
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent || cw.Disconnected() {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
for {
select {
case ev, ok := <-events:
if !ok {
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}, nil
}
if ev.err != nil {
if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity gemini"); handled {
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil
}
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return nil, ev.err
}
line := ev.line
trimmed := strings.TrimRight(line, "\r\n")
if strings.HasPrefix(trimmed, "data:") {
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
if payload == "" || payload == "[DONE]" {
cw.Fprintf("%s\n", line)
continue
}
// 解包 v1internal 响应
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
if parseErr == nil && inner != nil {
payload = string(inner)
}
// 解析 usage
var parsed map[string]any
if json.Unmarshal(inner, &parsed) == nil {
if u := extractGeminiUsage(parsed); u != nil {
usage = u
}
// Check for MALFORMED_FUNCTION_CALL
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
if cand, ok := candidates[0].(map[string]any); ok {
if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" {
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward stream")
if content, ok := cand["content"]; ok {
if b, err := json.Marshal(content); err == nil {
log.Printf("[Antigravity] Malformed content: %s", string(b))
}
}
}
}
}
}
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
cw.Fprintf("data: %s\n\n", payload)
continue
}
cw.Fprintf("%s\n", line)
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
if cw.Disconnected() {
log.Printf("Upstream timeout after client disconnect (antigravity gemini), returning collected usage")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
log.Printf("Stream data interval timeout (antigravity)")
sendErrorEvent("stream_timeout")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
}
}
// handleGeminiStreamToNonStreaming 读取上游流式响应,合并为非流式响应返回给客户端
// Gemini 流式响应是增量的,需要累积所有 chunk 的内容
func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) {
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
usage := &ClaudeUsage{}
var firstTokenMs *int
var last map[string]any
var lastWithParts map[string]any
var collectedImageParts []map[string]any // 收集所有包含图片的 parts
var collectedTextParts []string // 收集所有文本片段
type scanEvent struct {
line string
err error
}
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
streamInterval := time.Duration(0)
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
for {
select {
case ev, ok := <-events:
if !ok {
// 流结束,返回收集的响应
goto returnResponse
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long (antigravity non-stream): max_size=%d error=%v", maxLineSize, ev.err)
}
return nil, ev.err
}
line := ev.line
trimmed := strings.TrimRight(line, "\r\n")
if !strings.HasPrefix(trimmed, "data:") {
continue
}
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
if payload == "" || payload == "[DONE]" {
continue
}
// 解包 v1internal 响应
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
if parseErr != nil {
continue
}
var parsed map[string]any
if err := json.Unmarshal(inner, &parsed); err != nil {
continue
}
// 记录首 token 时间
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
last = parsed
// 提取 usage
if u := extractGeminiUsage(parsed); u != nil {
usage = u
}
// Check for MALFORMED_FUNCTION_CALL
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
if cand, ok := candidates[0].(map[string]any); ok {
if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" {
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward non-stream collect")
if content, ok := cand["content"]; ok {
if b, err := json.Marshal(content); err == nil {
log.Printf("[Antigravity] Malformed content: %s", string(b))
}
}
}
}
}
// 保留最后一个有 parts 的响应
if parts := extractGeminiParts(parsed); len(parts) > 0 {
lastWithParts = parsed
// 收集包含图片和文本的 parts
for _, part := range parts {
if inlineData, ok := part["inlineData"].(map[string]any); ok {
collectedImageParts = append(collectedImageParts, part)
_ = inlineData // 避免 unused 警告
}
if text, ok := part["text"].(string); ok && text != "" {
collectedTextParts = append(collectedTextParts, text)
}
}
}
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
log.Printf("Stream data interval timeout (antigravity non-stream)")
return nil, fmt.Errorf("stream data interval timeout")
}
}
returnResponse:
// 选择最后一个有效响应
finalResponse := pickGeminiCollectResult(last, lastWithParts)
// 处理空响应情况
if last == nil && lastWithParts == nil {
log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received")
}
// 如果收集到了图片 parts需要合并到最终响应中
if len(collectedImageParts) > 0 {
finalResponse = mergeImagePartsToResponse(finalResponse, collectedImageParts)
}
// 如果收集到了文本,需要合并到最终响应中
if len(collectedTextParts) > 0 {
finalResponse = mergeTextPartsToResponse(finalResponse, collectedTextParts)
}
respBody, err := json.Marshal(finalResponse)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}
c.Data(http.StatusOK, "application/json", respBody)
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
// getOrCreateGeminiParts 获取 Gemini 响应的 parts 结构,返回深拷贝和更新回调
func getOrCreateGeminiParts(response map[string]any) (result map[string]any, existingParts []any, setParts func([]any)) {
// 深拷贝 response
result = make(map[string]any)
for k, v := range response {
result[k] = v
}
// 获取或创建 candidates
candidates, ok := result["candidates"].([]any)
if !ok || len(candidates) == 0 {
candidates = []any{map[string]any{}}
}
// 获取第一个 candidate
candidate, ok := candidates[0].(map[string]any)
if !ok {
candidate = make(map[string]any)
candidates[0] = candidate
}
// 获取或创建 content
content, ok := candidate["content"].(map[string]any)
if !ok {
content = map[string]any{"role": "model"}
candidate["content"] = content
}
// 获取现有 parts
existingParts, ok = content["parts"].([]any)
if !ok {
existingParts = []any{}
}
// 返回更新回调
setParts = func(newParts []any) {
content["parts"] = newParts
result["candidates"] = candidates
}
return result, existingParts, setParts
}
// mergeCollectedPartsToResponse 将收集的所有 parts 合并到 Gemini 响应中
// 这个函数会合并所有类型的 partstext、thinking、functionCall、inlineData 等
// 保持原始顺序,只合并连续的普通 text parts
func mergeCollectedPartsToResponse(response map[string]any, collectedParts []map[string]any) map[string]any {
if len(collectedParts) == 0 {
return response
}
result, _, setParts := getOrCreateGeminiParts(response)
// 合并策略:
// 1. 保持原始顺序
// 2. 连续的普通 text parts 合并为一个
// 3. thinking、functionCall、inlineData 等保持原样
var mergedParts []any
var textBuffer strings.Builder
flushTextBuffer := func() {
if textBuffer.Len() > 0 {
mergedParts = append(mergedParts, map[string]any{
"text": textBuffer.String(),
})
textBuffer.Reset()
}
}
for _, part := range collectedParts {
// 检查是否是普通 text part
if text, ok := part["text"].(string); ok {
// 检查是否有 thought 标记
if thought, _ := part["thought"].(bool); thought {
// thinking part先刷新 text buffer然后保留原样
flushTextBuffer()
mergedParts = append(mergedParts, part)
} else {
// 普通 text累积到 buffer
_, _ = textBuffer.WriteString(text)
}
} else {
// 非 text partfunctionCall、inlineData 等),先刷新 text buffer然后保留原样
flushTextBuffer()
mergedParts = append(mergedParts, part)
}
}
// 刷新剩余的 text
flushTextBuffer()
setParts(mergedParts)
return result
}
// mergeImagePartsToResponse 将收集到的图片 parts 合并到 Gemini 响应中
func mergeImagePartsToResponse(response map[string]any, imageParts []map[string]any) map[string]any {
if len(imageParts) == 0 {
return response
}
result, existingParts, setParts := getOrCreateGeminiParts(response)
// 检查现有 parts 中是否已经有图片
for _, p := range existingParts {
if pm, ok := p.(map[string]any); ok {
if _, hasInline := pm["inlineData"]; hasInline {
return result // 已有图片,不重复添加
}
}
}
// 添加收集到的图片 parts
for _, imgPart := range imageParts {
existingParts = append(existingParts, imgPart)
}
setParts(existingParts)
return result
}
// mergeTextPartsToResponse 将收集到的文本合并到 Gemini 响应中
func mergeTextPartsToResponse(response map[string]any, textParts []string) map[string]any {
if len(textParts) == 0 {
return response
}
mergedText := strings.Join(textParts, "")
result, existingParts, setParts := getOrCreateGeminiParts(response)
// 查找并更新第一个 text part或创建新的
newParts := make([]any, 0, len(existingParts)+1)
textUpdated := false
for _, p := range existingParts {
pm, ok := p.(map[string]any)
if !ok {
newParts = append(newParts, p)
continue
}
if _, hasText := pm["text"]; hasText && !textUpdated {
// 用累积的文本替换
newPart := make(map[string]any)
for k, v := range pm {
newPart[k] = v
}
newPart["text"] = mergedText
newParts = append(newParts, newPart)
textUpdated = true
} else {
newParts = append(newParts, pm)
}
}
if !textUpdated {
newParts = append([]any{map[string]any{"text": mergedText}}, newParts...)
}
setParts(newParts)
return result
}
func (s *AntigravityGatewayService) writeClaudeError(c *gin.Context, status int, errType, message string) error {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{"type": errType, "message": message},
})
return fmt.Errorf("%s", message)
}
// WriteMappedClaudeError 导出版本,供 handler 层使用(如 fallback 错误处理)
func (s *AntigravityGatewayService) WriteMappedClaudeError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error {
return s.writeMappedClaudeError(c, account, upstreamStatus, upstreamRequestID, body)
}
func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error {
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
logBody, maxBytes := s.getLogConfig()
upstreamDetail := s.getUpstreamErrorDetail(body)
setOpsUpstreamError(c, upstreamStatus, upstreamMsg, upstreamDetail)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: upstreamStatus,
UpstreamRequestID: upstreamRequestID,
Kind: "http_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
// 记录上游错误详情便于排障(可选:由配置控制;不回显到客户端)
if logBody {
log.Printf("[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, truncateForLog(body, maxBytes))
}
var statusCode int
var errType, errMsg string
switch upstreamStatus {
case 400:
statusCode = http.StatusBadRequest
errType = "invalid_request_error"
errMsg = getPassthroughOrDefault(upstreamMsg, "Invalid request")
case 401:
statusCode = http.StatusBadGateway
errType = "authentication_error"
errMsg = "Upstream authentication failed"
case 403:
statusCode = http.StatusBadGateway
errType = "permission_error"
errMsg = "Upstream access forbidden"
case 429:
statusCode = http.StatusTooManyRequests
errType = "rate_limit_error"
errMsg = "Upstream rate limit exceeded"
case 529:
statusCode = http.StatusServiceUnavailable
errType = "overloaded_error"
errMsg = "Upstream service overloaded"
default:
statusCode = http.StatusBadGateway
errType = "upstream_error"
errMsg = "Upstream request failed"
}
c.JSON(statusCode, gin.H{
"type": "error",
"error": gin.H{"type": errType, "message": errMsg},
})
if upstreamMsg == "" {
return fmt.Errorf("upstream error: %d", upstreamStatus)
}
return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg)
}
func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error {
statusStr := "UNKNOWN"
switch status {
case 400:
statusStr = "INVALID_ARGUMENT"
case 404:
statusStr = "NOT_FOUND"
case 429:
statusStr = "RESOURCE_EXHAUSTED"
case 500:
statusStr = "INTERNAL"
case 502, 503:
statusStr = "UNAVAILABLE"
}
c.JSON(status, gin.H{
"error": gin.H{
"code": status,
"message": message,
"status": statusStr,
},
})
return fmt.Errorf("%s", message)
}
// handleClaudeStreamToNonStreaming 收集上游流式响应,转换为 Claude 非流式格式返回
// 用于处理客户端非流式请求但上游只支持流式的情况
func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
var firstTokenMs *int
var last map[string]any
var lastWithParts map[string]any
var collectedParts []map[string]any // 收集所有 parts包括 text、thinking、functionCall、inlineData 等)
type scanEvent struct {
line string
err error
}
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
streamInterval := time.Duration(0)
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
for {
select {
case ev, ok := <-events:
if !ok {
// 流结束,转换并返回响应
goto returnResponse
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long (antigravity claude non-stream): max_size=%d error=%v", maxLineSize, ev.err)
}
return nil, ev.err
}
line := ev.line
trimmed := strings.TrimRight(line, "\r\n")
if !strings.HasPrefix(trimmed, "data:") {
continue
}
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
if payload == "" || payload == "[DONE]" {
continue
}
// 解包 v1internal 响应
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
if parseErr != nil {
continue
}
var parsed map[string]any
if err := json.Unmarshal(inner, &parsed); err != nil {
continue
}
// 记录首 token 时间
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
last = parsed
// 保留最后一个有 parts 的响应,并收集所有 parts
if parts := extractGeminiParts(parsed); len(parts) > 0 {
lastWithParts = parsed
// 收集所有 partstext、thinking、functionCall、inlineData 等)
collectedParts = append(collectedParts, parts...)
}
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
log.Printf("Stream data interval timeout (antigravity claude non-stream)")
return nil, fmt.Errorf("stream data interval timeout")
}
}
returnResponse:
// 选择最后一个有效响应
finalResponse := pickGeminiCollectResult(last, lastWithParts)
// 处理空响应情况
if last == nil && lastWithParts == nil {
log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received")
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream")
}
// 将收集的所有 parts 合并到最终响应中
if len(collectedParts) > 0 {
finalResponse = mergeCollectedPartsToResponse(finalResponse, collectedParts)
}
// 序列化为 JSONGemini 格式)
geminiBody, err := json.Marshal(finalResponse)
if err != nil {
return nil, fmt.Errorf("failed to marshal gemini response: %w", err)
}
// 转换 Gemini 响应为 Claude 格式
claudeResp, agUsage, err := antigravity.TransformGeminiToClaude(geminiBody, originalModel)
if err != nil {
log.Printf("[antigravity-Forward] transform_error error=%v body=%s", err, string(geminiBody))
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
}
c.Data(http.StatusOK, "application/json", claudeResp)
// 转换为 service.ClaudeUsage
usage := &ClaudeUsage{
InputTokens: agUsage.InputTokens,
OutputTokens: agUsage.OutputTokens,
CacheCreationInputTokens: agUsage.CacheCreationInputTokens,
CacheReadInputTokens: agUsage.CacheReadInputTokens,
}
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
// handleClaudeStreamingResponse 处理 Claude 流式响应Gemini SSE → Claude SSE 转换)
func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
c.Status(http.StatusOK)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
processor := antigravity.NewStreamingProcessor(originalModel)
var firstTokenMs *int
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
// 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage
convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage {
if agUsage == nil {
return &ClaudeUsage{}
}
return &ClaudeUsage{
InputTokens: agUsage.InputTokens,
OutputTokens: agUsage.OutputTokens,
CacheCreationInputTokens: agUsage.CacheCreationInputTokens,
CacheReadInputTokens: agUsage.CacheReadInputTokens,
}
}
type scanEvent struct {
line string
err error
}
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
streamInterval := time.Duration(0)
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity claude")
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent || cw.Disconnected() {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
// finishUsage 是获取 processor 最终 usage 的辅助函数
finishUsage := func() *ClaudeUsage {
_, agUsage := processor.Finish()
return convertUsage(agUsage)
}
for {
select {
case ev, ok := <-events:
if !ok {
// 上游完成,发送结束事件
finalEvents, agUsage := processor.Finish()
if len(finalEvents) > 0 {
cw.Write(finalEvents)
}
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}, nil
}
if ev.err != nil {
if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity claude"); handled {
return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil
}
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return nil, fmt.Errorf("stream read error: %w", ev.err)
}
// 处理 SSE 行,转换为 Claude 格式
claudeEvents := processor.ProcessLine(strings.TrimRight(ev.line, "\r\n"))
if len(claudeEvents) > 0 {
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
cw.Write(claudeEvents)
}
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
if cw.Disconnected() {
log.Printf("Upstream timeout after client disconnect (antigravity claude), returning collected usage")
return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
log.Printf("Stream data interval timeout (antigravity)")
sendErrorEvent("stream_timeout")
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
}
}
// extractImageSize 从 Gemini 请求中提取 image_size 参数
func (s *AntigravityGatewayService) extractImageSize(body []byte) string {
var req antigravity.GeminiRequest
if err := json.Unmarshal(body, &req); err != nil {
return "2K" // 默认 2K
}
if req.GenerationConfig != nil && req.GenerationConfig.ImageConfig != nil {
size := strings.ToUpper(strings.TrimSpace(req.GenerationConfig.ImageConfig.ImageSize))
if size == "1K" || size == "2K" || size == "4K" {
return size
}
}
return "2K" // 默认 2K
}
// isImageGenerationModel 判断模型是否为图片生成模型
// 支持的模型gemini-3-pro-image, gemini-3-pro-image-preview, gemini-2.5-flash-image 等
func isImageGenerationModel(model string) bool {
modelLower := strings.ToLower(model)
// 移除 models/ 前缀
modelLower = strings.TrimPrefix(modelLower, "models/")
// 精确匹配或前缀匹配
return modelLower == "gemini-3-pro-image" ||
modelLower == "gemini-3-pro-image-preview" ||
strings.HasPrefix(modelLower, "gemini-3-pro-image-") ||
modelLower == "gemini-2.5-flash-image" ||
modelLower == "gemini-2.5-flash-image-preview" ||
strings.HasPrefix(modelLower, "gemini-2.5-flash-image-")
}
// cleanGeminiRequest 清理 Gemini 请求体中的 Schema
func cleanGeminiRequest(body []byte) ([]byte, error) {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return nil, err
}
modified := false
// 1. 清理 Tools
if tools, ok := payload["tools"].([]any); ok && len(tools) > 0 {
for _, t := range tools {
toolMap, ok := t.(map[string]any)
if !ok {
continue
}
// function_declarations (snake_case) or functionDeclarations (camelCase)
var funcs []any
if f, ok := toolMap["functionDeclarations"].([]any); ok {
funcs = f
} else if f, ok := toolMap["function_declarations"].([]any); ok {
funcs = f
}
if len(funcs) == 0 {
continue
}
for _, f := range funcs {
funcMap, ok := f.(map[string]any)
if !ok {
continue
}
if params, ok := funcMap["parameters"].(map[string]any); ok {
antigravity.DeepCleanUndefined(params)
cleaned := antigravity.CleanJSONSchema(params)
funcMap["parameters"] = cleaned
modified = true
}
}
}
}
if !modified {
return body, nil
}
return json.Marshal(payload)
}
// filterEmptyPartsFromGeminiRequest 过滤掉 parts 为空的消息
// Gemini API 不接受空 parts需要在请求前过滤
func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return nil, err
}
contents, ok := payload["contents"].([]any)
if !ok || len(contents) == 0 {
return body, nil
}
filtered := make([]any, 0, len(contents))
modified := false
for _, c := range contents {
contentMap, ok := c.(map[string]any)
if !ok {
filtered = append(filtered, c)
continue
}
parts, hasParts := contentMap["parts"]
if !hasParts {
filtered = append(filtered, c)
continue
}
partsSlice, ok := parts.([]any)
if !ok {
filtered = append(filtered, c)
continue
}
// 跳过 parts 为空数组的消息
if len(partsSlice) == 0 {
modified = true
continue
}
filtered = append(filtered, c)
}
if !modified {
return body, nil
}
payload["contents"] = filtered
return json.Marshal(payload)
}
// ForwardUpstream 使用 base_url + /v1/messages + 双 header 认证透传上游 Claude 请求
func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
startTime := time.Now()
sessionID := getSessionID(c)
prefix := logPrefix(sessionID, account.Name)
// 获取上游配置
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
if baseURL == "" || apiKey == "" {
return nil, fmt.Errorf("upstream account missing base_url or api_key")
}
baseURL = strings.TrimSuffix(baseURL, "/")
// 解析请求获取模型信息
var claudeReq antigravity.ClaudeRequest
if err := json.Unmarshal(body, &claudeReq); err != nil {
return nil, fmt.Errorf("parse claude request: %w", err)
}
if strings.TrimSpace(claudeReq.Model) == "" {
return nil, fmt.Errorf("missing model")
}
originalModel := claudeReq.Model
billingModel := originalModel
// 构建上游请求 URL
upstreamURL := baseURL + "/v1/messages"
// 创建请求
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("create upstream request: %w", err)
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("x-api-key", apiKey) // Claude API 兼容
// 透传 Claude 相关 headers
if v := c.GetHeader("anthropic-version"); v != "" {
req.Header.Set("anthropic-version", v)
}
if v := c.GetHeader("anthropic-beta"); v != "" {
req.Header.Set("anthropic-beta", v)
}
// 代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 发送请求
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
log.Printf("%s upstream request failed: %v", prefix, err)
return nil, fmt.Errorf("upstream request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
// 处理错误响应
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
// 429 错误时标记账号限流
if resp.StatusCode == http.StatusTooManyRequests {
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", false)
}
// 透传上游错误
c.Header("Content-Type", resp.Header.Get("Content-Type"))
c.Status(resp.StatusCode)
_, _ = c.Writer.Write(respBody)
return &ForwardResult{
Model: billingModel,
}, nil
}
// 处理成功响应(流式/非流式)
var usage *ClaudeUsage
var firstTokenMs *int
var clientDisconnect bool
if claudeReq.Stream {
// 流式响应:透传
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
c.Status(http.StatusOK)
streamRes := s.streamUpstreamResponse(c, resp, startTime)
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
clientDisconnect = streamRes.clientDisconnect
} else {
// 非流式响应:直接透传
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read upstream response: %w", err)
}
// 提取 usage
usage = s.extractClaudeUsage(respBody)
c.Header("Content-Type", resp.Header.Get("Content-Type"))
c.Status(http.StatusOK)
_, _ = c.Writer.Write(respBody)
}
// 构建计费结果
duration := time.Since(startTime)
log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds())
return &ForwardResult{
Model: billingModel,
Stream: claudeReq.Stream,
Duration: duration,
FirstTokenMs: firstTokenMs,
ClientDisconnect: clientDisconnect,
Usage: ClaudeUsage{
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
CacheReadInputTokens: usage.CacheReadInputTokens,
CacheCreationInputTokens: usage.CacheCreationInputTokens,
},
}, nil
}
// streamUpstreamResponse 透传上游 SSE 流并提取 Claude usage
func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp *http.Response, startTime time.Time) *antigravityStreamResult {
usage := &ClaudeUsage{}
var firstTokenMs *int
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
type scanEvent struct {
line string
err error
}
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
streamInterval := time.Duration(0)
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
flusher, _ := c.Writer.(http.Flusher)
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity upstream")
for {
select {
case ev, ok := <-events:
if !ok {
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}
}
if ev.err != nil {
if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity upstream"); handled {
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect}
}
log.Printf("Stream read error (antigravity upstream): %v", ev.err)
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
}
line := ev.line
// 记录首 token 时间
if firstTokenMs == nil && len(line) > 0 {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
// 尝试从 message_delta 或 message_stop 事件提取 usage
s.extractSSEUsage(line, usage)
// 透传行
cw.Fprintf("%s\n", line)
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
if cw.Disconnected() {
log.Printf("Upstream timeout after client disconnect (antigravity upstream), returning collected usage")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}
}
log.Printf("Stream data interval timeout (antigravity upstream)")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
}
}
}
// extractSSEUsage 从 SSE data 行中提取 Claude usage用于流式透传场景
func (s *AntigravityGatewayService) extractSSEUsage(line string, usage *ClaudeUsage) {
if !strings.HasPrefix(line, "data: ") {
return
}
dataStr := strings.TrimPrefix(line, "data: ")
var event map[string]any
if json.Unmarshal([]byte(dataStr), &event) != nil {
return
}
u, ok := event["usage"].(map[string]any)
if !ok {
return
}
if v, ok := u["input_tokens"].(float64); ok && int(v) > 0 {
usage.InputTokens = int(v)
}
if v, ok := u["output_tokens"].(float64); ok && int(v) > 0 {
usage.OutputTokens = int(v)
}
if v, ok := u["cache_read_input_tokens"].(float64); ok && int(v) > 0 {
usage.CacheReadInputTokens = int(v)
}
if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 {
usage.CacheCreationInputTokens = int(v)
}
}
// extractClaudeUsage 从非流式 Claude 响应提取 usage
func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage {
usage := &ClaudeUsage{}
var resp map[string]any
if json.Unmarshal(body, &resp) != nil {
return usage
}
if u, ok := resp["usage"].(map[string]any); ok {
if v, ok := u["input_tokens"].(float64); ok {
usage.InputTokens = int(v)
}
if v, ok := u["output_tokens"].(float64); ok {
usage.OutputTokens = int(v)
}
if v, ok := u["cache_read_input_tokens"].(float64); ok {
usage.CacheReadInputTokens = int(v)
}
if v, ok := u["cache_creation_input_tokens"].(float64); ok {
usage.CacheCreationInputTokens = int(v)
}
}
return usage
}