When custom error codes are enabled and the upstream error code is NOT in the configured list, return HTTP 500 to the client instead of transparently forwarding the original status code. Also adds integration test TestCustomErrorCode599 verifying that 429, 500, 503, 401, 403 all return 500 without triggering SetRateLimited or SetError.
3746 lines
124 KiB
Go
3746 lines
124 KiB
Go
package service
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"log"
|
||
"log/slog"
|
||
mathrand "math/rand"
|
||
"net"
|
||
"net/http"
|
||
"os"
|
||
"strconv"
|
||
"strings"
|
||
"sync/atomic"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/google/uuid"
|
||
)
|
||
|
||
const (
|
||
antigravityStickySessionTTL = time.Hour
|
||
antigravityMaxRetries = 3
|
||
antigravityRetryBaseDelay = 1 * time.Second
|
||
antigravityRetryMaxDelay = 16 * time.Second
|
||
|
||
// 限流相关常量
|
||
// antigravityRateLimitThreshold 限流等待/切换阈值
|
||
// - 智能重试:retryDelay < 此阈值时等待后重试,>= 此阈值时直接限流模型
|
||
// - 预检查:剩余限流时间 < 此阈值时等待,>= 此阈值时切换账号
|
||
antigravityRateLimitThreshold = 7 * time.Second
|
||
antigravitySmartRetryMinWait = 1 * time.Second // 智能重试最小等待时间
|
||
antigravitySmartRetryMaxAttempts = 1 // 智能重试最大次数(仅重试 1 次,防止重复限流/长期等待)
|
||
antigravityDefaultRateLimitDuration = 30 * time.Second // 默认限流时间(无 retryDelay 时使用)
|
||
|
||
// Google RPC 状态和类型常量
|
||
googleRPCStatusResourceExhausted = "RESOURCE_EXHAUSTED"
|
||
googleRPCStatusUnavailable = "UNAVAILABLE"
|
||
googleRPCTypeRetryInfo = "type.googleapis.com/google.rpc.RetryInfo"
|
||
googleRPCTypeErrorInfo = "type.googleapis.com/google.rpc.ErrorInfo"
|
||
googleRPCReasonModelCapacityExhausted = "MODEL_CAPACITY_EXHAUSTED"
|
||
googleRPCReasonRateLimitExceeded = "RATE_LIMIT_EXCEEDED"
|
||
)
|
||
|
||
// antigravityPassthroughErrorMessages 透传给客户端的错误消息白名单(小写)
|
||
// 匹配时使用 strings.Contains,无需完全匹配
|
||
var antigravityPassthroughErrorMessages = []string{
|
||
"prompt is too long",
|
||
}
|
||
|
||
const (
|
||
antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
|
||
antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
|
||
)
|
||
|
||
// AntigravityAccountSwitchError 账号切换信号
|
||
// 当账号限流时间超过阈值时,通知上层切换账号
|
||
type AntigravityAccountSwitchError struct {
|
||
OriginalAccountID int64
|
||
RateLimitedModel string
|
||
IsStickySession bool // 是否为粘性会话切换(决定是否缓存计费)
|
||
}
|
||
|
||
func (e *AntigravityAccountSwitchError) Error() string {
|
||
return fmt.Sprintf("account %d model %s rate limited, need switch",
|
||
e.OriginalAccountID, e.RateLimitedModel)
|
||
}
|
||
|
||
// IsAntigravityAccountSwitchError 检查错误是否为账号切换信号
|
||
func IsAntigravityAccountSwitchError(err error) (*AntigravityAccountSwitchError, bool) {
|
||
var switchErr *AntigravityAccountSwitchError
|
||
if errors.As(err, &switchErr) {
|
||
return switchErr, true
|
||
}
|
||
return nil, false
|
||
}
|
||
|
||
// PromptTooLongError 表示上游明确返回 prompt too long
|
||
type PromptTooLongError struct {
|
||
StatusCode int
|
||
RequestID string
|
||
Body []byte
|
||
}
|
||
|
||
func (e *PromptTooLongError) Error() string {
|
||
return fmt.Sprintf("prompt too long: status=%d", e.StatusCode)
|
||
}
|
||
|
||
// antigravityRetryLoopParams 重试循环的参数
|
||
type antigravityRetryLoopParams struct {
|
||
ctx context.Context
|
||
prefix string
|
||
account *Account
|
||
proxyURL string
|
||
accessToken string
|
||
action string
|
||
body []byte
|
||
c *gin.Context
|
||
httpUpstream HTTPUpstream
|
||
settingService *SettingService
|
||
accountRepo AccountRepository // 用于智能重试的模型级别限流
|
||
handleError func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult
|
||
requestedModel string // 用于限流检查的原始请求模型
|
||
isStickySession bool // 是否为粘性会话(用于账号切换时的缓存计费判断)
|
||
groupID int64 // 用于模型级限流时清除粘性会话
|
||
sessionHash string // 用于模型级限流时清除粘性会话
|
||
}
|
||
|
||
// antigravityRetryLoopResult 重试循环的结果
|
||
type antigravityRetryLoopResult struct {
|
||
resp *http.Response
|
||
}
|
||
|
||
// smartRetryAction 智能重试的处理结果
|
||
type smartRetryAction int
|
||
|
||
const (
|
||
smartRetryActionContinue smartRetryAction = iota // 继续默认重试逻辑
|
||
smartRetryActionBreakWithResp // 结束循环并返回 resp
|
||
smartRetryActionContinueURL // 继续 URL fallback 循环
|
||
)
|
||
|
||
// smartRetryResult 智能重试的结果
|
||
type smartRetryResult struct {
|
||
action smartRetryAction
|
||
resp *http.Response
|
||
err error
|
||
switchError *AntigravityAccountSwitchError // 模型限流时返回账号切换信号
|
||
}
|
||
|
||
// handleSmartRetry 处理 OAuth 账号的智能重试逻辑
|
||
// 将 429/503 限流处理逻辑抽取为独立函数,减少 antigravityRetryLoop 的复杂度
|
||
func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParams, resp *http.Response, respBody []byte, baseURL string, urlIdx int, availableURLs []string) *smartRetryResult {
|
||
// "Resource has been exhausted" 是 URL 级别限流,切换 URL(仅 429)
|
||
if resp.StatusCode == http.StatusTooManyRequests && isURLLevelRateLimit(respBody) && urlIdx < len(availableURLs)-1 {
|
||
log.Printf("%s URL fallback (429): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
|
||
return &smartRetryResult{action: smartRetryActionContinueURL}
|
||
}
|
||
|
||
// 判断是否触发智能重试
|
||
shouldSmartRetry, shouldRateLimitModel, waitDuration, modelName := shouldTriggerAntigravitySmartRetry(p.account, respBody)
|
||
|
||
// 情况1: retryDelay >= 阈值,限流模型并切换账号
|
||
if shouldRateLimitModel {
|
||
rateLimitDuration := waitDuration
|
||
if rateLimitDuration <= 0 {
|
||
rateLimitDuration = antigravityDefaultRateLimitDuration
|
||
}
|
||
log.Printf("%s status=%d oauth_long_delay model=%s account=%d upstream_retry_delay=%v body=%s (model rate limit, switch account)",
|
||
p.prefix, resp.StatusCode, modelName, p.account.ID, rateLimitDuration, truncateForLog(respBody, 200))
|
||
|
||
resetAt := time.Now().Add(rateLimitDuration)
|
||
if !setModelRateLimitByModelName(p.ctx, p.accountRepo, p.account.ID, modelName, p.prefix, resp.StatusCode, resetAt, false) {
|
||
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession)
|
||
log.Printf("%s status=%d rate_limited account=%d (no model mapping)", p.prefix, resp.StatusCode, p.account.ID)
|
||
} else {
|
||
s.updateAccountModelRateLimitInCache(p.ctx, p.account, modelName, resetAt)
|
||
}
|
||
|
||
// 返回账号切换信号,让上层切换账号重试
|
||
return &smartRetryResult{
|
||
action: smartRetryActionBreakWithResp,
|
||
switchError: &AntigravityAccountSwitchError{
|
||
OriginalAccountID: p.account.ID,
|
||
RateLimitedModel: modelName,
|
||
IsStickySession: p.isStickySession,
|
||
},
|
||
}
|
||
}
|
||
|
||
// 情况2: retryDelay < 阈值,智能重试(最多 antigravitySmartRetryMaxAttempts 次)
|
||
if shouldSmartRetry {
|
||
var lastRetryResp *http.Response
|
||
var lastRetryBody []byte
|
||
|
||
for attempt := 1; attempt <= antigravitySmartRetryMaxAttempts; attempt++ {
|
||
log.Printf("%s status=%d oauth_smart_retry attempt=%d/%d delay=%v model=%s account=%d",
|
||
p.prefix, resp.StatusCode, attempt, antigravitySmartRetryMaxAttempts, waitDuration, modelName, p.account.ID)
|
||
|
||
select {
|
||
case <-p.ctx.Done():
|
||
log.Printf("%s status=context_canceled_during_smart_retry", p.prefix)
|
||
return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()}
|
||
case <-time.After(waitDuration):
|
||
}
|
||
|
||
// 智能重试:创建新请求
|
||
retryReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body)
|
||
if err != nil {
|
||
log.Printf("%s status=smart_retry_request_build_failed error=%v", p.prefix, err)
|
||
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession)
|
||
return &smartRetryResult{
|
||
action: smartRetryActionBreakWithResp,
|
||
resp: &http.Response{
|
||
StatusCode: resp.StatusCode,
|
||
Header: resp.Header.Clone(),
|
||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||
},
|
||
}
|
||
}
|
||
|
||
retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency)
|
||
if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable {
|
||
log.Printf("%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, antigravitySmartRetryMaxAttempts)
|
||
return &smartRetryResult{action: smartRetryActionBreakWithResp, resp: retryResp}
|
||
}
|
||
|
||
// 网络错误时,继续重试
|
||
if retryErr != nil || retryResp == nil {
|
||
log.Printf("%s status=smart_retry_network_error attempt=%d/%d error=%v", p.prefix, attempt, antigravitySmartRetryMaxAttempts, retryErr)
|
||
continue
|
||
}
|
||
|
||
// 重试失败,关闭之前的响应
|
||
if lastRetryResp != nil {
|
||
_ = lastRetryResp.Body.Close()
|
||
}
|
||
lastRetryResp = retryResp
|
||
if retryResp != nil {
|
||
lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
||
_ = retryResp.Body.Close()
|
||
}
|
||
|
||
// 解析新的重试信息,用于下次重试的等待时间
|
||
if attempt < antigravitySmartRetryMaxAttempts && lastRetryBody != nil {
|
||
newShouldRetry, _, newWaitDuration, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody)
|
||
if newShouldRetry && newWaitDuration > 0 {
|
||
waitDuration = newWaitDuration
|
||
}
|
||
}
|
||
}
|
||
|
||
// 所有重试都失败,限流当前模型并切换账号
|
||
rateLimitDuration := waitDuration
|
||
if rateLimitDuration <= 0 {
|
||
rateLimitDuration = antigravityDefaultRateLimitDuration
|
||
}
|
||
retryBody := lastRetryBody
|
||
if retryBody == nil {
|
||
retryBody = respBody
|
||
}
|
||
log.Printf("%s status=%d smart_retry_exhausted attempts=%d model=%s account=%d upstream_retry_delay=%v body=%s (switch account)",
|
||
p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID, rateLimitDuration, truncateForLog(retryBody, 200))
|
||
|
||
resetAt := time.Now().Add(rateLimitDuration)
|
||
if p.accountRepo != nil && modelName != "" {
|
||
if err := p.accountRepo.SetModelRateLimit(p.ctx, p.account.ID, modelName, resetAt); err != nil {
|
||
log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", p.prefix, resp.StatusCode, modelName, err)
|
||
} else {
|
||
log.Printf("%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v",
|
||
p.prefix, resp.StatusCode, modelName, p.account.ID, rateLimitDuration)
|
||
s.updateAccountModelRateLimitInCache(p.ctx, p.account, modelName, resetAt)
|
||
}
|
||
}
|
||
|
||
// 清除粘性会话绑定,避免下次请求仍命中限流账号
|
||
if s.cache != nil && p.sessionHash != "" {
|
||
_ = s.cache.DeleteSessionAccountID(p.ctx, p.groupID, p.sessionHash)
|
||
}
|
||
|
||
// 返回账号切换信号,让上层切换账号重试
|
||
return &smartRetryResult{
|
||
action: smartRetryActionBreakWithResp,
|
||
switchError: &AntigravityAccountSwitchError{
|
||
OriginalAccountID: p.account.ID,
|
||
RateLimitedModel: modelName,
|
||
IsStickySession: p.isStickySession,
|
||
},
|
||
}
|
||
}
|
||
|
||
// 未触发智能重试,继续默认重试逻辑
|
||
return &smartRetryResult{action: smartRetryActionContinue}
|
||
}
|
||
|
||
// antigravityRetryLoop 执行带 URL fallback 的重试循环
|
||
func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
|
||
// 预检查:如果账号已限流,直接返回切换信号
|
||
if p.requestedModel != "" {
|
||
if remaining := p.account.GetRateLimitRemainingTimeWithContext(p.ctx, p.requestedModel); remaining > 0 {
|
||
log.Printf("%s pre_check: rate_limit_switch remaining=%v model=%s account=%d",
|
||
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
|
||
return nil, &AntigravityAccountSwitchError{
|
||
OriginalAccountID: p.account.ID,
|
||
RateLimitedModel: p.requestedModel,
|
||
IsStickySession: p.isStickySession,
|
||
}
|
||
}
|
||
}
|
||
|
||
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
||
if len(availableURLs) == 0 {
|
||
availableURLs = antigravity.BaseURLs
|
||
}
|
||
|
||
var resp *http.Response
|
||
var usedBaseURL string
|
||
logBody := p.settingService != nil && p.settingService.cfg != nil && p.settingService.cfg.Gateway.LogUpstreamErrorBody
|
||
maxBytes := 2048
|
||
if p.settingService != nil && p.settingService.cfg != nil && p.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
|
||
maxBytes = p.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||
}
|
||
getUpstreamDetail := func(body []byte) string {
|
||
if !logBody {
|
||
return ""
|
||
}
|
||
return truncateString(string(body), maxBytes)
|
||
}
|
||
|
||
urlFallbackLoop:
|
||
for urlIdx, baseURL := range availableURLs {
|
||
usedBaseURL = baseURL
|
||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
||
select {
|
||
case <-p.ctx.Done():
|
||
log.Printf("%s status=context_canceled error=%v", p.prefix, p.ctx.Err())
|
||
return nil, p.ctx.Err()
|
||
default:
|
||
}
|
||
|
||
upstreamReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// Capture upstream request body for ops retry of this attempt.
|
||
if p.c != nil && len(p.body) > 0 {
|
||
p.c.Set(OpsUpstreamRequestBodyKey, string(p.body))
|
||
}
|
||
|
||
resp, err = p.httpUpstream.Do(upstreamReq, p.proxyURL, p.account.ID, p.account.Concurrency)
|
||
if err == nil && resp == nil {
|
||
err = errors.New("upstream returned nil response")
|
||
}
|
||
if err != nil {
|
||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
|
||
Platform: p.account.Platform,
|
||
AccountID: p.account.ID,
|
||
AccountName: p.account.Name,
|
||
UpstreamStatusCode: 0,
|
||
Kind: "request_error",
|
||
Message: safeErr,
|
||
})
|
||
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||
log.Printf("%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
|
||
continue urlFallbackLoop
|
||
}
|
||
if attempt < antigravityMaxRetries {
|
||
log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, antigravityMaxRetries, err)
|
||
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
|
||
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
|
||
return nil, p.ctx.Err()
|
||
}
|
||
continue
|
||
}
|
||
log.Printf("%s status=request_failed retries_exhausted error=%v", p.prefix, err)
|
||
setOpsUpstreamError(p.c, 0, safeErr, "")
|
||
return nil, fmt.Errorf("upstream request failed after retries: %w", err)
|
||
}
|
||
|
||
// 统一处理错误响应
|
||
if resp.StatusCode >= 400 {
|
||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||
_ = resp.Body.Close()
|
||
|
||
// ★ 统一入口:自定义错误码 + 临时不可调度
|
||
if handled, outStatus, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled {
|
||
if policyErr != nil {
|
||
return nil, policyErr
|
||
}
|
||
resp = &http.Response{
|
||
StatusCode: outStatus,
|
||
Header: resp.Header.Clone(),
|
||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||
}
|
||
break urlFallbackLoop
|
||
}
|
||
|
||
// 429/503 限流处理:区分 URL 级别限流、智能重试和账户配额限流
|
||
if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable {
|
||
// 尝试智能重试处理(OAuth 账号专用)
|
||
smartResult := s.handleSmartRetry(p, resp, respBody, baseURL, urlIdx, availableURLs)
|
||
switch smartResult.action {
|
||
case smartRetryActionContinueURL:
|
||
continue urlFallbackLoop
|
||
case smartRetryActionBreakWithResp:
|
||
if smartResult.err != nil {
|
||
return nil, smartResult.err
|
||
}
|
||
// 模型限流时返回切换账号信号
|
||
if smartResult.switchError != nil {
|
||
return nil, smartResult.switchError
|
||
}
|
||
resp = smartResult.resp
|
||
break urlFallbackLoop
|
||
}
|
||
// smartRetryActionContinue: 继续默认重试逻辑
|
||
|
||
// 账户/模型配额限流,重试 3 次(指数退避)- 默认逻辑(非 OAuth 账号或解析失败)
|
||
if attempt < antigravityMaxRetries {
|
||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
|
||
Platform: p.account.Platform,
|
||
AccountID: p.account.ID,
|
||
AccountName: p.account.Name,
|
||
UpstreamStatusCode: resp.StatusCode,
|
||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||
Kind: "retry",
|
||
Message: upstreamMsg,
|
||
Detail: getUpstreamDetail(respBody),
|
||
})
|
||
log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200))
|
||
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
|
||
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
|
||
return nil, p.ctx.Err()
|
||
}
|
||
continue
|
||
}
|
||
|
||
// 重试用尽,标记账户限流
|
||
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession)
|
||
log.Printf("%s status=%d rate_limited base_url=%s body=%s", p.prefix, resp.StatusCode, baseURL, truncateForLog(respBody, 200))
|
||
resp = &http.Response{
|
||
StatusCode: resp.StatusCode,
|
||
Header: resp.Header.Clone(),
|
||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||
}
|
||
break urlFallbackLoop
|
||
}
|
||
|
||
// 其他可重试错误(500/502/504/529,不包括 429 和 503)
|
||
if shouldRetryAntigravityError(resp.StatusCode) {
|
||
if attempt < antigravityMaxRetries {
|
||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
|
||
Platform: p.account.Platform,
|
||
AccountID: p.account.ID,
|
||
AccountName: p.account.Name,
|
||
UpstreamStatusCode: resp.StatusCode,
|
||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||
Kind: "retry",
|
||
Message: upstreamMsg,
|
||
Detail: getUpstreamDetail(respBody),
|
||
})
|
||
log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
|
||
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
|
||
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
|
||
return nil, p.ctx.Err()
|
||
}
|
||
continue
|
||
}
|
||
}
|
||
|
||
// 其他 4xx 错误或重试用尽,直接返回
|
||
resp = &http.Response{
|
||
StatusCode: resp.StatusCode,
|
||
Header: resp.Header.Clone(),
|
||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||
}
|
||
break urlFallbackLoop
|
||
}
|
||
|
||
// 成功响应(< 400)
|
||
break urlFallbackLoop
|
||
}
|
||
}
|
||
|
||
if resp != nil && resp.StatusCode < 400 && usedBaseURL != "" {
|
||
antigravity.DefaultURLAvailability.MarkSuccess(usedBaseURL)
|
||
}
|
||
|
||
return &antigravityRetryLoopResult{resp: resp}, nil
|
||
}
|
||
|
||
// shouldRetryAntigravityError 判断是否应该重试
|
||
func shouldRetryAntigravityError(statusCode int) bool {
|
||
switch statusCode {
|
||
case 429, 500, 502, 503, 504, 529:
|
||
return true
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
// isURLLevelRateLimit 判断是否为 URL 级别的限流(应切换 URL 重试)
|
||
// "Resource has been exhausted" 是 URL/节点级别限流,切换 URL 可能成功
|
||
// "exhausted your capacity on this model" 是账户/模型配额限流,切换 URL 无效
|
||
func isURLLevelRateLimit(body []byte) bool {
|
||
// 快速检查:包含 "Resource has been exhausted" 且不包含 "capacity on this model"
|
||
bodyStr := string(body)
|
||
return strings.Contains(bodyStr, "Resource has been exhausted") &&
|
||
!strings.Contains(bodyStr, "capacity on this model")
|
||
}
|
||
|
||
// isAntigravityConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝)
|
||
func isAntigravityConnectionError(err error) bool {
|
||
if err == nil {
|
||
return false
|
||
}
|
||
|
||
// 检查超时错误
|
||
var netErr net.Error
|
||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||
return true
|
||
}
|
||
|
||
// 检查连接错误(DNS 失败、连接拒绝)
|
||
var opErr *net.OpError
|
||
return errors.As(err, &opErr)
|
||
}
|
||
|
||
// shouldAntigravityFallbackToNextURL 判断是否应切换到下一个 URL
|
||
// 仅连接错误和 HTTP 429 触发 URL 降级
|
||
func shouldAntigravityFallbackToNextURL(err error, statusCode int) bool {
|
||
if isAntigravityConnectionError(err) {
|
||
return true
|
||
}
|
||
return statusCode == http.StatusTooManyRequests
|
||
}
|
||
|
||
// getSessionID 从 gin.Context 获取 session_id(用于日志追踪)
|
||
func getSessionID(c *gin.Context) string {
|
||
if c == nil {
|
||
return ""
|
||
}
|
||
return c.GetHeader("session_id")
|
||
}
|
||
|
||
// logPrefix 生成统一的日志前缀
|
||
func logPrefix(sessionID, accountName string) string {
|
||
if sessionID != "" {
|
||
return fmt.Sprintf("[antigravity-Forward] session=%s account=%s", sessionID, accountName)
|
||
}
|
||
return fmt.Sprintf("[antigravity-Forward] account=%s", accountName)
|
||
}
|
||
|
||
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
|
||
type AntigravityGatewayService struct {
|
||
accountRepo AccountRepository
|
||
tokenProvider *AntigravityTokenProvider
|
||
rateLimitService *RateLimitService
|
||
httpUpstream HTTPUpstream
|
||
settingService *SettingService
|
||
cache GatewayCache // 用于模型级限流时清除粘性会话绑定
|
||
schedulerSnapshot *SchedulerSnapshotService
|
||
}
|
||
|
||
func NewAntigravityGatewayService(
|
||
accountRepo AccountRepository,
|
||
cache GatewayCache,
|
||
schedulerSnapshot *SchedulerSnapshotService,
|
||
tokenProvider *AntigravityTokenProvider,
|
||
rateLimitService *RateLimitService,
|
||
httpUpstream HTTPUpstream,
|
||
settingService *SettingService,
|
||
) *AntigravityGatewayService {
|
||
return &AntigravityGatewayService{
|
||
accountRepo: accountRepo,
|
||
tokenProvider: tokenProvider,
|
||
rateLimitService: rateLimitService,
|
||
httpUpstream: httpUpstream,
|
||
settingService: settingService,
|
||
cache: cache,
|
||
schedulerSnapshot: schedulerSnapshot,
|
||
}
|
||
}
|
||
|
||
// GetTokenProvider 返回 token provider
|
||
func (s *AntigravityGatewayService) GetTokenProvider() *AntigravityTokenProvider {
|
||
return s.tokenProvider
|
||
}
|
||
|
||
// getLogConfig 获取上游错误日志配置
|
||
// 返回是否记录日志体和最大字节数
|
||
func (s *AntigravityGatewayService) getLogConfig() (logBody bool, maxBytes int) {
|
||
maxBytes = 2048 // 默认值
|
||
if s.settingService == nil || s.settingService.cfg == nil {
|
||
return false, maxBytes
|
||
}
|
||
cfg := s.settingService.cfg.Gateway
|
||
if cfg.LogUpstreamErrorBodyMaxBytes > 0 {
|
||
maxBytes = cfg.LogUpstreamErrorBodyMaxBytes
|
||
}
|
||
return cfg.LogUpstreamErrorBody, maxBytes
|
||
}
|
||
|
||
// getUpstreamErrorDetail 获取上游错误详情(用于日志记录)
|
||
func (s *AntigravityGatewayService) getUpstreamErrorDetail(body []byte) string {
|
||
logBody, maxBytes := s.getLogConfig()
|
||
if !logBody {
|
||
return ""
|
||
}
|
||
return truncateString(string(body), maxBytes)
|
||
}
|
||
|
||
// checkErrorPolicy nil 安全的包装
|
||
func (s *AntigravityGatewayService) checkErrorPolicy(ctx context.Context, account *Account, statusCode int, body []byte) ErrorPolicyResult {
|
||
if s.rateLimitService == nil {
|
||
return ErrorPolicyNone
|
||
}
|
||
return s.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, body)
|
||
}
|
||
|
||
// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环及应返回的状态码。
|
||
// ErrorPolicySkipped 时 outStatus 为 500(前端约定:未命中的错误返回 500)。
|
||
func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, outStatus int, retErr error) {
|
||
switch s.checkErrorPolicy(p.ctx, p.account, statusCode, respBody) {
|
||
case ErrorPolicySkipped:
|
||
return true, http.StatusInternalServerError, nil
|
||
case ErrorPolicyMatched:
|
||
_ = p.handleError(p.ctx, p.prefix, p.account, statusCode, headers, respBody,
|
||
p.requestedModel, p.groupID, p.sessionHash, p.isStickySession)
|
||
return true, statusCode, nil
|
||
case ErrorPolicyTempUnscheduled:
|
||
slog.Info("temp_unschedulable_matched",
|
||
"prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID)
|
||
return true, statusCode, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession}
|
||
}
|
||
return false, statusCode, nil
|
||
}
|
||
|
||
// mapAntigravityModel 获取映射后的模型名
|
||
// 完全依赖映射配置:账户映射(通配符)→ 默认映射兜底(DefaultAntigravityModelMapping)
|
||
// 注意:返回空字符串表示模型不被支持,调度时会过滤掉该账号
|
||
func mapAntigravityModel(account *Account, requestedModel string) string {
|
||
if account == nil {
|
||
return ""
|
||
}
|
||
|
||
// 获取映射表(未配置时自动使用 DefaultAntigravityModelMapping)
|
||
mapping := account.GetModelMapping()
|
||
if len(mapping) == 0 {
|
||
return "" // 无映射配置(非 Antigravity 平台)
|
||
}
|
||
|
||
// 通过映射表查询(支持精确匹配 + 通配符)
|
||
mapped := account.GetMappedModel(requestedModel)
|
||
|
||
// 判断是否映射成功(mapped != requestedModel 说明找到了映射规则)
|
||
if mapped != requestedModel {
|
||
return mapped
|
||
}
|
||
|
||
// 如果 mapped == requestedModel,检查是否在映射表中配置(精确或通配符)
|
||
// 这区分两种情况:
|
||
// 1. 映射表中有 "model-a": "model-a"(显式透传)→ 返回 model-a
|
||
// 2. 通配符匹配 "claude-*": "claude-sonnet-4-5" 恰好目标等于请求名 → 返回 model-a
|
||
// 3. 映射表中没有 model-a 的配置 → 返回空(不支持)
|
||
if account.IsModelSupported(requestedModel) {
|
||
return requestedModel
|
||
}
|
||
|
||
// 未在映射表中配置的模型,返回空字符串(不支持)
|
||
return ""
|
||
}
|
||
|
||
// getMappedModel 获取映射后的模型名
|
||
// 完全依赖映射配置:账户映射(通配符)→ 默认映射兜底
|
||
func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedModel string) string {
|
||
return mapAntigravityModel(account, requestedModel)
|
||
}
|
||
|
||
// applyThinkingModelSuffix 根据 thinking 配置调整模型名
|
||
// 当映射结果是 claude-sonnet-4-5 且请求开启了 thinking 时,改为 claude-sonnet-4-5-thinking
|
||
func applyThinkingModelSuffix(mappedModel string, thinkingEnabled bool) string {
|
||
if !thinkingEnabled {
|
||
return mappedModel
|
||
}
|
||
if mappedModel == "claude-sonnet-4-5" {
|
||
return "claude-sonnet-4-5-thinking"
|
||
}
|
||
return mappedModel
|
||
}
|
||
|
||
// IsModelSupported 检查模型是否被支持
|
||
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持
|
||
func (s *AntigravityGatewayService) IsModelSupported(requestedModel string) bool {
|
||
return strings.HasPrefix(requestedModel, "claude-") ||
|
||
strings.HasPrefix(requestedModel, "gemini-")
|
||
}
|
||
|
||
// TestConnectionResult 测试连接结果
|
||
type TestConnectionResult struct {
|
||
Text string // 响应文本
|
||
MappedModel string // 实际使用的模型
|
||
}
|
||
|
||
// TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费)
|
||
// 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择
|
||
func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
|
||
|
||
// 获取 token
|
||
if s.tokenProvider == nil {
|
||
return nil, errors.New("antigravity token provider not configured")
|
||
}
|
||
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取 access_token 失败: %w", err)
|
||
}
|
||
|
||
// 获取 project_id(部分账户类型可能没有)
|
||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||
|
||
// 模型映射
|
||
mappedModel := s.getMappedModel(account, modelID)
|
||
if mappedModel == "" {
|
||
return nil, fmt.Errorf("model %s not in whitelist", modelID)
|
||
}
|
||
|
||
// 构建请求体
|
||
var requestBody []byte
|
||
if strings.HasPrefix(modelID, "gemini-") {
|
||
// Gemini 模型:直接使用 Gemini 格式
|
||
requestBody, err = s.buildGeminiTestRequest(projectID, mappedModel)
|
||
} else {
|
||
// Claude 模型:使用协议转换
|
||
requestBody, err = s.buildClaudeTestRequest(projectID, mappedModel)
|
||
}
|
||
if err != nil {
|
||
return nil, fmt.Errorf("构建请求失败: %w", err)
|
||
}
|
||
|
||
// 代理 URL
|
||
proxyURL := ""
|
||
if account.ProxyID != nil && account.Proxy != nil {
|
||
proxyURL = account.Proxy.URL()
|
||
}
|
||
|
||
// URL fallback 循环
|
||
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
||
if len(availableURLs) == 0 {
|
||
availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有
|
||
}
|
||
|
||
var lastErr error
|
||
for urlIdx, baseURL := range availableURLs {
|
||
// 构建 HTTP 请求(总是使用流式 endpoint,与官方客户端一致)
|
||
req, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, "streamGenerateContent", accessToken, requestBody)
|
||
if err != nil {
|
||
lastErr = err
|
||
continue
|
||
}
|
||
|
||
// 调试日志:Test 请求信息
|
||
log.Printf("[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String())
|
||
|
||
// 发送请求
|
||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||
if err != nil {
|
||
lastErr = fmt.Errorf("请求失败: %w", err)
|
||
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||
log.Printf("[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
||
continue
|
||
}
|
||
return nil, lastErr
|
||
}
|
||
|
||
// 读取响应
|
||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||
_ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏
|
||
if err != nil {
|
||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||
}
|
||
|
||
// 检查是否需要 URL 降级
|
||
if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
||
log.Printf("[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
||
continue
|
||
}
|
||
|
||
if resp.StatusCode >= 400 {
|
||
return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody))
|
||
}
|
||
|
||
// 解析流式响应,提取文本
|
||
text := extractTextFromSSEResponse(respBody)
|
||
|
||
// 标记成功的 URL,下次优先使用
|
||
antigravity.DefaultURLAvailability.MarkSuccess(baseURL)
|
||
return &TestConnectionResult{
|
||
Text: text,
|
||
MappedModel: mappedModel,
|
||
}, nil
|
||
}
|
||
|
||
return nil, lastErr
|
||
}
|
||
|
||
// buildGeminiTestRequest 构建 Gemini 格式测试请求
|
||
// 使用最小 token 消耗:输入 "." + maxOutputTokens: 1
|
||
func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model string) ([]byte, error) {
|
||
payload := map[string]any{
|
||
"contents": []map[string]any{
|
||
{
|
||
"role": "user",
|
||
"parts": []map[string]any{
|
||
{"text": "."},
|
||
},
|
||
},
|
||
},
|
||
// Antigravity 上游要求必须包含身份提示词
|
||
"systemInstruction": map[string]any{
|
||
"parts": []map[string]any{
|
||
{"text": antigravity.GetDefaultIdentityPatch()},
|
||
},
|
||
},
|
||
"generationConfig": map[string]any{
|
||
"maxOutputTokens": 1,
|
||
},
|
||
}
|
||
payloadBytes, _ := json.Marshal(payload)
|
||
return s.wrapV1InternalRequest(projectID, model, payloadBytes)
|
||
}
|
||
|
||
// buildClaudeTestRequest 构建 Claude 格式测试请求并转换为 Gemini 格式
|
||
// 使用最小 token 消耗:输入 "." + MaxTokens: 1
|
||
func (s *AntigravityGatewayService) buildClaudeTestRequest(projectID, mappedModel string) ([]byte, error) {
|
||
claudeReq := &antigravity.ClaudeRequest{
|
||
Model: mappedModel,
|
||
Messages: []antigravity.ClaudeMessage{
|
||
{
|
||
Role: "user",
|
||
Content: json.RawMessage(`"."`),
|
||
},
|
||
},
|
||
MaxTokens: 1,
|
||
Stream: false,
|
||
}
|
||
return antigravity.TransformClaudeToGemini(claudeReq, projectID, mappedModel)
|
||
}
|
||
|
||
func (s *AntigravityGatewayService) getClaudeTransformOptions(ctx context.Context) antigravity.TransformOptions {
|
||
opts := antigravity.DefaultTransformOptions()
|
||
if s.settingService == nil {
|
||
return opts
|
||
}
|
||
opts.EnableIdentityPatch = s.settingService.IsIdentityPatchEnabled(ctx)
|
||
opts.IdentityPatch = s.settingService.GetIdentityPatchPrompt(ctx)
|
||
return opts
|
||
}
|
||
|
||
// extractTextFromSSEResponse 从 SSE 流式响应中提取文本
|
||
func extractTextFromSSEResponse(respBody []byte) string {
|
||
var texts []string
|
||
lines := bytes.Split(respBody, []byte("\n"))
|
||
|
||
for _, line := range lines {
|
||
line = bytes.TrimSpace(line)
|
||
if len(line) == 0 {
|
||
continue
|
||
}
|
||
|
||
// 跳过 SSE 前缀
|
||
if bytes.HasPrefix(line, []byte("data:")) {
|
||
line = bytes.TrimPrefix(line, []byte("data:"))
|
||
line = bytes.TrimSpace(line)
|
||
}
|
||
|
||
// 跳过非 JSON 行
|
||
if len(line) == 0 || line[0] != '{' {
|
||
continue
|
||
}
|
||
|
||
// 解析 JSON
|
||
var data map[string]any
|
||
if err := json.Unmarshal(line, &data); err != nil {
|
||
continue
|
||
}
|
||
|
||
// 尝试从 response.candidates[0].content.parts[].text 提取
|
||
response, ok := data["response"].(map[string]any)
|
||
if !ok {
|
||
// 尝试直接从 candidates 提取(某些响应格式)
|
||
response = data
|
||
}
|
||
|
||
candidates, ok := response["candidates"].([]any)
|
||
if !ok || len(candidates) == 0 {
|
||
continue
|
||
}
|
||
|
||
candidate, ok := candidates[0].(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
content, ok := candidate["content"].(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
parts, ok := content["parts"].([]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
for _, part := range parts {
|
||
if partMap, ok := part.(map[string]any); ok {
|
||
if text, ok := partMap["text"].(string); ok && text != "" {
|
||
texts = append(texts, text)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
return strings.Join(texts, "")
|
||
}
|
||
|
||
// injectIdentityPatchToGeminiRequest 为 Gemini 格式请求注入身份提示词
|
||
// 如果请求中已包含 "You are Antigravity" 则不重复注入
|
||
func injectIdentityPatchToGeminiRequest(body []byte) ([]byte, error) {
|
||
var request map[string]any
|
||
if err := json.Unmarshal(body, &request); err != nil {
|
||
return nil, fmt.Errorf("解析 Gemini 请求失败: %w", err)
|
||
}
|
||
|
||
// 检查现有 systemInstruction 是否已包含身份提示词
|
||
if sysInst, ok := request["systemInstruction"].(map[string]any); ok {
|
||
if parts, ok := sysInst["parts"].([]any); ok {
|
||
for _, part := range parts {
|
||
if partMap, ok := part.(map[string]any); ok {
|
||
if text, ok := partMap["text"].(string); ok {
|
||
if strings.Contains(text, "You are Antigravity") {
|
||
// 已包含身份提示词,直接返回原始请求
|
||
return body, nil
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 获取默认身份提示词
|
||
identityPatch := antigravity.GetDefaultIdentityPatch()
|
||
|
||
// 构建新的 systemInstruction
|
||
newPart := map[string]any{"text": identityPatch}
|
||
|
||
if existing, ok := request["systemInstruction"].(map[string]any); ok {
|
||
// 已有 systemInstruction,在开头插入身份提示词
|
||
if parts, ok := existing["parts"].([]any); ok {
|
||
existing["parts"] = append([]any{newPart}, parts...)
|
||
} else {
|
||
existing["parts"] = []any{newPart}
|
||
}
|
||
} else {
|
||
// 没有 systemInstruction,创建新的
|
||
request["systemInstruction"] = map[string]any{
|
||
"parts": []any{newPart},
|
||
}
|
||
}
|
||
|
||
return json.Marshal(request)
|
||
}
|
||
|
||
// wrapV1InternalRequest 包装请求为 v1internal 格式
|
||
func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model string, originalBody []byte) ([]byte, error) {
|
||
var request any
|
||
if err := json.Unmarshal(originalBody, &request); err != nil {
|
||
return nil, fmt.Errorf("解析请求体失败: %w", err)
|
||
}
|
||
|
||
wrapped := map[string]any{
|
||
"project": projectID,
|
||
"requestId": "agent-" + uuid.New().String(),
|
||
"userAgent": "antigravity", // 固定值,与官方客户端一致
|
||
"requestType": "agent",
|
||
"model": model,
|
||
"request": request,
|
||
}
|
||
|
||
return json.Marshal(wrapped)
|
||
}
|
||
|
||
// unwrapV1InternalResponse 解包 v1internal 响应
|
||
func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byte, error) {
|
||
var outer map[string]any
|
||
if err := json.Unmarshal(body, &outer); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if resp, ok := outer["response"]; ok {
|
||
return json.Marshal(resp)
|
||
}
|
||
|
||
return body, nil
|
||
}
|
||
|
||
// isModelNotFoundError 检测是否为模型不存在的 404 错误
|
||
func isModelNotFoundError(statusCode int, body []byte) bool {
|
||
if statusCode != 404 {
|
||
return false
|
||
}
|
||
|
||
bodyStr := strings.ToLower(string(body))
|
||
keywords := []string{"model not found", "unknown model", "not found"}
|
||
for _, keyword := range keywords {
|
||
if strings.Contains(bodyStr, keyword) {
|
||
return true
|
||
}
|
||
}
|
||
return true // 404 without specific message also treated as model not found
|
||
}
|
||
|
||
// Forward 转发 Claude 协议请求(Claude → Gemini 转换)
|
||
//
|
||
// 限流处理流程:
|
||
//
|
||
// 请求 → antigravityRetryLoop → 预检查(remaining>0? → 切换账号) → 发送上游
|
||
// ├─ 成功 → 正常返回
|
||
// └─ 429/503 → handleSmartRetry
|
||
// ├─ retryDelay >= 7s → 设置模型限流 + 清除粘性绑定 → 切换账号
|
||
// └─ retryDelay < 7s → 等待后重试 1 次
|
||
// ├─ 成功 → 正常返回
|
||
// └─ 失败 → 设置模型限流 + 清除粘性绑定 → 切换账号
|
||
func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) {
|
||
// 上游透传账号直接转发,不走 OAuth token 刷新
|
||
if account.Type == AccountTypeUpstream {
|
||
return s.ForwardUpstream(ctx, c, account, body)
|
||
}
|
||
|
||
startTime := time.Now()
|
||
|
||
sessionID := getSessionID(c)
|
||
prefix := logPrefix(sessionID, account.Name)
|
||
|
||
// 解析 Claude 请求
|
||
var claudeReq antigravity.ClaudeRequest
|
||
if err := json.Unmarshal(body, &claudeReq); err != nil {
|
||
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Invalid request body")
|
||
}
|
||
if strings.TrimSpace(claudeReq.Model) == "" {
|
||
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Missing model")
|
||
}
|
||
|
||
originalModel := claudeReq.Model
|
||
mappedModel := s.getMappedModel(account, claudeReq.Model)
|
||
if mappedModel == "" {
|
||
return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model))
|
||
}
|
||
// 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本
|
||
thinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
|
||
mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled)
|
||
|
||
// 获取 access_token
|
||
if s.tokenProvider == nil {
|
||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "api_error", "Antigravity token provider not configured")
|
||
}
|
||
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
|
||
if err != nil {
|
||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "authentication_error", "Failed to get upstream access token")
|
||
}
|
||
|
||
// 获取 project_id(部分账户类型可能没有)
|
||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||
|
||
// 代理 URL
|
||
proxyURL := ""
|
||
if account.ProxyID != nil && account.Proxy != nil {
|
||
proxyURL = account.Proxy.URL()
|
||
}
|
||
|
||
// 获取转换选项
|
||
// Antigravity 上游要求必须包含身份提示词,否则会返回 429
|
||
transformOpts := s.getClaudeTransformOptions(ctx)
|
||
transformOpts.EnableIdentityPatch = true // 强制启用,Antigravity 上游必需
|
||
|
||
// 转换 Claude 请求为 Gemini 格式
|
||
geminiBody, err := antigravity.TransformClaudeToGeminiWithOptions(&claudeReq, projectID, mappedModel, transformOpts)
|
||
if err != nil {
|
||
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Invalid request")
|
||
}
|
||
|
||
// Antigravity 上游只支持流式请求,统一使用 streamGenerateContent
|
||
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
|
||
action := "streamGenerateContent"
|
||
|
||
// 执行带重试的请求
|
||
result, err := s.antigravityRetryLoop(antigravityRetryLoopParams{
|
||
ctx: ctx,
|
||
prefix: prefix,
|
||
account: account,
|
||
proxyURL: proxyURL,
|
||
accessToken: accessToken,
|
||
action: action,
|
||
body: geminiBody,
|
||
c: c,
|
||
httpUpstream: s.httpUpstream,
|
||
settingService: s.settingService,
|
||
accountRepo: s.accountRepo,
|
||
handleError: s.handleUpstreamError,
|
||
requestedModel: originalModel,
|
||
isStickySession: isStickySession, // Forward 由上层判断粘性会话
|
||
groupID: 0, // Forward 方法没有 groupID,由上层处理粘性会话清除
|
||
sessionHash: "", // Forward 方法没有 sessionHash,由上层处理粘性会话清除
|
||
})
|
||
if err != nil {
|
||
// 检查是否是账号切换信号,转换为 UpstreamFailoverError 让 Handler 切换账号
|
||
if switchErr, ok := IsAntigravityAccountSwitchError(err); ok {
|
||
return nil, &UpstreamFailoverError{
|
||
StatusCode: http.StatusServiceUnavailable,
|
||
ForceCacheBilling: switchErr.IsStickySession,
|
||
}
|
||
}
|
||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
|
||
}
|
||
resp := result.resp
|
||
defer func() { _ = resp.Body.Close() }()
|
||
|
||
if resp.StatusCode >= 400 {
|
||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||
|
||
// 优先检测 thinking block 的 signature 相关错误(400)并重试一次:
|
||
// Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验,
|
||
// 当历史消息携带的 signature 不合法时会直接 400;去除 thinking 后可继续完成请求。
|
||
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) {
|
||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||
logBody, maxBytes := s.getLogConfig()
|
||
upstreamDetail := s.getUpstreamErrorDetail(respBody)
|
||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||
Platform: account.Platform,
|
||
AccountID: account.ID,
|
||
AccountName: account.Name,
|
||
UpstreamStatusCode: resp.StatusCode,
|
||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||
Kind: "signature_error",
|
||
Message: upstreamMsg,
|
||
Detail: upstreamDetail,
|
||
})
|
||
|
||
// Conservative two-stage fallback:
|
||
// 1) Disable top-level thinking + thinking->text
|
||
// 2) Only if still signature-related 400: also downgrade tool_use/tool_result to text.
|
||
|
||
retryStages := []struct {
|
||
name string
|
||
strip func(*antigravity.ClaudeRequest) (bool, error)
|
||
}{
|
||
{name: "thinking-only", strip: stripThinkingFromClaudeRequest},
|
||
{name: "thinking+tools", strip: stripSignatureSensitiveBlocksFromClaudeRequest},
|
||
}
|
||
|
||
for _, stage := range retryStages {
|
||
retryClaudeReq := claudeReq
|
||
retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...)
|
||
|
||
stripped, stripErr := stage.strip(&retryClaudeReq)
|
||
if stripErr != nil || !stripped {
|
||
continue
|
||
}
|
||
|
||
log.Printf("Antigravity account %d: detected signature-related 400, retrying once (%s)", account.ID, stage.name)
|
||
|
||
retryGeminiBody, txErr := antigravity.TransformClaudeToGeminiWithOptions(&retryClaudeReq, projectID, mappedModel, s.getClaudeTransformOptions(ctx))
|
||
if txErr != nil {
|
||
continue
|
||
}
|
||
retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{
|
||
ctx: ctx,
|
||
prefix: prefix,
|
||
account: account,
|
||
proxyURL: proxyURL,
|
||
accessToken: accessToken,
|
||
action: action,
|
||
body: retryGeminiBody,
|
||
c: c,
|
||
httpUpstream: s.httpUpstream,
|
||
settingService: s.settingService,
|
||
accountRepo: s.accountRepo,
|
||
handleError: s.handleUpstreamError,
|
||
requestedModel: originalModel,
|
||
isStickySession: isStickySession,
|
||
groupID: 0, // Forward 方法没有 groupID,由上层处理粘性会话清除
|
||
sessionHash: "", // Forward 方法没有 sessionHash,由上层处理粘性会话清除
|
||
})
|
||
if retryErr != nil {
|
||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||
Platform: account.Platform,
|
||
AccountID: account.ID,
|
||
AccountName: account.Name,
|
||
UpstreamStatusCode: 0,
|
||
Kind: "signature_retry_request_error",
|
||
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
|
||
})
|
||
log.Printf("Antigravity account %d: signature retry request failed (%s): %v", account.ID, stage.name, retryErr)
|
||
continue
|
||
}
|
||
|
||
retryResp := retryResult.resp
|
||
if retryResp.StatusCode < 400 {
|
||
_ = resp.Body.Close()
|
||
resp = retryResp
|
||
respBody = nil
|
||
break
|
||
}
|
||
|
||
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
||
_ = retryResp.Body.Close()
|
||
if retryResp.StatusCode == http.StatusTooManyRequests {
|
||
retryBaseURL := ""
|
||
if retryResp.Request != nil && retryResp.Request.URL != nil {
|
||
retryBaseURL = retryResp.Request.URL.Scheme + "://" + retryResp.Request.URL.Host
|
||
}
|
||
log.Printf("%s status=429 rate_limited base_url=%s retry_stage=%s body=%s", prefix, retryBaseURL, stage.name, truncateForLog(retryBody, 200))
|
||
}
|
||
kind := "signature_retry"
|
||
if strings.TrimSpace(stage.name) != "" {
|
||
kind = "signature_retry_" + strings.ReplaceAll(stage.name, "+", "_")
|
||
}
|
||
retryUpstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(retryBody))
|
||
retryUpstreamMsg = sanitizeUpstreamErrorMessage(retryUpstreamMsg)
|
||
retryUpstreamDetail := ""
|
||
if logBody {
|
||
retryUpstreamDetail = truncateString(string(retryBody), maxBytes)
|
||
}
|
||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||
Platform: account.Platform,
|
||
AccountID: account.ID,
|
||
AccountName: account.Name,
|
||
UpstreamStatusCode: retryResp.StatusCode,
|
||
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
|
||
Kind: kind,
|
||
Message: retryUpstreamMsg,
|
||
Detail: retryUpstreamDetail,
|
||
})
|
||
|
||
// If this stage fixed the signature issue, we stop; otherwise we may try the next stage.
|
||
if retryResp.StatusCode != http.StatusBadRequest || !isSignatureRelatedError(retryBody) {
|
||
respBody = retryBody
|
||
resp = &http.Response{
|
||
StatusCode: retryResp.StatusCode,
|
||
Header: retryResp.Header.Clone(),
|
||
Body: io.NopCloser(bytes.NewReader(retryBody)),
|
||
}
|
||
break
|
||
}
|
||
|
||
// Still signature-related; capture context and allow next stage.
|
||
respBody = retryBody
|
||
resp = &http.Response{
|
||
StatusCode: retryResp.StatusCode,
|
||
Header: retryResp.Header.Clone(),
|
||
Body: io.NopCloser(bytes.NewReader(retryBody)),
|
||
}
|
||
}
|
||
}
|
||
|
||
// 处理错误响应(重试后仍失败或不触发重试)
|
||
if resp.StatusCode >= 400 {
|
||
// 检测 prompt too long 错误,返回特殊错误类型供上层 fallback
|
||
if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) {
|
||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||
upstreamDetail := s.getUpstreamErrorDetail(respBody)
|
||
logBody, maxBytes := s.getLogConfig()
|
||
if logBody {
|
||
log.Printf("%s status=400 prompt_too_long=true upstream_message=%q request_id=%s body=%s", prefix, upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, maxBytes))
|
||
}
|
||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||
Platform: account.Platform,
|
||
AccountID: account.ID,
|
||
AccountName: account.Name,
|
||
UpstreamStatusCode: resp.StatusCode,
|
||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||
Kind: "prompt_too_long",
|
||
Message: upstreamMsg,
|
||
Detail: upstreamDetail,
|
||
})
|
||
return nil, &PromptTooLongError{
|
||
StatusCode: resp.StatusCode,
|
||
RequestID: resp.Header.Get("x-request-id"),
|
||
Body: respBody,
|
||
}
|
||
}
|
||
|
||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", isStickySession)
|
||
|
||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||
upstreamDetail := s.getUpstreamErrorDetail(respBody)
|
||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||
Platform: account.Platform,
|
||
AccountID: account.ID,
|
||
AccountName: account.Name,
|
||
UpstreamStatusCode: resp.StatusCode,
|
||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||
Kind: "failover",
|
||
Message: upstreamMsg,
|
||
Detail: upstreamDetail,
|
||
})
|
||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||
}
|
||
|
||
return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody)
|
||
}
|
||
}
|
||
|
||
requestID := resp.Header.Get("x-request-id")
|
||
if requestID != "" {
|
||
c.Header("x-request-id", requestID)
|
||
}
|
||
|
||
var usage *ClaudeUsage
|
||
var firstTokenMs *int
|
||
var clientDisconnect bool
|
||
if claudeReq.Stream {
|
||
// 客户端要求流式,直接透传转换
|
||
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
|
||
if err != nil {
|
||
log.Printf("%s status=stream_error error=%v", prefix, err)
|
||
return nil, err
|
||
}
|
||
usage = streamRes.usage
|
||
firstTokenMs = streamRes.firstTokenMs
|
||
clientDisconnect = streamRes.clientDisconnect
|
||
} else {
|
||
// 客户端要求非流式,收集流式响应后转换返回
|
||
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel)
|
||
if err != nil {
|
||
log.Printf("%s status=stream_collect_error error=%v", prefix, err)
|
||
return nil, err
|
||
}
|
||
usage = streamRes.usage
|
||
firstTokenMs = streamRes.firstTokenMs
|
||
}
|
||
|
||
return &ForwardResult{
|
||
RequestID: requestID,
|
||
Usage: *usage,
|
||
Model: originalModel, // 使用原始模型用于计费和日志
|
||
Stream: claudeReq.Stream,
|
||
Duration: time.Since(startTime),
|
||
FirstTokenMs: firstTokenMs,
|
||
ClientDisconnect: clientDisconnect,
|
||
}, nil
|
||
}
|
||
|
||
func isSignatureRelatedError(respBody []byte) bool {
|
||
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
|
||
if msg == "" {
|
||
// Fallback: best-effort scan of the raw payload.
|
||
msg = strings.ToLower(string(respBody))
|
||
}
|
||
|
||
// Keep this intentionally broad: different upstreams may use "signature" or "thought_signature".
|
||
if strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") {
|
||
return true
|
||
}
|
||
|
||
// Also detect thinking block structural errors:
|
||
// "Expected `thinking` or `redacted_thinking`, but found `text`"
|
||
if strings.Contains(msg, "expected") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) {
|
||
return true
|
||
}
|
||
|
||
return false
|
||
}
|
||
|
||
// isPromptTooLongError 检测是否为 prompt too long 错误
|
||
func isPromptTooLongError(respBody []byte) bool {
|
||
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
|
||
if msg == "" {
|
||
msg = strings.ToLower(string(respBody))
|
||
}
|
||
return strings.Contains(msg, "prompt is too long") ||
|
||
strings.Contains(msg, "request is too long") ||
|
||
strings.Contains(msg, "context length exceeded") ||
|
||
strings.Contains(msg, "max_tokens")
|
||
}
|
||
|
||
// isPassthroughErrorMessage 检查错误消息是否在透传白名单中
|
||
func isPassthroughErrorMessage(msg string) bool {
|
||
lower := strings.ToLower(msg)
|
||
for _, pattern := range antigravityPassthroughErrorMessages {
|
||
if strings.Contains(lower, pattern) {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// getPassthroughOrDefault 若消息在白名单内则返回原始消息,否则返回默认消息
|
||
func getPassthroughOrDefault(upstreamMsg, defaultMsg string) string {
|
||
if isPassthroughErrorMessage(upstreamMsg) {
|
||
return upstreamMsg
|
||
}
|
||
return defaultMsg
|
||
}
|
||
|
||
func extractAntigravityErrorMessage(body []byte) string {
|
||
var payload map[string]any
|
||
if err := json.Unmarshal(body, &payload); err != nil {
|
||
return ""
|
||
}
|
||
|
||
// Google-style: {"error": {"message": "..."}}
|
||
if errObj, ok := payload["error"].(map[string]any); ok {
|
||
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
||
return msg
|
||
}
|
||
}
|
||
|
||
// Fallback: top-level message
|
||
if msg, ok := payload["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
||
return msg
|
||
}
|
||
|
||
return ""
|
||
}
|
||
|
||
// stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request.
|
||
// This preserves the thinking content while avoiding signature validation errors.
|
||
// Note: redacted_thinking blocks are removed because they cannot be converted to text.
|
||
// It also disables top-level `thinking` to avoid upstream structural constraints for thinking mode.
|
||
func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error) {
|
||
if req == nil {
|
||
return false, nil
|
||
}
|
||
|
||
changed := false
|
||
if req.Thinking != nil {
|
||
req.Thinking = nil
|
||
changed = true
|
||
}
|
||
|
||
for i := range req.Messages {
|
||
raw := req.Messages[i].Content
|
||
if len(raw) == 0 {
|
||
continue
|
||
}
|
||
|
||
// If content is a string, nothing to strip.
|
||
var str string
|
||
if json.Unmarshal(raw, &str) == nil {
|
||
continue
|
||
}
|
||
|
||
// Otherwise treat as an array of blocks and convert thinking blocks to text.
|
||
var blocks []map[string]any
|
||
if err := json.Unmarshal(raw, &blocks); err != nil {
|
||
continue
|
||
}
|
||
|
||
filtered := make([]map[string]any, 0, len(blocks))
|
||
modifiedAny := false
|
||
for _, block := range blocks {
|
||
t, _ := block["type"].(string)
|
||
switch t {
|
||
case "thinking":
|
||
thinkingText, _ := block["thinking"].(string)
|
||
if thinkingText != "" {
|
||
filtered = append(filtered, map[string]any{
|
||
"type": "text",
|
||
"text": thinkingText,
|
||
})
|
||
}
|
||
modifiedAny = true
|
||
case "redacted_thinking":
|
||
modifiedAny = true
|
||
case "":
|
||
if thinkingText, hasThinking := block["thinking"].(string); hasThinking {
|
||
if thinkingText != "" {
|
||
filtered = append(filtered, map[string]any{
|
||
"type": "text",
|
||
"text": thinkingText,
|
||
})
|
||
}
|
||
modifiedAny = true
|
||
} else {
|
||
filtered = append(filtered, block)
|
||
}
|
||
default:
|
||
filtered = append(filtered, block)
|
||
}
|
||
}
|
||
|
||
if !modifiedAny {
|
||
continue
|
||
}
|
||
|
||
if len(filtered) == 0 {
|
||
filtered = append(filtered, map[string]any{
|
||
"type": "text",
|
||
"text": "(content removed)",
|
||
})
|
||
}
|
||
|
||
newRaw, err := json.Marshal(filtered)
|
||
if err != nil {
|
||
return changed, err
|
||
}
|
||
req.Messages[i].Content = newRaw
|
||
changed = true
|
||
}
|
||
|
||
return changed, nil
|
||
}
|
||
|
||
// stripSignatureSensitiveBlocksFromClaudeRequest is a stronger retry degradation that additionally converts
|
||
// tool blocks to plain text. Use this only after a thinking-only retry still fails with signature errors.
|
||
func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error) {
|
||
if req == nil {
|
||
return false, nil
|
||
}
|
||
|
||
changed := false
|
||
if req.Thinking != nil {
|
||
req.Thinking = nil
|
||
changed = true
|
||
}
|
||
|
||
for i := range req.Messages {
|
||
raw := req.Messages[i].Content
|
||
if len(raw) == 0 {
|
||
continue
|
||
}
|
||
|
||
// If content is a string, nothing to strip.
|
||
var str string
|
||
if json.Unmarshal(raw, &str) == nil {
|
||
continue
|
||
}
|
||
|
||
// Otherwise treat as an array of blocks and convert signature-sensitive blocks to text.
|
||
var blocks []map[string]any
|
||
if err := json.Unmarshal(raw, &blocks); err != nil {
|
||
continue
|
||
}
|
||
|
||
filtered := make([]map[string]any, 0, len(blocks))
|
||
modifiedAny := false
|
||
for _, block := range blocks {
|
||
t, _ := block["type"].(string)
|
||
switch t {
|
||
case "thinking":
|
||
// Convert thinking to text, skip if empty
|
||
thinkingText, _ := block["thinking"].(string)
|
||
if thinkingText != "" {
|
||
filtered = append(filtered, map[string]any{
|
||
"type": "text",
|
||
"text": thinkingText,
|
||
})
|
||
}
|
||
modifiedAny = true
|
||
case "redacted_thinking":
|
||
// Remove redacted_thinking (cannot convert encrypted content)
|
||
modifiedAny = true
|
||
case "tool_use":
|
||
// Convert tool_use to text to avoid upstream signature/thought_signature validation errors.
|
||
// This is a retry-only degradation path, so we prioritise request validity over tool semantics.
|
||
name, _ := block["name"].(string)
|
||
id, _ := block["id"].(string)
|
||
input := block["input"]
|
||
inputJSON, _ := json.Marshal(input)
|
||
text := "(tool_use)"
|
||
if name != "" {
|
||
text += " name=" + name
|
||
}
|
||
if id != "" {
|
||
text += " id=" + id
|
||
}
|
||
if len(inputJSON) > 0 && string(inputJSON) != "null" {
|
||
text += " input=" + string(inputJSON)
|
||
}
|
||
filtered = append(filtered, map[string]any{
|
||
"type": "text",
|
||
"text": text,
|
||
})
|
||
modifiedAny = true
|
||
case "tool_result":
|
||
// Convert tool_result to text so it stays consistent when tool_use is downgraded.
|
||
toolUseID, _ := block["tool_use_id"].(string)
|
||
isError, _ := block["is_error"].(bool)
|
||
content := block["content"]
|
||
contentJSON, _ := json.Marshal(content)
|
||
text := "(tool_result)"
|
||
if toolUseID != "" {
|
||
text += " tool_use_id=" + toolUseID
|
||
}
|
||
if isError {
|
||
text += " is_error=true"
|
||
}
|
||
if len(contentJSON) > 0 && string(contentJSON) != "null" {
|
||
text += "\n" + string(contentJSON)
|
||
}
|
||
filtered = append(filtered, map[string]any{
|
||
"type": "text",
|
||
"text": text,
|
||
})
|
||
modifiedAny = true
|
||
case "":
|
||
// Handle untyped block with "thinking" field
|
||
if thinkingText, hasThinking := block["thinking"].(string); hasThinking {
|
||
if thinkingText != "" {
|
||
filtered = append(filtered, map[string]any{
|
||
"type": "text",
|
||
"text": thinkingText,
|
||
})
|
||
}
|
||
modifiedAny = true
|
||
} else {
|
||
filtered = append(filtered, block)
|
||
}
|
||
default:
|
||
filtered = append(filtered, block)
|
||
}
|
||
}
|
||
|
||
if !modifiedAny {
|
||
continue
|
||
}
|
||
|
||
if len(filtered) == 0 {
|
||
// Keep request valid: upstream rejects empty content arrays.
|
||
filtered = append(filtered, map[string]any{
|
||
"type": "text",
|
||
"text": "(content removed)",
|
||
})
|
||
}
|
||
|
||
newRaw, err := json.Marshal(filtered)
|
||
if err != nil {
|
||
return changed, err
|
||
}
|
||
req.Messages[i].Content = newRaw
|
||
changed = true
|
||
}
|
||
|
||
return changed, nil
|
||
}
|
||
|
||
// ForwardGemini 转发 Gemini 协议请求
|
||
//
|
||
// 限流处理流程:
|
||
//
|
||
// 请求 → antigravityRetryLoop → 预检查(remaining>0? → 切换账号) → 发送上游
|
||
// ├─ 成功 → 正常返回
|
||
// └─ 429/503 → handleSmartRetry
|
||
// ├─ retryDelay >= 7s → 设置模型限流 + 清除粘性绑定 → 切换账号
|
||
// └─ retryDelay < 7s → 等待后重试 1 次
|
||
// ├─ 成功 → 正常返回
|
||
// └─ 失败 → 设置模型限流 + 清除粘性绑定 → 切换账号
|
||
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) {
|
||
startTime := time.Now()
|
||
|
||
sessionID := getSessionID(c)
|
||
prefix := logPrefix(sessionID, account.Name)
|
||
|
||
if strings.TrimSpace(originalModel) == "" {
|
||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing model in URL")
|
||
}
|
||
if strings.TrimSpace(action) == "" {
|
||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing action in URL")
|
||
}
|
||
if len(body) == 0 {
|
||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
|
||
}
|
||
|
||
// 解析请求以获取 image_size(用于图片计费)
|
||
imageSize := s.extractImageSize(body)
|
||
|
||
switch action {
|
||
case "generateContent", "streamGenerateContent":
|
||
// ok
|
||
case "countTokens":
|
||
// 直接返回空值,不透传上游
|
||
c.JSON(http.StatusOK, map[string]any{"totalTokens": 0})
|
||
return &ForwardResult{
|
||
RequestID: "",
|
||
Usage: ClaudeUsage{},
|
||
Model: originalModel,
|
||
Stream: false,
|
||
Duration: time.Since(time.Now()),
|
||
FirstTokenMs: nil,
|
||
}, nil
|
||
default:
|
||
return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action)
|
||
}
|
||
|
||
mappedModel := s.getMappedModel(account, originalModel)
|
||
if mappedModel == "" {
|
||
return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel))
|
||
}
|
||
|
||
// 获取 access_token
|
||
if s.tokenProvider == nil {
|
||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Antigravity token provider not configured")
|
||
}
|
||
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
|
||
if err != nil {
|
||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Failed to get upstream access token")
|
||
}
|
||
|
||
// 获取 project_id(部分账户类型可能没有)
|
||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||
|
||
// 代理 URL
|
||
proxyURL := ""
|
||
if account.ProxyID != nil && account.Proxy != nil {
|
||
proxyURL = account.Proxy.URL()
|
||
}
|
||
|
||
// Antigravity 上游要求必须包含身份提示词,注入到请求中
|
||
injectedBody, err := injectIdentityPatchToGeminiRequest(body)
|
||
if err != nil {
|
||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Invalid request body")
|
||
}
|
||
|
||
// 清理 Schema
|
||
if cleanedBody, err := cleanGeminiRequest(injectedBody); err == nil {
|
||
injectedBody = cleanedBody
|
||
log.Printf("[Antigravity] Cleaned request schema in forwarded request for account %s", account.Name)
|
||
} else {
|
||
log.Printf("[Antigravity] Failed to clean schema: %v", err)
|
||
}
|
||
|
||
// 包装请求
|
||
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody)
|
||
if err != nil {
|
||
return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build upstream request")
|
||
}
|
||
|
||
// Antigravity 上游只支持流式请求,统一使用 streamGenerateContent
|
||
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回
|
||
upstreamAction := "streamGenerateContent"
|
||
|
||
// 执行带重试的请求
|
||
result, err := s.antigravityRetryLoop(antigravityRetryLoopParams{
|
||
ctx: ctx,
|
||
prefix: prefix,
|
||
account: account,
|
||
proxyURL: proxyURL,
|
||
accessToken: accessToken,
|
||
action: upstreamAction,
|
||
body: wrappedBody,
|
||
c: c,
|
||
httpUpstream: s.httpUpstream,
|
||
settingService: s.settingService,
|
||
accountRepo: s.accountRepo,
|
||
handleError: s.handleUpstreamError,
|
||
requestedModel: originalModel,
|
||
isStickySession: isStickySession, // ForwardGemini 由上层判断粘性会话
|
||
groupID: 0, // ForwardGemini 方法没有 groupID,由上层处理粘性会话清除
|
||
sessionHash: "", // ForwardGemini 方法没有 sessionHash,由上层处理粘性会话清除
|
||
})
|
||
if err != nil {
|
||
// 检查是否是账号切换信号,转换为 UpstreamFailoverError 让 Handler 切换账号
|
||
if switchErr, ok := IsAntigravityAccountSwitchError(err); ok {
|
||
return nil, &UpstreamFailoverError{
|
||
StatusCode: http.StatusServiceUnavailable,
|
||
ForceCacheBilling: switchErr.IsStickySession,
|
||
}
|
||
}
|
||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
|
||
}
|
||
resp := result.resp
|
||
defer func() {
|
||
if resp != nil && resp.Body != nil {
|
||
_ = resp.Body.Close()
|
||
}
|
||
}()
|
||
|
||
// 处理错误响应
|
||
if resp.StatusCode >= 400 {
|
||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||
contentType := resp.Header.Get("Content-Type")
|
||
// 尽早关闭原始响应体,释放连接;后续逻辑仍可能需要读取 body,因此用内存副本重新包装。
|
||
_ = resp.Body.Close()
|
||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||
|
||
// 模型兜底:模型不存在且开启 fallback 时,自动用 fallback 模型重试一次
|
||
if s.settingService != nil && s.settingService.IsModelFallbackEnabled(ctx) &&
|
||
isModelNotFoundError(resp.StatusCode, respBody) {
|
||
fallbackModel := s.settingService.GetFallbackModel(ctx, PlatformAntigravity)
|
||
if fallbackModel != "" && fallbackModel != mappedModel {
|
||
log.Printf("[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name)
|
||
|
||
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, injectedBody)
|
||
if err == nil {
|
||
fallbackReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, fallbackWrapped)
|
||
if err == nil {
|
||
fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency)
|
||
if err == nil && fallbackResp.StatusCode < 400 {
|
||
_ = resp.Body.Close()
|
||
resp = fallbackResp
|
||
} else if fallbackResp != nil {
|
||
_ = fallbackResp.Body.Close()
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// fallback 成功:继续按正常响应处理
|
||
if resp.StatusCode < 400 {
|
||
goto handleSuccess
|
||
}
|
||
|
||
requestID := resp.Header.Get("x-request-id")
|
||
if requestID != "" {
|
||
c.Header("x-request-id", requestID)
|
||
}
|
||
|
||
unwrapped, unwrapErr := s.unwrapV1InternalResponse(respBody)
|
||
unwrappedForOps := unwrapped
|
||
if unwrapErr != nil || len(unwrappedForOps) == 0 {
|
||
unwrappedForOps = respBody
|
||
}
|
||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", isStickySession)
|
||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(unwrappedForOps))
|
||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||
upstreamDetail := s.getUpstreamErrorDetail(unwrappedForOps)
|
||
|
||
// Always record upstream context for Ops error logs, even when we will failover.
|
||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||
|
||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||
Platform: account.Platform,
|
||
AccountID: account.ID,
|
||
AccountName: account.Name,
|
||
UpstreamStatusCode: resp.StatusCode,
|
||
UpstreamRequestID: requestID,
|
||
Kind: "failover",
|
||
Message: upstreamMsg,
|
||
Detail: upstreamDetail,
|
||
})
|
||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: unwrappedForOps}
|
||
}
|
||
if contentType == "" {
|
||
contentType = "application/json"
|
||
}
|
||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||
Platform: account.Platform,
|
||
AccountID: account.ID,
|
||
AccountName: account.Name,
|
||
UpstreamStatusCode: resp.StatusCode,
|
||
UpstreamRequestID: requestID,
|
||
Kind: "http_error",
|
||
Message: upstreamMsg,
|
||
Detail: upstreamDetail,
|
||
})
|
||
log.Printf("[antigravity-Forward] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(unwrappedForOps, 500))
|
||
c.Data(resp.StatusCode, contentType, unwrappedForOps)
|
||
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
|
||
}
|
||
|
||
handleSuccess:
|
||
requestID := resp.Header.Get("x-request-id")
|
||
if requestID != "" {
|
||
c.Header("x-request-id", requestID)
|
||
}
|
||
|
||
var usage *ClaudeUsage
|
||
var firstTokenMs *int
|
||
var clientDisconnect bool
|
||
|
||
if stream {
|
||
// 客户端要求流式,直接透传
|
||
streamRes, err := s.handleGeminiStreamingResponse(c, resp, startTime)
|
||
if err != nil {
|
||
log.Printf("%s status=stream_error error=%v", prefix, err)
|
||
return nil, err
|
||
}
|
||
usage = streamRes.usage
|
||
firstTokenMs = streamRes.firstTokenMs
|
||
clientDisconnect = streamRes.clientDisconnect
|
||
} else {
|
||
// 客户端要求非流式,收集流式响应后返回
|
||
streamRes, err := s.handleGeminiStreamToNonStreaming(c, resp, startTime)
|
||
if err != nil {
|
||
log.Printf("%s status=stream_collect_error error=%v", prefix, err)
|
||
return nil, err
|
||
}
|
||
usage = streamRes.usage
|
||
firstTokenMs = streamRes.firstTokenMs
|
||
}
|
||
|
||
if usage == nil {
|
||
usage = &ClaudeUsage{}
|
||
}
|
||
|
||
// 判断是否为图片生成模型
|
||
imageCount := 0
|
||
if isImageGenerationModel(mappedModel) {
|
||
// Gemini 图片生成 API 每次请求只生成一张图片(API 限制)
|
||
imageCount = 1
|
||
}
|
||
|
||
return &ForwardResult{
|
||
RequestID: requestID,
|
||
Usage: *usage,
|
||
Model: originalModel,
|
||
Stream: stream,
|
||
Duration: time.Since(startTime),
|
||
FirstTokenMs: firstTokenMs,
|
||
ClientDisconnect: clientDisconnect,
|
||
ImageCount: imageCount,
|
||
ImageSize: imageSize,
|
||
}, nil
|
||
}
|
||
|
||
func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
||
switch statusCode {
|
||
case 401, 403, 429, 529:
|
||
return true
|
||
default:
|
||
return statusCode >= 500
|
||
}
|
||
}
|
||
|
||
// sleepAntigravityBackoffWithContext 带 context 取消检查的退避等待
|
||
// 返回 true 表示正常完成等待,false 表示 context 已取消
|
||
func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
|
||
delay := antigravityRetryBaseDelay * time.Duration(1<<uint(attempt-1))
|
||
if delay > antigravityRetryMaxDelay {
|
||
delay = antigravityRetryMaxDelay
|
||
}
|
||
|
||
// +/- 20% jitter
|
||
r := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
|
||
jitter := time.Duration(float64(delay) * 0.2 * (r.Float64()*2 - 1))
|
||
sleepFor := delay + jitter
|
||
if sleepFor < 0 {
|
||
sleepFor = 0
|
||
}
|
||
|
||
select {
|
||
case <-ctx.Done():
|
||
return false
|
||
case <-time.After(sleepFor):
|
||
return true
|
||
}
|
||
}
|
||
|
||
// setModelRateLimitByModelName 使用官方模型 ID 设置模型级限流
|
||
// 直接使用上游返回的模型 ID(如 claude-sonnet-4-5)作为限流 key
|
||
// 返回是否已成功设置(若模型名为空或 repo 为 nil 将返回 false)
|
||
func setModelRateLimitByModelName(ctx context.Context, repo AccountRepository, accountID int64, modelName, prefix string, statusCode int, resetAt time.Time, afterSmartRetry bool) bool {
|
||
if repo == nil || modelName == "" {
|
||
return false
|
||
}
|
||
// 直接使用官方模型 ID 作为 key,不再转换为 scope
|
||
if err := repo.SetModelRateLimit(ctx, accountID, modelName, resetAt); err != nil {
|
||
log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", prefix, statusCode, modelName, err)
|
||
return false
|
||
}
|
||
if afterSmartRetry {
|
||
log.Printf("%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second))
|
||
} else {
|
||
log.Printf("%s status=%d model_rate_limited model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second))
|
||
}
|
||
return true
|
||
}
|
||
|
||
func antigravityFallbackCooldownSeconds() (time.Duration, bool) {
|
||
raw := strings.TrimSpace(os.Getenv(antigravityFallbackSecondsEnv))
|
||
if raw == "" {
|
||
return 0, false
|
||
}
|
||
seconds, err := strconv.Atoi(raw)
|
||
if err != nil || seconds <= 0 {
|
||
return 0, false
|
||
}
|
||
return time.Duration(seconds) * time.Second, true
|
||
}
|
||
|
||
// antigravitySmartRetryInfo 智能重试所需的信息
|
||
type antigravitySmartRetryInfo struct {
|
||
RetryDelay time.Duration // 重试延迟时间
|
||
ModelName string // 限流的模型名称(如 "claude-sonnet-4-5")
|
||
}
|
||
|
||
// parseAntigravitySmartRetryInfo 解析 Google RPC RetryInfo 和 ErrorInfo 信息
|
||
// 返回解析结果,如果解析失败或不满足条件返回 nil
|
||
//
|
||
// 支持两种情况:
|
||
// 1. 429 RESOURCE_EXHAUSTED + RATE_LIMIT_EXCEEDED:
|
||
// - error.status == "RESOURCE_EXHAUSTED"
|
||
// - error.details[].reason == "RATE_LIMIT_EXCEEDED"
|
||
//
|
||
// 2. 503 UNAVAILABLE + MODEL_CAPACITY_EXHAUSTED:
|
||
// - error.status == "UNAVAILABLE"
|
||
// - error.details[].reason == "MODEL_CAPACITY_EXHAUSTED"
|
||
//
|
||
// 必须满足以下条件才会返回有效值:
|
||
// - error.details[] 中存在 @type == "type.googleapis.com/google.rpc.RetryInfo" 的元素
|
||
// - 该元素包含 retryDelay 字段,格式为 "数字s"(如 "0.201506475s")
|
||
func parseAntigravitySmartRetryInfo(body []byte) *antigravitySmartRetryInfo {
|
||
var parsed map[string]any
|
||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||
return nil
|
||
}
|
||
|
||
errObj, ok := parsed["error"].(map[string]any)
|
||
if !ok {
|
||
return nil
|
||
}
|
||
|
||
// 检查 status 是否符合条件
|
||
// 情况1: 429 RESOURCE_EXHAUSTED (需要进一步检查 reason == RATE_LIMIT_EXCEEDED)
|
||
// 情况2: 503 UNAVAILABLE (需要进一步检查 reason == MODEL_CAPACITY_EXHAUSTED)
|
||
status, _ := errObj["status"].(string)
|
||
isResourceExhausted := status == googleRPCStatusResourceExhausted
|
||
isUnavailable := status == googleRPCStatusUnavailable
|
||
|
||
if !isResourceExhausted && !isUnavailable {
|
||
return nil
|
||
}
|
||
|
||
details, ok := errObj["details"].([]any)
|
||
if !ok {
|
||
return nil
|
||
}
|
||
|
||
var retryDelay time.Duration
|
||
var modelName string
|
||
var hasRateLimitExceeded bool // 429 需要此 reason
|
||
var hasModelCapacityExhausted bool // 503 需要此 reason
|
||
|
||
for _, d := range details {
|
||
dm, ok := d.(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
atType, _ := dm["@type"].(string)
|
||
|
||
// 从 ErrorInfo 提取模型名称和 reason
|
||
if atType == googleRPCTypeErrorInfo {
|
||
if meta, ok := dm["metadata"].(map[string]any); ok {
|
||
if model, ok := meta["model"].(string); ok {
|
||
modelName = model
|
||
}
|
||
}
|
||
// 检查 reason
|
||
if reason, ok := dm["reason"].(string); ok {
|
||
if reason == googleRPCReasonModelCapacityExhausted {
|
||
hasModelCapacityExhausted = true
|
||
}
|
||
if reason == googleRPCReasonRateLimitExceeded {
|
||
hasRateLimitExceeded = true
|
||
}
|
||
}
|
||
continue
|
||
}
|
||
|
||
// 从 RetryInfo 提取重试延迟
|
||
if atType == googleRPCTypeRetryInfo {
|
||
delay, ok := dm["retryDelay"].(string)
|
||
if !ok || delay == "" {
|
||
continue
|
||
}
|
||
// 使用 time.ParseDuration 解析,支持所有 Go duration 格式
|
||
// 例如: "0.5s", "10s", "4m50s", "1h30m", "200ms" 等
|
||
dur, err := time.ParseDuration(delay)
|
||
if err != nil {
|
||
log.Printf("[Antigravity] failed to parse retryDelay: %s error=%v", delay, err)
|
||
continue
|
||
}
|
||
retryDelay = dur
|
||
}
|
||
}
|
||
|
||
// 验证条件
|
||
// 情况1: RESOURCE_EXHAUSTED 需要有 RATE_LIMIT_EXCEEDED reason
|
||
// 情况2: UNAVAILABLE 需要有 MODEL_CAPACITY_EXHAUSTED reason
|
||
if isResourceExhausted && !hasRateLimitExceeded {
|
||
return nil
|
||
}
|
||
if isUnavailable && !hasModelCapacityExhausted {
|
||
return nil
|
||
}
|
||
|
||
// 必须有模型名才返回有效结果
|
||
if modelName == "" {
|
||
return nil
|
||
}
|
||
|
||
// 如果上游未提供 retryDelay,使用默认限流时间
|
||
if retryDelay <= 0 {
|
||
retryDelay = antigravityDefaultRateLimitDuration
|
||
}
|
||
|
||
return &antigravitySmartRetryInfo{
|
||
RetryDelay: retryDelay,
|
||
ModelName: modelName,
|
||
}
|
||
}
|
||
|
||
// shouldTriggerAntigravitySmartRetry 判断是否应该触发智能重试
|
||
// 返回:
|
||
// - shouldRetry: 是否应该智能重试(retryDelay < antigravityRateLimitThreshold)
|
||
// - shouldRateLimitModel: 是否应该限流模型(retryDelay >= antigravityRateLimitThreshold)
|
||
// - waitDuration: 等待时间(智能重试时使用,shouldRateLimitModel=true 时为 0)
|
||
// - modelName: 限流的模型名称
|
||
func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shouldRetry bool, shouldRateLimitModel bool, waitDuration time.Duration, modelName string) {
|
||
if account.Platform != PlatformAntigravity {
|
||
return false, false, 0, ""
|
||
}
|
||
|
||
info := parseAntigravitySmartRetryInfo(respBody)
|
||
if info == nil {
|
||
return false, false, 0, ""
|
||
}
|
||
|
||
// retryDelay >= 阈值:直接限流模型,不重试
|
||
// 注意:如果上游未提供 retryDelay,parseAntigravitySmartRetryInfo 已设置为默认 30s
|
||
if info.RetryDelay >= antigravityRateLimitThreshold {
|
||
return false, true, info.RetryDelay, info.ModelName
|
||
}
|
||
|
||
// retryDelay < 阈值:智能重试
|
||
waitDuration = info.RetryDelay
|
||
if waitDuration < antigravitySmartRetryMinWait {
|
||
waitDuration = antigravitySmartRetryMinWait
|
||
}
|
||
|
||
return true, false, waitDuration, info.ModelName
|
||
}
|
||
|
||
// handleModelRateLimitParams 模型级限流处理参数
|
||
type handleModelRateLimitParams struct {
|
||
ctx context.Context
|
||
prefix string
|
||
account *Account
|
||
statusCode int
|
||
body []byte
|
||
cache GatewayCache
|
||
groupID int64
|
||
sessionHash string
|
||
isStickySession bool
|
||
}
|
||
|
||
// handleModelRateLimitResult 模型级限流处理结果
|
||
type handleModelRateLimitResult struct {
|
||
Handled bool // 是否已处理
|
||
ShouldRetry bool // 是否等待后重试
|
||
WaitDuration time.Duration // 等待时间
|
||
SwitchError *AntigravityAccountSwitchError // 账号切换错误
|
||
}
|
||
|
||
// handleModelRateLimit 处理模型级限流(在原有逻辑之前调用)
|
||
// 仅处理 429/503,解析模型名和 retryDelay
|
||
// - retryDelay < antigravityRateLimitThreshold: 返回 ShouldRetry=true,由调用方等待后重试
|
||
// - retryDelay >= antigravityRateLimitThreshold: 设置模型限流 + 清除粘性会话 + 返回 SwitchError
|
||
func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimitParams) *handleModelRateLimitResult {
|
||
if p.statusCode != 429 && p.statusCode != 503 {
|
||
return &handleModelRateLimitResult{Handled: false}
|
||
}
|
||
|
||
info := parseAntigravitySmartRetryInfo(p.body)
|
||
if info == nil || info.ModelName == "" {
|
||
return &handleModelRateLimitResult{Handled: false}
|
||
}
|
||
|
||
// < antigravityRateLimitThreshold: 等待后重试
|
||
if info.RetryDelay < antigravityRateLimitThreshold {
|
||
log.Printf("%s status=%d model_rate_limit_wait model=%s wait=%v",
|
||
p.prefix, p.statusCode, info.ModelName, info.RetryDelay)
|
||
return &handleModelRateLimitResult{
|
||
Handled: true,
|
||
ShouldRetry: true,
|
||
WaitDuration: info.RetryDelay,
|
||
}
|
||
}
|
||
|
||
// >= antigravityRateLimitThreshold: 设置限流 + 清除粘性会话 + 切换账号
|
||
s.setModelRateLimitAndClearSession(p, info)
|
||
|
||
return &handleModelRateLimitResult{
|
||
Handled: true,
|
||
SwitchError: &AntigravityAccountSwitchError{
|
||
OriginalAccountID: p.account.ID,
|
||
RateLimitedModel: info.ModelName,
|
||
IsStickySession: p.isStickySession,
|
||
},
|
||
}
|
||
}
|
||
|
||
// setModelRateLimitAndClearSession 设置模型限流并清除粘性会话
|
||
func (s *AntigravityGatewayService) setModelRateLimitAndClearSession(p *handleModelRateLimitParams, info *antigravitySmartRetryInfo) {
|
||
resetAt := time.Now().Add(info.RetryDelay)
|
||
log.Printf("%s status=%d model_rate_limited model=%s account=%d reset_in=%v",
|
||
p.prefix, p.statusCode, info.ModelName, p.account.ID, info.RetryDelay)
|
||
|
||
// 设置模型限流状态(数据库)
|
||
if err := s.accountRepo.SetModelRateLimit(p.ctx, p.account.ID, info.ModelName, resetAt); err != nil {
|
||
log.Printf("%s model_rate_limit_failed model=%s error=%v", p.prefix, info.ModelName, err)
|
||
}
|
||
|
||
// 立即更新 Redis 快照中账号的限流状态,避免并发请求重复选中
|
||
s.updateAccountModelRateLimitInCache(p.ctx, p.account, info.ModelName, resetAt)
|
||
|
||
// 清除粘性会话绑定
|
||
if p.cache != nil && p.sessionHash != "" {
|
||
_ = p.cache.DeleteSessionAccountID(p.ctx, p.groupID, p.sessionHash)
|
||
}
|
||
}
|
||
|
||
// updateAccountModelRateLimitInCache 立即更新 Redis 中账号的模型限流状态
|
||
func (s *AntigravityGatewayService) updateAccountModelRateLimitInCache(ctx context.Context, account *Account, modelKey string, resetAt time.Time) {
|
||
if s.schedulerSnapshot == nil || account == nil || modelKey == "" {
|
||
return
|
||
}
|
||
|
||
// 更新账号对象的 Extra 字段
|
||
if account.Extra == nil {
|
||
account.Extra = make(map[string]any)
|
||
}
|
||
|
||
limits, _ := account.Extra["model_rate_limits"].(map[string]any)
|
||
if limits == nil {
|
||
limits = make(map[string]any)
|
||
account.Extra["model_rate_limits"] = limits
|
||
}
|
||
|
||
limits[modelKey] = map[string]any{
|
||
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
|
||
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
|
||
}
|
||
|
||
// 更新 Redis 快照
|
||
if err := s.schedulerSnapshot.UpdateAccountInCache(ctx, account); err != nil {
|
||
log.Printf("[antigravity-Forward] cache_update_failed account=%d model=%s err=%v", account.ID, modelKey, err)
|
||
}
|
||
}
|
||
|
||
func (s *AntigravityGatewayService) handleUpstreamError(
|
||
ctx context.Context, prefix string, account *Account,
|
||
statusCode int, headers http.Header, body []byte,
|
||
requestedModel string,
|
||
groupID int64, sessionHash string, isStickySession bool,
|
||
) *handleModelRateLimitResult {
|
||
// 模型级限流处理(优先)
|
||
result := s.handleModelRateLimit(&handleModelRateLimitParams{
|
||
ctx: ctx,
|
||
prefix: prefix,
|
||
account: account,
|
||
statusCode: statusCode,
|
||
body: body,
|
||
cache: s.cache,
|
||
groupID: groupID,
|
||
sessionHash: sessionHash,
|
||
isStickySession: isStickySession,
|
||
})
|
||
if result.Handled {
|
||
return result
|
||
}
|
||
|
||
// 503 仅处理模型限流(MODEL_CAPACITY_EXHAUSTED),非模型限流不做额外处理
|
||
// 避免将普通的 503 错误误判为账号问题
|
||
if statusCode == 503 {
|
||
return nil
|
||
}
|
||
|
||
// 429:尝试解析模型级限流,解析失败时兜底为账号级限流
|
||
if statusCode == 429 {
|
||
if logBody, maxBytes := s.getLogConfig(); logBody {
|
||
log.Printf("[Antigravity-Debug] 429 response body: %s", truncateString(string(body), maxBytes))
|
||
}
|
||
|
||
resetAt := ParseGeminiRateLimitResetTime(body)
|
||
defaultDur := s.getDefaultRateLimitDuration()
|
||
|
||
// 尝试解析模型 key 并设置模型级限流
|
||
modelKey := resolveAntigravityModelKey(requestedModel)
|
||
if modelKey != "" {
|
||
ra := s.resolveResetTime(resetAt, defaultDur)
|
||
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelKey, ra); err != nil {
|
||
log.Printf("%s status=429 model_rate_limit_set_failed model=%s error=%v", prefix, modelKey, err)
|
||
} else {
|
||
log.Printf("%s status=429 model_rate_limited model=%s account=%d reset_at=%v reset_in=%v",
|
||
prefix, modelKey, account.ID, ra.Format("15:04:05"), time.Until(ra).Truncate(time.Second))
|
||
s.updateAccountModelRateLimitInCache(ctx, account, modelKey, ra)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// 无法解析模型 key,兜底为账号级限流
|
||
ra := s.resolveResetTime(resetAt, defaultDur)
|
||
log.Printf("%s status=429 rate_limited account=%d reset_at=%v reset_in=%v (fallback)",
|
||
prefix, account.ID, ra.Format("15:04:05"), time.Until(ra).Truncate(time.Second))
|
||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, ra); err != nil {
|
||
log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err)
|
||
}
|
||
return nil
|
||
}
|
||
// 其他错误码继续使用 rateLimitService
|
||
if s.rateLimitService == nil {
|
||
return nil
|
||
}
|
||
shouldDisable := s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
|
||
if shouldDisable {
|
||
log.Printf("%s status=%d marked_error", prefix, statusCode)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// getDefaultRateLimitDuration 获取默认限流时间
|
||
func (s *AntigravityGatewayService) getDefaultRateLimitDuration() time.Duration {
|
||
defaultDur := antigravityDefaultRateLimitDuration
|
||
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes > 0 {
|
||
defaultDur = time.Duration(s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes) * time.Minute
|
||
}
|
||
if override, ok := antigravityFallbackCooldownSeconds(); ok {
|
||
defaultDur = override
|
||
}
|
||
return defaultDur
|
||
}
|
||
|
||
// resolveResetTime 根据解析的重置时间或默认时长计算重置时间点
|
||
func (s *AntigravityGatewayService) resolveResetTime(resetAt *int64, defaultDur time.Duration) time.Time {
|
||
if resetAt != nil {
|
||
return time.Unix(*resetAt, 0)
|
||
}
|
||
return time.Now().Add(defaultDur)
|
||
}
|
||
|
||
type antigravityStreamResult struct {
|
||
usage *ClaudeUsage
|
||
firstTokenMs *int
|
||
clientDisconnect bool // 客户端是否在流式传输过程中断开
|
||
}
|
||
|
||
// antigravityClientWriter 封装流式响应的客户端写入,自动检测断开并标记。
|
||
// 断开后所有写入操作变为 no-op,调用方通过 Disconnected() 判断是否继续 drain 上游。
|
||
type antigravityClientWriter struct {
|
||
w gin.ResponseWriter
|
||
flusher http.Flusher
|
||
disconnected bool
|
||
prefix string // 日志前缀,标识来源方法
|
||
}
|
||
|
||
func newAntigravityClientWriter(w gin.ResponseWriter, flusher http.Flusher, prefix string) *antigravityClientWriter {
|
||
return &antigravityClientWriter{w: w, flusher: flusher, prefix: prefix}
|
||
}
|
||
|
||
// Write 写入数据到客户端,写入失败时标记断开并返回 false
|
||
func (cw *antigravityClientWriter) Write(p []byte) bool {
|
||
if cw.disconnected {
|
||
return false
|
||
}
|
||
if _, err := cw.w.Write(p); err != nil {
|
||
cw.markDisconnected()
|
||
return false
|
||
}
|
||
cw.flusher.Flush()
|
||
return true
|
||
}
|
||
|
||
// Fprintf 格式化写入数据到客户端,写入失败时标记断开并返回 false
|
||
func (cw *antigravityClientWriter) Fprintf(format string, args ...any) bool {
|
||
if cw.disconnected {
|
||
return false
|
||
}
|
||
if _, err := fmt.Fprintf(cw.w, format, args...); err != nil {
|
||
cw.markDisconnected()
|
||
return false
|
||
}
|
||
cw.flusher.Flush()
|
||
return true
|
||
}
|
||
|
||
func (cw *antigravityClientWriter) Disconnected() bool { return cw.disconnected }
|
||
|
||
func (cw *antigravityClientWriter) markDisconnected() {
|
||
cw.disconnected = true
|
||
log.Printf("Client disconnected during streaming (%s), continuing to drain upstream for billing", cw.prefix)
|
||
}
|
||
|
||
// handleStreamReadError 处理上游读取错误的通用逻辑。
|
||
// 返回 (clientDisconnect, handled):handled=true 表示错误已处理,调用方应返回已收集的 usage。
|
||
func handleStreamReadError(err error, clientDisconnected bool, prefix string) (disconnect bool, handled bool) {
|
||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||
log.Printf("Context canceled during streaming (%s), returning collected usage", prefix)
|
||
return true, true
|
||
}
|
||
if clientDisconnected {
|
||
log.Printf("Upstream read error after client disconnect (%s): %v, returning collected usage", prefix, err)
|
||
return true, true
|
||
}
|
||
return false, false
|
||
}
|
||
|
||
func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) {
|
||
c.Status(resp.StatusCode)
|
||
c.Header("Cache-Control", "no-cache")
|
||
c.Header("Connection", "keep-alive")
|
||
c.Header("X-Accel-Buffering", "no")
|
||
|
||
contentType := resp.Header.Get("Content-Type")
|
||
if contentType == "" {
|
||
contentType = "text/event-stream; charset=utf-8"
|
||
}
|
||
c.Header("Content-Type", contentType)
|
||
|
||
flusher, ok := c.Writer.(http.Flusher)
|
||
if !ok {
|
||
return nil, errors.New("streaming not supported")
|
||
}
|
||
|
||
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
|
||
scanner := bufio.NewScanner(resp.Body)
|
||
maxLineSize := defaultMaxLineSize
|
||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
||
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
|
||
}
|
||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||
usage := &ClaudeUsage{}
|
||
var firstTokenMs *int
|
||
|
||
type scanEvent struct {
|
||
line string
|
||
err error
|
||
}
|
||
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
|
||
events := make(chan scanEvent, 16)
|
||
done := make(chan struct{})
|
||
sendEvent := func(ev scanEvent) bool {
|
||
select {
|
||
case events <- ev:
|
||
return true
|
||
case <-done:
|
||
return false
|
||
}
|
||
}
|
||
var lastReadAt int64
|
||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||
go func() {
|
||
defer close(events)
|
||
for scanner.Scan() {
|
||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||
return
|
||
}
|
||
}
|
||
if err := scanner.Err(); err != nil {
|
||
_ = sendEvent(scanEvent{err: err})
|
||
}
|
||
}()
|
||
defer close(done)
|
||
|
||
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
|
||
streamInterval := time.Duration(0)
|
||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||
}
|
||
var intervalTicker *time.Ticker
|
||
if streamInterval > 0 {
|
||
intervalTicker = time.NewTicker(streamInterval)
|
||
defer intervalTicker.Stop()
|
||
}
|
||
var intervalCh <-chan time.Time
|
||
if intervalTicker != nil {
|
||
intervalCh = intervalTicker.C
|
||
}
|
||
|
||
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity gemini")
|
||
|
||
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
||
errorEventSent := false
|
||
sendErrorEvent := func(reason string) {
|
||
if errorEventSent || cw.Disconnected() {
|
||
return
|
||
}
|
||
errorEventSent = true
|
||
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
||
flusher.Flush()
|
||
}
|
||
|
||
for {
|
||
select {
|
||
case ev, ok := <-events:
|
||
if !ok {
|
||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}, nil
|
||
}
|
||
if ev.err != nil {
|
||
if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity gemini"); handled {
|
||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil
|
||
}
|
||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
|
||
sendErrorEvent("response_too_large")
|
||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
|
||
}
|
||
sendErrorEvent("stream_read_error")
|
||
return nil, ev.err
|
||
}
|
||
|
||
line := ev.line
|
||
trimmed := strings.TrimRight(line, "\r\n")
|
||
if strings.HasPrefix(trimmed, "data:") {
|
||
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
|
||
if payload == "" || payload == "[DONE]" {
|
||
cw.Fprintf("%s\n", line)
|
||
continue
|
||
}
|
||
|
||
// 解包 v1internal 响应
|
||
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
|
||
if parseErr == nil && inner != nil {
|
||
payload = string(inner)
|
||
}
|
||
|
||
// 解析 usage
|
||
var parsed map[string]any
|
||
if json.Unmarshal(inner, &parsed) == nil {
|
||
if u := extractGeminiUsage(parsed); u != nil {
|
||
usage = u
|
||
}
|
||
// Check for MALFORMED_FUNCTION_CALL
|
||
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
|
||
if cand, ok := candidates[0].(map[string]any); ok {
|
||
if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" {
|
||
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward stream")
|
||
if content, ok := cand["content"]; ok {
|
||
if b, err := json.Marshal(content); err == nil {
|
||
log.Printf("[Antigravity] Malformed content: %s", string(b))
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if firstTokenMs == nil {
|
||
ms := int(time.Since(startTime).Milliseconds())
|
||
firstTokenMs = &ms
|
||
}
|
||
|
||
cw.Fprintf("data: %s\n\n", payload)
|
||
continue
|
||
}
|
||
|
||
cw.Fprintf("%s\n", line)
|
||
|
||
case <-intervalCh:
|
||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||
if time.Since(lastRead) < streamInterval {
|
||
continue
|
||
}
|
||
if cw.Disconnected() {
|
||
log.Printf("Upstream timeout after client disconnect (antigravity gemini), returning collected usage")
|
||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||
}
|
||
log.Printf("Stream data interval timeout (antigravity)")
|
||
sendErrorEvent("stream_timeout")
|
||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||
}
|
||
}
|
||
}
|
||
|
||
// handleGeminiStreamToNonStreaming 读取上游流式响应,合并为非流式响应返回给客户端
|
||
// Gemini 流式响应是增量的,需要累积所有 chunk 的内容
|
||
func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) {
|
||
scanner := bufio.NewScanner(resp.Body)
|
||
maxLineSize := defaultMaxLineSize
|
||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
||
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
|
||
}
|
||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||
|
||
usage := &ClaudeUsage{}
|
||
var firstTokenMs *int
|
||
var last map[string]any
|
||
var lastWithParts map[string]any
|
||
var collectedImageParts []map[string]any // 收集所有包含图片的 parts
|
||
var collectedTextParts []string // 收集所有文本片段
|
||
|
||
type scanEvent struct {
|
||
line string
|
||
err error
|
||
}
|
||
|
||
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
|
||
events := make(chan scanEvent, 16)
|
||
done := make(chan struct{})
|
||
sendEvent := func(ev scanEvent) bool {
|
||
select {
|
||
case events <- ev:
|
||
return true
|
||
case <-done:
|
||
return false
|
||
}
|
||
}
|
||
|
||
var lastReadAt int64
|
||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||
go func() {
|
||
defer close(events)
|
||
for scanner.Scan() {
|
||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||
return
|
||
}
|
||
}
|
||
if err := scanner.Err(); err != nil {
|
||
_ = sendEvent(scanEvent{err: err})
|
||
}
|
||
}()
|
||
defer close(done)
|
||
|
||
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
|
||
streamInterval := time.Duration(0)
|
||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||
}
|
||
var intervalTicker *time.Ticker
|
||
if streamInterval > 0 {
|
||
intervalTicker = time.NewTicker(streamInterval)
|
||
defer intervalTicker.Stop()
|
||
}
|
||
var intervalCh <-chan time.Time
|
||
if intervalTicker != nil {
|
||
intervalCh = intervalTicker.C
|
||
}
|
||
|
||
for {
|
||
select {
|
||
case ev, ok := <-events:
|
||
if !ok {
|
||
// 流结束,返回收集的响应
|
||
goto returnResponse
|
||
}
|
||
if ev.err != nil {
|
||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||
log.Printf("SSE line too long (antigravity non-stream): max_size=%d error=%v", maxLineSize, ev.err)
|
||
}
|
||
return nil, ev.err
|
||
}
|
||
|
||
line := ev.line
|
||
trimmed := strings.TrimRight(line, "\r\n")
|
||
|
||
if !strings.HasPrefix(trimmed, "data:") {
|
||
continue
|
||
}
|
||
|
||
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
|
||
if payload == "" || payload == "[DONE]" {
|
||
continue
|
||
}
|
||
|
||
// 解包 v1internal 响应
|
||
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
|
||
if parseErr != nil {
|
||
continue
|
||
}
|
||
|
||
var parsed map[string]any
|
||
if err := json.Unmarshal(inner, &parsed); err != nil {
|
||
continue
|
||
}
|
||
|
||
// 记录首 token 时间
|
||
if firstTokenMs == nil {
|
||
ms := int(time.Since(startTime).Milliseconds())
|
||
firstTokenMs = &ms
|
||
}
|
||
|
||
last = parsed
|
||
|
||
// 提取 usage
|
||
if u := extractGeminiUsage(parsed); u != nil {
|
||
usage = u
|
||
}
|
||
|
||
// Check for MALFORMED_FUNCTION_CALL
|
||
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
|
||
if cand, ok := candidates[0].(map[string]any); ok {
|
||
if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" {
|
||
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward non-stream collect")
|
||
if content, ok := cand["content"]; ok {
|
||
if b, err := json.Marshal(content); err == nil {
|
||
log.Printf("[Antigravity] Malformed content: %s", string(b))
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 保留最后一个有 parts 的响应
|
||
if parts := extractGeminiParts(parsed); len(parts) > 0 {
|
||
lastWithParts = parsed
|
||
// 收集包含图片和文本的 parts
|
||
for _, part := range parts {
|
||
if inlineData, ok := part["inlineData"].(map[string]any); ok {
|
||
collectedImageParts = append(collectedImageParts, part)
|
||
_ = inlineData // 避免 unused 警告
|
||
}
|
||
if text, ok := part["text"].(string); ok && text != "" {
|
||
collectedTextParts = append(collectedTextParts, text)
|
||
}
|
||
}
|
||
}
|
||
|
||
case <-intervalCh:
|
||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||
if time.Since(lastRead) < streamInterval {
|
||
continue
|
||
}
|
||
log.Printf("Stream data interval timeout (antigravity non-stream)")
|
||
return nil, fmt.Errorf("stream data interval timeout")
|
||
}
|
||
}
|
||
|
||
returnResponse:
|
||
// 选择最后一个有效响应
|
||
finalResponse := pickGeminiCollectResult(last, lastWithParts)
|
||
|
||
// 处理空响应情况
|
||
if last == nil && lastWithParts == nil {
|
||
log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received")
|
||
}
|
||
|
||
// 如果收集到了图片 parts,需要合并到最终响应中
|
||
if len(collectedImageParts) > 0 {
|
||
finalResponse = mergeImagePartsToResponse(finalResponse, collectedImageParts)
|
||
}
|
||
|
||
// 如果收集到了文本,需要合并到最终响应中
|
||
if len(collectedTextParts) > 0 {
|
||
finalResponse = mergeTextPartsToResponse(finalResponse, collectedTextParts)
|
||
}
|
||
|
||
respBody, err := json.Marshal(finalResponse)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to marshal response: %w", err)
|
||
}
|
||
c.Data(http.StatusOK, "application/json", respBody)
|
||
|
||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||
}
|
||
|
||
// getOrCreateGeminiParts 获取 Gemini 响应的 parts 结构,返回深拷贝和更新回调
|
||
func getOrCreateGeminiParts(response map[string]any) (result map[string]any, existingParts []any, setParts func([]any)) {
|
||
// 深拷贝 response
|
||
result = make(map[string]any)
|
||
for k, v := range response {
|
||
result[k] = v
|
||
}
|
||
|
||
// 获取或创建 candidates
|
||
candidates, ok := result["candidates"].([]any)
|
||
if !ok || len(candidates) == 0 {
|
||
candidates = []any{map[string]any{}}
|
||
}
|
||
|
||
// 获取第一个 candidate
|
||
candidate, ok := candidates[0].(map[string]any)
|
||
if !ok {
|
||
candidate = make(map[string]any)
|
||
candidates[0] = candidate
|
||
}
|
||
|
||
// 获取或创建 content
|
||
content, ok := candidate["content"].(map[string]any)
|
||
if !ok {
|
||
content = map[string]any{"role": "model"}
|
||
candidate["content"] = content
|
||
}
|
||
|
||
// 获取现有 parts
|
||
existingParts, ok = content["parts"].([]any)
|
||
if !ok {
|
||
existingParts = []any{}
|
||
}
|
||
|
||
// 返回更新回调
|
||
setParts = func(newParts []any) {
|
||
content["parts"] = newParts
|
||
result["candidates"] = candidates
|
||
}
|
||
|
||
return result, existingParts, setParts
|
||
}
|
||
|
||
// mergeCollectedPartsToResponse 将收集的所有 parts 合并到 Gemini 响应中
|
||
// 这个函数会合并所有类型的 parts:text、thinking、functionCall、inlineData 等
|
||
// 保持原始顺序,只合并连续的普通 text parts
|
||
func mergeCollectedPartsToResponse(response map[string]any, collectedParts []map[string]any) map[string]any {
|
||
if len(collectedParts) == 0 {
|
||
return response
|
||
}
|
||
|
||
result, _, setParts := getOrCreateGeminiParts(response)
|
||
|
||
// 合并策略:
|
||
// 1. 保持原始顺序
|
||
// 2. 连续的普通 text parts 合并为一个
|
||
// 3. thinking、functionCall、inlineData 等保持原样
|
||
var mergedParts []any
|
||
var textBuffer strings.Builder
|
||
|
||
flushTextBuffer := func() {
|
||
if textBuffer.Len() > 0 {
|
||
mergedParts = append(mergedParts, map[string]any{
|
||
"text": textBuffer.String(),
|
||
})
|
||
textBuffer.Reset()
|
||
}
|
||
}
|
||
|
||
for _, part := range collectedParts {
|
||
// 检查是否是普通 text part
|
||
if text, ok := part["text"].(string); ok {
|
||
// 检查是否有 thought 标记
|
||
if thought, _ := part["thought"].(bool); thought {
|
||
// thinking part,先刷新 text buffer,然后保留原样
|
||
flushTextBuffer()
|
||
mergedParts = append(mergedParts, part)
|
||
} else {
|
||
// 普通 text,累积到 buffer
|
||
_, _ = textBuffer.WriteString(text)
|
||
}
|
||
} else {
|
||
// 非 text part(functionCall、inlineData 等),先刷新 text buffer,然后保留原样
|
||
flushTextBuffer()
|
||
mergedParts = append(mergedParts, part)
|
||
}
|
||
}
|
||
|
||
// 刷新剩余的 text
|
||
flushTextBuffer()
|
||
|
||
setParts(mergedParts)
|
||
return result
|
||
}
|
||
|
||
// mergeImagePartsToResponse 将收集到的图片 parts 合并到 Gemini 响应中
|
||
func mergeImagePartsToResponse(response map[string]any, imageParts []map[string]any) map[string]any {
|
||
if len(imageParts) == 0 {
|
||
return response
|
||
}
|
||
|
||
result, existingParts, setParts := getOrCreateGeminiParts(response)
|
||
|
||
// 检查现有 parts 中是否已经有图片
|
||
for _, p := range existingParts {
|
||
if pm, ok := p.(map[string]any); ok {
|
||
if _, hasInline := pm["inlineData"]; hasInline {
|
||
return result // 已有图片,不重复添加
|
||
}
|
||
}
|
||
}
|
||
|
||
// 添加收集到的图片 parts
|
||
for _, imgPart := range imageParts {
|
||
existingParts = append(existingParts, imgPart)
|
||
}
|
||
setParts(existingParts)
|
||
return result
|
||
}
|
||
|
||
// mergeTextPartsToResponse 将收集到的文本合并到 Gemini 响应中
|
||
func mergeTextPartsToResponse(response map[string]any, textParts []string) map[string]any {
|
||
if len(textParts) == 0 {
|
||
return response
|
||
}
|
||
|
||
mergedText := strings.Join(textParts, "")
|
||
result, existingParts, setParts := getOrCreateGeminiParts(response)
|
||
|
||
// 查找并更新第一个 text part,或创建新的
|
||
newParts := make([]any, 0, len(existingParts)+1)
|
||
textUpdated := false
|
||
|
||
for _, p := range existingParts {
|
||
pm, ok := p.(map[string]any)
|
||
if !ok {
|
||
newParts = append(newParts, p)
|
||
continue
|
||
}
|
||
if _, hasText := pm["text"]; hasText && !textUpdated {
|
||
// 用累积的文本替换
|
||
newPart := make(map[string]any)
|
||
for k, v := range pm {
|
||
newPart[k] = v
|
||
}
|
||
newPart["text"] = mergedText
|
||
newParts = append(newParts, newPart)
|
||
textUpdated = true
|
||
} else {
|
||
newParts = append(newParts, pm)
|
||
}
|
||
}
|
||
|
||
if !textUpdated {
|
||
newParts = append([]any{map[string]any{"text": mergedText}}, newParts...)
|
||
}
|
||
|
||
setParts(newParts)
|
||
return result
|
||
}
|
||
|
||
func (s *AntigravityGatewayService) writeClaudeError(c *gin.Context, status int, errType, message string) error {
|
||
c.JSON(status, gin.H{
|
||
"type": "error",
|
||
"error": gin.H{"type": errType, "message": message},
|
||
})
|
||
return fmt.Errorf("%s", message)
|
||
}
|
||
|
||
// WriteMappedClaudeError 导出版本,供 handler 层使用(如 fallback 错误处理)
|
||
func (s *AntigravityGatewayService) WriteMappedClaudeError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error {
|
||
return s.writeMappedClaudeError(c, account, upstreamStatus, upstreamRequestID, body)
|
||
}
|
||
|
||
func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error {
|
||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||
logBody, maxBytes := s.getLogConfig()
|
||
upstreamDetail := s.getUpstreamErrorDetail(body)
|
||
setOpsUpstreamError(c, upstreamStatus, upstreamMsg, upstreamDetail)
|
||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||
Platform: account.Platform,
|
||
AccountID: account.ID,
|
||
AccountName: account.Name,
|
||
UpstreamStatusCode: upstreamStatus,
|
||
UpstreamRequestID: upstreamRequestID,
|
||
Kind: "http_error",
|
||
Message: upstreamMsg,
|
||
Detail: upstreamDetail,
|
||
})
|
||
|
||
// 记录上游错误详情便于排障(可选:由配置控制;不回显到客户端)
|
||
if logBody {
|
||
log.Printf("[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, truncateForLog(body, maxBytes))
|
||
}
|
||
|
||
var statusCode int
|
||
var errType, errMsg string
|
||
|
||
switch upstreamStatus {
|
||
case 400:
|
||
statusCode = http.StatusBadRequest
|
||
errType = "invalid_request_error"
|
||
errMsg = getPassthroughOrDefault(upstreamMsg, "Invalid request")
|
||
case 401:
|
||
statusCode = http.StatusBadGateway
|
||
errType = "authentication_error"
|
||
errMsg = "Upstream authentication failed"
|
||
case 403:
|
||
statusCode = http.StatusBadGateway
|
||
errType = "permission_error"
|
||
errMsg = "Upstream access forbidden"
|
||
case 429:
|
||
statusCode = http.StatusTooManyRequests
|
||
errType = "rate_limit_error"
|
||
errMsg = "Upstream rate limit exceeded"
|
||
case 529:
|
||
statusCode = http.StatusServiceUnavailable
|
||
errType = "overloaded_error"
|
||
errMsg = "Upstream service overloaded"
|
||
default:
|
||
statusCode = http.StatusBadGateway
|
||
errType = "upstream_error"
|
||
errMsg = "Upstream request failed"
|
||
}
|
||
|
||
c.JSON(statusCode, gin.H{
|
||
"type": "error",
|
||
"error": gin.H{"type": errType, "message": errMsg},
|
||
})
|
||
if upstreamMsg == "" {
|
||
return fmt.Errorf("upstream error: %d", upstreamStatus)
|
||
}
|
||
return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg)
|
||
}
|
||
|
||
func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error {
|
||
statusStr := "UNKNOWN"
|
||
switch status {
|
||
case 400:
|
||
statusStr = "INVALID_ARGUMENT"
|
||
case 404:
|
||
statusStr = "NOT_FOUND"
|
||
case 429:
|
||
statusStr = "RESOURCE_EXHAUSTED"
|
||
case 500:
|
||
statusStr = "INTERNAL"
|
||
case 502, 503:
|
||
statusStr = "UNAVAILABLE"
|
||
}
|
||
|
||
c.JSON(status, gin.H{
|
||
"error": gin.H{
|
||
"code": status,
|
||
"message": message,
|
||
"status": statusStr,
|
||
},
|
||
})
|
||
return fmt.Errorf("%s", message)
|
||
}
|
||
|
||
// handleClaudeStreamToNonStreaming 收集上游流式响应,转换为 Claude 非流式格式返回
|
||
// 用于处理客户端非流式请求但上游只支持流式的情况
|
||
func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
|
||
scanner := bufio.NewScanner(resp.Body)
|
||
maxLineSize := defaultMaxLineSize
|
||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
||
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
|
||
}
|
||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||
|
||
var firstTokenMs *int
|
||
var last map[string]any
|
||
var lastWithParts map[string]any
|
||
var collectedParts []map[string]any // 收集所有 parts(包括 text、thinking、functionCall、inlineData 等)
|
||
|
||
type scanEvent struct {
|
||
line string
|
||
err error
|
||
}
|
||
|
||
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
|
||
events := make(chan scanEvent, 16)
|
||
done := make(chan struct{})
|
||
sendEvent := func(ev scanEvent) bool {
|
||
select {
|
||
case events <- ev:
|
||
return true
|
||
case <-done:
|
||
return false
|
||
}
|
||
}
|
||
|
||
var lastReadAt int64
|
||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||
go func() {
|
||
defer close(events)
|
||
for scanner.Scan() {
|
||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||
return
|
||
}
|
||
}
|
||
if err := scanner.Err(); err != nil {
|
||
_ = sendEvent(scanEvent{err: err})
|
||
}
|
||
}()
|
||
defer close(done)
|
||
|
||
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
|
||
streamInterval := time.Duration(0)
|
||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||
}
|
||
var intervalTicker *time.Ticker
|
||
if streamInterval > 0 {
|
||
intervalTicker = time.NewTicker(streamInterval)
|
||
defer intervalTicker.Stop()
|
||
}
|
||
var intervalCh <-chan time.Time
|
||
if intervalTicker != nil {
|
||
intervalCh = intervalTicker.C
|
||
}
|
||
|
||
for {
|
||
select {
|
||
case ev, ok := <-events:
|
||
if !ok {
|
||
// 流结束,转换并返回响应
|
||
goto returnResponse
|
||
}
|
||
if ev.err != nil {
|
||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||
log.Printf("SSE line too long (antigravity claude non-stream): max_size=%d error=%v", maxLineSize, ev.err)
|
||
}
|
||
return nil, ev.err
|
||
}
|
||
|
||
line := ev.line
|
||
trimmed := strings.TrimRight(line, "\r\n")
|
||
|
||
if !strings.HasPrefix(trimmed, "data:") {
|
||
continue
|
||
}
|
||
|
||
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
|
||
if payload == "" || payload == "[DONE]" {
|
||
continue
|
||
}
|
||
|
||
// 解包 v1internal 响应
|
||
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
|
||
if parseErr != nil {
|
||
continue
|
||
}
|
||
|
||
var parsed map[string]any
|
||
if err := json.Unmarshal(inner, &parsed); err != nil {
|
||
continue
|
||
}
|
||
|
||
// 记录首 token 时间
|
||
if firstTokenMs == nil {
|
||
ms := int(time.Since(startTime).Milliseconds())
|
||
firstTokenMs = &ms
|
||
}
|
||
|
||
last = parsed
|
||
|
||
// 保留最后一个有 parts 的响应,并收集所有 parts
|
||
if parts := extractGeminiParts(parsed); len(parts) > 0 {
|
||
lastWithParts = parsed
|
||
|
||
// 收集所有 parts(text、thinking、functionCall、inlineData 等)
|
||
collectedParts = append(collectedParts, parts...)
|
||
}
|
||
|
||
case <-intervalCh:
|
||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||
if time.Since(lastRead) < streamInterval {
|
||
continue
|
||
}
|
||
log.Printf("Stream data interval timeout (antigravity claude non-stream)")
|
||
return nil, fmt.Errorf("stream data interval timeout")
|
||
}
|
||
}
|
||
|
||
returnResponse:
|
||
// 选择最后一个有效响应
|
||
finalResponse := pickGeminiCollectResult(last, lastWithParts)
|
||
|
||
// 处理空响应情况
|
||
if last == nil && lastWithParts == nil {
|
||
log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received")
|
||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream")
|
||
}
|
||
|
||
// 将收集的所有 parts 合并到最终响应中
|
||
if len(collectedParts) > 0 {
|
||
finalResponse = mergeCollectedPartsToResponse(finalResponse, collectedParts)
|
||
}
|
||
|
||
// 序列化为 JSON(Gemini 格式)
|
||
geminiBody, err := json.Marshal(finalResponse)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to marshal gemini response: %w", err)
|
||
}
|
||
|
||
// 转换 Gemini 响应为 Claude 格式
|
||
claudeResp, agUsage, err := antigravity.TransformGeminiToClaude(geminiBody, originalModel)
|
||
if err != nil {
|
||
log.Printf("[antigravity-Forward] transform_error error=%v body=%s", err, string(geminiBody))
|
||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
|
||
}
|
||
|
||
c.Data(http.StatusOK, "application/json", claudeResp)
|
||
|
||
// 转换为 service.ClaudeUsage
|
||
usage := &ClaudeUsage{
|
||
InputTokens: agUsage.InputTokens,
|
||
OutputTokens: agUsage.OutputTokens,
|
||
CacheCreationInputTokens: agUsage.CacheCreationInputTokens,
|
||
CacheReadInputTokens: agUsage.CacheReadInputTokens,
|
||
}
|
||
|
||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||
}
|
||
|
||
// handleClaudeStreamingResponse 处理 Claude 流式响应(Gemini SSE → Claude SSE 转换)
|
||
func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
|
||
c.Header("Content-Type", "text/event-stream")
|
||
c.Header("Cache-Control", "no-cache")
|
||
c.Header("Connection", "keep-alive")
|
||
c.Header("X-Accel-Buffering", "no")
|
||
c.Status(http.StatusOK)
|
||
|
||
flusher, ok := c.Writer.(http.Flusher)
|
||
if !ok {
|
||
return nil, errors.New("streaming not supported")
|
||
}
|
||
|
||
processor := antigravity.NewStreamingProcessor(originalModel)
|
||
var firstTokenMs *int
|
||
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
|
||
scanner := bufio.NewScanner(resp.Body)
|
||
maxLineSize := defaultMaxLineSize
|
||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
||
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
|
||
}
|
||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||
|
||
// 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage
|
||
convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage {
|
||
if agUsage == nil {
|
||
return &ClaudeUsage{}
|
||
}
|
||
return &ClaudeUsage{
|
||
InputTokens: agUsage.InputTokens,
|
||
OutputTokens: agUsage.OutputTokens,
|
||
CacheCreationInputTokens: agUsage.CacheCreationInputTokens,
|
||
CacheReadInputTokens: agUsage.CacheReadInputTokens,
|
||
}
|
||
}
|
||
|
||
type scanEvent struct {
|
||
line string
|
||
err error
|
||
}
|
||
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
|
||
events := make(chan scanEvent, 16)
|
||
done := make(chan struct{})
|
||
sendEvent := func(ev scanEvent) bool {
|
||
select {
|
||
case events <- ev:
|
||
return true
|
||
case <-done:
|
||
return false
|
||
}
|
||
}
|
||
var lastReadAt int64
|
||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||
go func() {
|
||
defer close(events)
|
||
for scanner.Scan() {
|
||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||
return
|
||
}
|
||
}
|
||
if err := scanner.Err(); err != nil {
|
||
_ = sendEvent(scanEvent{err: err})
|
||
}
|
||
}()
|
||
defer close(done)
|
||
|
||
streamInterval := time.Duration(0)
|
||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||
}
|
||
var intervalTicker *time.Ticker
|
||
if streamInterval > 0 {
|
||
intervalTicker = time.NewTicker(streamInterval)
|
||
defer intervalTicker.Stop()
|
||
}
|
||
var intervalCh <-chan time.Time
|
||
if intervalTicker != nil {
|
||
intervalCh = intervalTicker.C
|
||
}
|
||
|
||
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity claude")
|
||
|
||
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
||
errorEventSent := false
|
||
sendErrorEvent := func(reason string) {
|
||
if errorEventSent || cw.Disconnected() {
|
||
return
|
||
}
|
||
errorEventSent = true
|
||
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
||
flusher.Flush()
|
||
}
|
||
|
||
// finishUsage 是获取 processor 最终 usage 的辅助函数
|
||
finishUsage := func() *ClaudeUsage {
|
||
_, agUsage := processor.Finish()
|
||
return convertUsage(agUsage)
|
||
}
|
||
|
||
for {
|
||
select {
|
||
case ev, ok := <-events:
|
||
if !ok {
|
||
// 上游完成,发送结束事件
|
||
finalEvents, agUsage := processor.Finish()
|
||
if len(finalEvents) > 0 {
|
||
cw.Write(finalEvents)
|
||
}
|
||
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}, nil
|
||
}
|
||
if ev.err != nil {
|
||
if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity claude"); handled {
|
||
return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil
|
||
}
|
||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
|
||
sendErrorEvent("response_too_large")
|
||
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, ev.err
|
||
}
|
||
sendErrorEvent("stream_read_error")
|
||
return nil, fmt.Errorf("stream read error: %w", ev.err)
|
||
}
|
||
|
||
// 处理 SSE 行,转换为 Claude 格式
|
||
claudeEvents := processor.ProcessLine(strings.TrimRight(ev.line, "\r\n"))
|
||
if len(claudeEvents) > 0 {
|
||
if firstTokenMs == nil {
|
||
ms := int(time.Since(startTime).Milliseconds())
|
||
firstTokenMs = &ms
|
||
}
|
||
cw.Write(claudeEvents)
|
||
}
|
||
|
||
case <-intervalCh:
|
||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||
if time.Since(lastRead) < streamInterval {
|
||
continue
|
||
}
|
||
if cw.Disconnected() {
|
||
log.Printf("Upstream timeout after client disconnect (antigravity claude), returning collected usage")
|
||
return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||
}
|
||
log.Printf("Stream data interval timeout (antigravity)")
|
||
sendErrorEvent("stream_timeout")
|
||
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||
}
|
||
}
|
||
}
|
||
|
||
// extractImageSize 从 Gemini 请求中提取 image_size 参数
|
||
func (s *AntigravityGatewayService) extractImageSize(body []byte) string {
|
||
var req antigravity.GeminiRequest
|
||
if err := json.Unmarshal(body, &req); err != nil {
|
||
return "2K" // 默认 2K
|
||
}
|
||
|
||
if req.GenerationConfig != nil && req.GenerationConfig.ImageConfig != nil {
|
||
size := strings.ToUpper(strings.TrimSpace(req.GenerationConfig.ImageConfig.ImageSize))
|
||
if size == "1K" || size == "2K" || size == "4K" {
|
||
return size
|
||
}
|
||
}
|
||
|
||
return "2K" // 默认 2K
|
||
}
|
||
|
||
// isImageGenerationModel 判断模型是否为图片生成模型
|
||
// 支持的模型:gemini-3-pro-image, gemini-3-pro-image-preview, gemini-2.5-flash-image 等
|
||
func isImageGenerationModel(model string) bool {
|
||
modelLower := strings.ToLower(model)
|
||
// 移除 models/ 前缀
|
||
modelLower = strings.TrimPrefix(modelLower, "models/")
|
||
|
||
// 精确匹配或前缀匹配
|
||
return modelLower == "gemini-3-pro-image" ||
|
||
modelLower == "gemini-3-pro-image-preview" ||
|
||
strings.HasPrefix(modelLower, "gemini-3-pro-image-") ||
|
||
modelLower == "gemini-2.5-flash-image" ||
|
||
modelLower == "gemini-2.5-flash-image-preview" ||
|
||
strings.HasPrefix(modelLower, "gemini-2.5-flash-image-")
|
||
}
|
||
|
||
// cleanGeminiRequest 清理 Gemini 请求体中的 Schema
|
||
func cleanGeminiRequest(body []byte) ([]byte, error) {
|
||
var payload map[string]any
|
||
if err := json.Unmarshal(body, &payload); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
modified := false
|
||
|
||
// 1. 清理 Tools
|
||
if tools, ok := payload["tools"].([]any); ok && len(tools) > 0 {
|
||
for _, t := range tools {
|
||
toolMap, ok := t.(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
// function_declarations (snake_case) or functionDeclarations (camelCase)
|
||
var funcs []any
|
||
if f, ok := toolMap["functionDeclarations"].([]any); ok {
|
||
funcs = f
|
||
} else if f, ok := toolMap["function_declarations"].([]any); ok {
|
||
funcs = f
|
||
}
|
||
|
||
if len(funcs) == 0 {
|
||
continue
|
||
}
|
||
|
||
for _, f := range funcs {
|
||
funcMap, ok := f.(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
if params, ok := funcMap["parameters"].(map[string]any); ok {
|
||
antigravity.DeepCleanUndefined(params)
|
||
cleaned := antigravity.CleanJSONSchema(params)
|
||
funcMap["parameters"] = cleaned
|
||
modified = true
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if !modified {
|
||
return body, nil
|
||
}
|
||
|
||
return json.Marshal(payload)
|
||
}
|
||
|
||
// filterEmptyPartsFromGeminiRequest 过滤掉 parts 为空的消息
|
||
// Gemini API 不接受空 parts,需要在请求前过滤
|
||
func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) {
|
||
var payload map[string]any
|
||
if err := json.Unmarshal(body, &payload); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
contents, ok := payload["contents"].([]any)
|
||
if !ok || len(contents) == 0 {
|
||
return body, nil
|
||
}
|
||
|
||
filtered := make([]any, 0, len(contents))
|
||
modified := false
|
||
|
||
for _, c := range contents {
|
||
contentMap, ok := c.(map[string]any)
|
||
if !ok {
|
||
filtered = append(filtered, c)
|
||
continue
|
||
}
|
||
|
||
parts, hasParts := contentMap["parts"]
|
||
if !hasParts {
|
||
filtered = append(filtered, c)
|
||
continue
|
||
}
|
||
|
||
partsSlice, ok := parts.([]any)
|
||
if !ok {
|
||
filtered = append(filtered, c)
|
||
continue
|
||
}
|
||
|
||
// 跳过 parts 为空数组的消息
|
||
if len(partsSlice) == 0 {
|
||
modified = true
|
||
continue
|
||
}
|
||
|
||
filtered = append(filtered, c)
|
||
}
|
||
|
||
if !modified {
|
||
return body, nil
|
||
}
|
||
|
||
payload["contents"] = filtered
|
||
return json.Marshal(payload)
|
||
}
|
||
|
||
// ForwardUpstream 使用 base_url + /v1/messages + 双 header 认证透传上游 Claude 请求
|
||
func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
||
startTime := time.Now()
|
||
sessionID := getSessionID(c)
|
||
prefix := logPrefix(sessionID, account.Name)
|
||
|
||
// 获取上游配置
|
||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
|
||
if baseURL == "" || apiKey == "" {
|
||
return nil, fmt.Errorf("upstream account missing base_url or api_key")
|
||
}
|
||
baseURL = strings.TrimSuffix(baseURL, "/")
|
||
|
||
// 解析请求获取模型信息
|
||
var claudeReq antigravity.ClaudeRequest
|
||
if err := json.Unmarshal(body, &claudeReq); err != nil {
|
||
return nil, fmt.Errorf("parse claude request: %w", err)
|
||
}
|
||
if strings.TrimSpace(claudeReq.Model) == "" {
|
||
return nil, fmt.Errorf("missing model")
|
||
}
|
||
originalModel := claudeReq.Model
|
||
billingModel := originalModel
|
||
|
||
// 构建上游请求 URL
|
||
upstreamURL := baseURL + "/v1/messages"
|
||
|
||
// 创建请求
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
|
||
if err != nil {
|
||
return nil, fmt.Errorf("create upstream request: %w", err)
|
||
}
|
||
|
||
// 设置请求头
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||
req.Header.Set("x-api-key", apiKey) // Claude API 兼容
|
||
|
||
// 透传 Claude 相关 headers
|
||
if v := c.GetHeader("anthropic-version"); v != "" {
|
||
req.Header.Set("anthropic-version", v)
|
||
}
|
||
if v := c.GetHeader("anthropic-beta"); v != "" {
|
||
req.Header.Set("anthropic-beta", v)
|
||
}
|
||
|
||
// 代理 URL
|
||
proxyURL := ""
|
||
if account.ProxyID != nil && account.Proxy != nil {
|
||
proxyURL = account.Proxy.URL()
|
||
}
|
||
|
||
// 发送请求
|
||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||
if err != nil {
|
||
log.Printf("%s upstream request failed: %v", prefix, err)
|
||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||
}
|
||
defer func() { _ = resp.Body.Close() }()
|
||
|
||
// 处理错误响应
|
||
if resp.StatusCode >= 400 {
|
||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||
|
||
// 429 错误时标记账号限流
|
||
if resp.StatusCode == http.StatusTooManyRequests {
|
||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", false)
|
||
}
|
||
|
||
// 透传上游错误
|
||
c.Header("Content-Type", resp.Header.Get("Content-Type"))
|
||
c.Status(resp.StatusCode)
|
||
_, _ = c.Writer.Write(respBody)
|
||
|
||
return &ForwardResult{
|
||
Model: billingModel,
|
||
}, nil
|
||
}
|
||
|
||
// 处理成功响应(流式/非流式)
|
||
var usage *ClaudeUsage
|
||
var firstTokenMs *int
|
||
var clientDisconnect bool
|
||
|
||
if claudeReq.Stream {
|
||
// 流式响应:透传
|
||
c.Header("Content-Type", "text/event-stream")
|
||
c.Header("Cache-Control", "no-cache")
|
||
c.Header("Connection", "keep-alive")
|
||
c.Header("X-Accel-Buffering", "no")
|
||
c.Status(http.StatusOK)
|
||
|
||
streamRes := s.streamUpstreamResponse(c, resp, startTime)
|
||
usage = streamRes.usage
|
||
firstTokenMs = streamRes.firstTokenMs
|
||
clientDisconnect = streamRes.clientDisconnect
|
||
} else {
|
||
// 非流式响应:直接透传
|
||
respBody, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("read upstream response: %w", err)
|
||
}
|
||
|
||
// 提取 usage
|
||
usage = s.extractClaudeUsage(respBody)
|
||
|
||
c.Header("Content-Type", resp.Header.Get("Content-Type"))
|
||
c.Status(http.StatusOK)
|
||
_, _ = c.Writer.Write(respBody)
|
||
}
|
||
|
||
// 构建计费结果
|
||
duration := time.Since(startTime)
|
||
log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds())
|
||
|
||
return &ForwardResult{
|
||
Model: billingModel,
|
||
Stream: claudeReq.Stream,
|
||
Duration: duration,
|
||
FirstTokenMs: firstTokenMs,
|
||
ClientDisconnect: clientDisconnect,
|
||
Usage: ClaudeUsage{
|
||
InputTokens: usage.InputTokens,
|
||
OutputTokens: usage.OutputTokens,
|
||
CacheReadInputTokens: usage.CacheReadInputTokens,
|
||
CacheCreationInputTokens: usage.CacheCreationInputTokens,
|
||
},
|
||
}, nil
|
||
}
|
||
|
||
// streamUpstreamResponse 透传上游 SSE 流并提取 Claude usage
|
||
func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp *http.Response, startTime time.Time) *antigravityStreamResult {
|
||
usage := &ClaudeUsage{}
|
||
var firstTokenMs *int
|
||
|
||
scanner := bufio.NewScanner(resp.Body)
|
||
maxLineSize := defaultMaxLineSize
|
||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
||
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
|
||
}
|
||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||
|
||
type scanEvent struct {
|
||
line string
|
||
err error
|
||
}
|
||
events := make(chan scanEvent, 16)
|
||
done := make(chan struct{})
|
||
sendEvent := func(ev scanEvent) bool {
|
||
select {
|
||
case events <- ev:
|
||
return true
|
||
case <-done:
|
||
return false
|
||
}
|
||
}
|
||
var lastReadAt int64
|
||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||
go func() {
|
||
defer close(events)
|
||
for scanner.Scan() {
|
||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||
return
|
||
}
|
||
}
|
||
if err := scanner.Err(); err != nil {
|
||
_ = sendEvent(scanEvent{err: err})
|
||
}
|
||
}()
|
||
defer close(done)
|
||
|
||
streamInterval := time.Duration(0)
|
||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||
}
|
||
var intervalTicker *time.Ticker
|
||
if streamInterval > 0 {
|
||
intervalTicker = time.NewTicker(streamInterval)
|
||
defer intervalTicker.Stop()
|
||
}
|
||
var intervalCh <-chan time.Time
|
||
if intervalTicker != nil {
|
||
intervalCh = intervalTicker.C
|
||
}
|
||
|
||
flusher, _ := c.Writer.(http.Flusher)
|
||
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity upstream")
|
||
|
||
for {
|
||
select {
|
||
case ev, ok := <-events:
|
||
if !ok {
|
||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}
|
||
}
|
||
if ev.err != nil {
|
||
if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity upstream"); handled {
|
||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect}
|
||
}
|
||
log.Printf("Stream read error (antigravity upstream): %v", ev.err)
|
||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
|
||
}
|
||
|
||
line := ev.line
|
||
|
||
// 记录首 token 时间
|
||
if firstTokenMs == nil && len(line) > 0 {
|
||
ms := int(time.Since(startTime).Milliseconds())
|
||
firstTokenMs = &ms
|
||
}
|
||
|
||
// 尝试从 message_delta 或 message_stop 事件提取 usage
|
||
s.extractSSEUsage(line, usage)
|
||
|
||
// 透传行
|
||
cw.Fprintf("%s\n", line)
|
||
|
||
case <-intervalCh:
|
||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||
if time.Since(lastRead) < streamInterval {
|
||
continue
|
||
}
|
||
if cw.Disconnected() {
|
||
log.Printf("Upstream timeout after client disconnect (antigravity upstream), returning collected usage")
|
||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}
|
||
}
|
||
log.Printf("Stream data interval timeout (antigravity upstream)")
|
||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
|
||
}
|
||
}
|
||
}
|
||
|
||
// extractSSEUsage 从 SSE data 行中提取 Claude usage(用于流式透传场景)
|
||
func (s *AntigravityGatewayService) extractSSEUsage(line string, usage *ClaudeUsage) {
|
||
if !strings.HasPrefix(line, "data: ") {
|
||
return
|
||
}
|
||
dataStr := strings.TrimPrefix(line, "data: ")
|
||
var event map[string]any
|
||
if json.Unmarshal([]byte(dataStr), &event) != nil {
|
||
return
|
||
}
|
||
u, ok := event["usage"].(map[string]any)
|
||
if !ok {
|
||
return
|
||
}
|
||
if v, ok := u["input_tokens"].(float64); ok && int(v) > 0 {
|
||
usage.InputTokens = int(v)
|
||
}
|
||
if v, ok := u["output_tokens"].(float64); ok && int(v) > 0 {
|
||
usage.OutputTokens = int(v)
|
||
}
|
||
if v, ok := u["cache_read_input_tokens"].(float64); ok && int(v) > 0 {
|
||
usage.CacheReadInputTokens = int(v)
|
||
}
|
||
if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 {
|
||
usage.CacheCreationInputTokens = int(v)
|
||
}
|
||
}
|
||
|
||
// extractClaudeUsage 从非流式 Claude 响应提取 usage
|
||
func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage {
|
||
usage := &ClaudeUsage{}
|
||
var resp map[string]any
|
||
if json.Unmarshal(body, &resp) != nil {
|
||
return usage
|
||
}
|
||
if u, ok := resp["usage"].(map[string]any); ok {
|
||
if v, ok := u["input_tokens"].(float64); ok {
|
||
usage.InputTokens = int(v)
|
||
}
|
||
if v, ok := u["output_tokens"].(float64); ok {
|
||
usage.OutputTokens = int(v)
|
||
}
|
||
if v, ok := u["cache_read_input_tokens"].(float64); ok {
|
||
usage.CacheReadInputTokens = int(v)
|
||
}
|
||
if v, ok := u["cache_creation_input_tokens"].(float64); ok {
|
||
usage.CacheCreationInputTokens = int(v)
|
||
}
|
||
}
|
||
return usage
|
||
}
|