ForwardUpstream/ForwardUpstreamGemini should pipe the upstream response directly to the client (headers + body), not parse it as SSE stream.
3739 lines
123 KiB
Go
3739 lines
123 KiB
Go
package service
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"log"
|
||
mathrand "math/rand"
|
||
"net"
|
||
"net/http"
|
||
"os"
|
||
"strconv"
|
||
"strings"
|
||
"sync/atomic"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/google/uuid"
|
||
)
|
||
|
||
const (
|
||
antigravityStickySessionTTL = time.Hour
|
||
antigravityMaxRetries = 3
|
||
antigravityRetryBaseDelay = 1 * time.Second
|
||
antigravityRetryMaxDelay = 16 * time.Second
|
||
|
||
// 限流相关常量
|
||
// antigravityRateLimitThreshold 限流等待/切换阈值
|
||
// - 智能重试:retryDelay < 此阈值时等待后重试,>= 此阈值时直接限流模型
|
||
// - 预检查:剩余限流时间 < 此阈值时等待,>= 此阈值时切换账号
|
||
antigravityRateLimitThreshold = 7 * time.Second
|
||
antigravitySmartRetryMinWait = 1 * time.Second // 智能重试最小等待时间
|
||
antigravitySmartRetryMaxAttempts = 3 // 智能重试最大次数
|
||
antigravityDefaultRateLimitDuration = 30 * time.Second // 默认限流时间(无 retryDelay 时使用)
|
||
|
||
// Google RPC 状态和类型常量
|
||
googleRPCStatusResourceExhausted = "RESOURCE_EXHAUSTED"
|
||
googleRPCStatusUnavailable = "UNAVAILABLE"
|
||
googleRPCTypeRetryInfo = "type.googleapis.com/google.rpc.RetryInfo"
|
||
googleRPCTypeErrorInfo = "type.googleapis.com/google.rpc.ErrorInfo"
|
||
googleRPCReasonModelCapacityExhausted = "MODEL_CAPACITY_EXHAUSTED"
|
||
googleRPCReasonRateLimitExceeded = "RATE_LIMIT_EXCEEDED"
|
||
)
|
||
|
||
// upstreamHopByHopHeaders 透传请求头时需要排除的 hop-by-hop 头
|
||
var upstreamHopByHopHeaders = map[string]bool{
|
||
"connection": true,
|
||
"keep-alive": true,
|
||
"proxy-authenticate": true,
|
||
"proxy-authorization": true,
|
||
"proxy-connection": true,
|
||
"te": true,
|
||
"trailer": true,
|
||
"transfer-encoding": true,
|
||
"upgrade": true,
|
||
"host": true,
|
||
"content-length": true,
|
||
}
|
||
|
||
// antigravityPassthroughErrorMessages 透传给客户端的错误消息白名单(小写)
|
||
// 匹配时使用 strings.Contains,无需完全匹配
|
||
var antigravityPassthroughErrorMessages = []string{
|
||
"prompt is too long",
|
||
}
|
||
|
||
const (
|
||
antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
|
||
antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
|
||
)
|
||
|
||
// AntigravityAccountSwitchError 账号切换信号
|
||
// 当账号限流时间超过阈值时,通知上层切换账号
|
||
type AntigravityAccountSwitchError struct {
|
||
OriginalAccountID int64
|
||
RateLimitedModel string
|
||
IsStickySession bool // 是否为粘性会话切换(决定是否缓存计费)
|
||
}
|
||
|
||
func (e *AntigravityAccountSwitchError) Error() string {
|
||
return fmt.Sprintf("account %d model %s rate limited, need switch",
|
||
e.OriginalAccountID, e.RateLimitedModel)
|
||
}
|
||
|
||
// IsAntigravityAccountSwitchError 检查错误是否为账号切换信号
|
||
func IsAntigravityAccountSwitchError(err error) (*AntigravityAccountSwitchError, bool) {
|
||
var switchErr *AntigravityAccountSwitchError
|
||
if errors.As(err, &switchErr) {
|
||
return switchErr, true
|
||
}
|
||
return nil, false
|
||
}
|
||
|
||
// PromptTooLongError 表示上游明确返回 prompt too long
|
||
type PromptTooLongError struct {
|
||
StatusCode int
|
||
RequestID string
|
||
Body []byte
|
||
}
|
||
|
||
func (e *PromptTooLongError) Error() string {
|
||
return fmt.Sprintf("prompt too long: status=%d", e.StatusCode)
|
||
}
|
||
|
||
// antigravityRetryLoopParams 重试循环的参数
|
||
type antigravityRetryLoopParams struct {
|
||
ctx context.Context
|
||
prefix string
|
||
account *Account
|
||
proxyURL string
|
||
accessToken string
|
||
action string
|
||
body []byte
|
||
quotaScope AntigravityQuotaScope
|
||
c *gin.Context
|
||
httpUpstream HTTPUpstream
|
||
settingService *SettingService
|
||
accountRepo AccountRepository // 用于智能重试的模型级别限流
|
||
handleError func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult
|
||
requestedModel string // 用于限流检查的原始请求模型
|
||
isStickySession bool // 是否为粘性会话(用于账号切换时的缓存计费判断)
|
||
groupID int64 // 用于模型级限流时清除粘性会话
|
||
sessionHash string // 用于模型级限流时清除粘性会话
|
||
}
|
||
|
||
// antigravityRetryLoopResult 重试循环的结果
|
||
type antigravityRetryLoopResult struct {
|
||
resp *http.Response
|
||
}
|
||
|
||
// smartRetryAction 智能重试的处理结果
|
||
type smartRetryAction int
|
||
|
||
const (
|
||
smartRetryActionContinue smartRetryAction = iota // 继续默认重试逻辑
|
||
smartRetryActionBreakWithResp // 结束循环并返回 resp
|
||
smartRetryActionContinueURL // 继续 URL fallback 循环
|
||
)
|
||
|
||
// smartRetryResult 智能重试的结果
|
||
type smartRetryResult struct {
|
||
action smartRetryAction
|
||
resp *http.Response
|
||
err error
|
||
switchError *AntigravityAccountSwitchError // 模型限流时返回账号切换信号
|
||
}
|
||
|
||
// handleSmartRetry 处理 OAuth 账号的智能重试逻辑
|
||
// 将 429/503 限流处理逻辑抽取为独立函数,减少 antigravityRetryLoop 的复杂度
|
||
func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParams, resp *http.Response, respBody []byte, baseURL string, urlIdx int, availableURLs []string) *smartRetryResult {
|
||
// "Resource has been exhausted" 是 URL 级别限流,切换 URL(仅 429)
|
||
if resp.StatusCode == http.StatusTooManyRequests && isURLLevelRateLimit(respBody) && urlIdx < len(availableURLs)-1 {
|
||
log.Printf("%s URL fallback (429): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
|
||
return &smartRetryResult{action: smartRetryActionContinueURL}
|
||
}
|
||
|
||
// 判断是否触发智能重试
|
||
shouldSmartRetry, shouldRateLimitModel, waitDuration, modelName := shouldTriggerAntigravitySmartRetry(p.account, respBody)
|
||
|
||
// 情况1: retryDelay >= 阈值,限流模型并切换账号
|
||
if shouldRateLimitModel {
|
||
log.Printf("%s status=%d oauth_long_delay model=%s account=%d (model rate limit, switch account)",
|
||
p.prefix, resp.StatusCode, modelName, p.account.ID)
|
||
|
||
resetAt := time.Now().Add(antigravityDefaultRateLimitDuration)
|
||
if !setModelRateLimitByModelName(p.ctx, p.accountRepo, p.account.ID, modelName, p.prefix, resp.StatusCode, resetAt, false) {
|
||
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope, p.groupID, p.sessionHash, p.isStickySession)
|
||
log.Printf("%s status=%d rate_limited account=%d (no scope mapping)", p.prefix, resp.StatusCode, p.account.ID)
|
||
} else {
|
||
s.updateAccountModelRateLimitInCache(p.ctx, p.account, modelName, resetAt)
|
||
}
|
||
|
||
// 返回账号切换信号,让上层切换账号重试
|
||
return &smartRetryResult{
|
||
action: smartRetryActionBreakWithResp,
|
||
switchError: &AntigravityAccountSwitchError{
|
||
OriginalAccountID: p.account.ID,
|
||
RateLimitedModel: modelName,
|
||
IsStickySession: p.isStickySession,
|
||
},
|
||
}
|
||
}
|
||
|
||
// 情况2: retryDelay < 阈值,智能重试(最多 antigravitySmartRetryMaxAttempts 次)
|
||
if shouldSmartRetry {
|
||
var lastRetryResp *http.Response
|
||
var lastRetryBody []byte
|
||
|
||
for attempt := 1; attempt <= antigravitySmartRetryMaxAttempts; attempt++ {
|
||
log.Printf("%s status=%d oauth_smart_retry attempt=%d/%d delay=%v model=%s account=%d",
|
||
p.prefix, resp.StatusCode, attempt, antigravitySmartRetryMaxAttempts, waitDuration, modelName, p.account.ID)
|
||
|
||
select {
|
||
case <-p.ctx.Done():
|
||
log.Printf("%s status=context_canceled_during_smart_retry", p.prefix)
|
||
return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()}
|
||
case <-time.After(waitDuration):
|
||
}
|
||
|
||
// 智能重试:创建新请求
|
||
retryReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body)
|
||
if err != nil {
|
||
log.Printf("%s status=smart_retry_request_build_failed error=%v", p.prefix, err)
|
||
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope, p.groupID, p.sessionHash, p.isStickySession)
|
||
return &smartRetryResult{
|
||
action: smartRetryActionBreakWithResp,
|
||
resp: &http.Response{
|
||
StatusCode: resp.StatusCode,
|
||
Header: resp.Header.Clone(),
|
||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||
},
|
||
}
|
||
}
|
||
|
||
retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency)
|
||
if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable {
|
||
log.Printf("%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, antigravitySmartRetryMaxAttempts)
|
||
return &smartRetryResult{action: smartRetryActionBreakWithResp, resp: retryResp}
|
||
}
|
||
|
||
// 网络错误时,继续重试
|
||
if retryErr != nil || retryResp == nil {
|
||
log.Printf("%s status=smart_retry_network_error attempt=%d/%d error=%v", p.prefix, attempt, antigravitySmartRetryMaxAttempts, retryErr)
|
||
continue
|
||
}
|
||
|
||
// 重试失败,关闭之前的响应
|
||
if lastRetryResp != nil {
|
||
_ = lastRetryResp.Body.Close()
|
||
}
|
||
lastRetryResp = retryResp
|
||
if retryResp != nil {
|
||
lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
||
_ = retryResp.Body.Close()
|
||
}
|
||
|
||
// 解析新的重试信息,用于下次重试的等待时间
|
||
if attempt < antigravitySmartRetryMaxAttempts && lastRetryBody != nil {
|
||
newShouldRetry, _, newWaitDuration, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody)
|
||
if newShouldRetry && newWaitDuration > 0 {
|
||
waitDuration = newWaitDuration
|
||
}
|
||
}
|
||
}
|
||
|
||
// 所有重试都失败,限流当前模型并切换账号
|
||
log.Printf("%s status=%d smart_retry_exhausted attempts=%d model=%s account=%d (switch account)",
|
||
p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID)
|
||
|
||
resetAt := time.Now().Add(antigravityDefaultRateLimitDuration)
|
||
if p.accountRepo != nil && modelName != "" {
|
||
if err := p.accountRepo.SetModelRateLimit(p.ctx, p.account.ID, modelName, resetAt); err != nil {
|
||
log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", p.prefix, resp.StatusCode, modelName, err)
|
||
} else {
|
||
log.Printf("%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v",
|
||
p.prefix, resp.StatusCode, modelName, p.account.ID, antigravityDefaultRateLimitDuration)
|
||
s.updateAccountModelRateLimitInCache(p.ctx, p.account, modelName, resetAt)
|
||
}
|
||
}
|
||
|
||
// 返回账号切换信号,让上层切换账号重试
|
||
return &smartRetryResult{
|
||
action: smartRetryActionBreakWithResp,
|
||
switchError: &AntigravityAccountSwitchError{
|
||
OriginalAccountID: p.account.ID,
|
||
RateLimitedModel: modelName,
|
||
IsStickySession: p.isStickySession,
|
||
},
|
||
}
|
||
}
|
||
|
||
// 未触发智能重试,继续默认重试逻辑
|
||
return &smartRetryResult{action: smartRetryActionContinue}
|
||
}
|
||
|
||
// antigravityRetryLoop 执行带 URL fallback 的重试循环
|
||
func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
|
||
// 预检查:如果账号已限流,根据剩余时间决定等待或切换
|
||
if p.requestedModel != "" {
|
||
if remaining := p.account.GetRateLimitRemainingTimeWithContext(p.ctx, p.requestedModel); remaining > 0 {
|
||
if remaining < antigravityRateLimitThreshold {
|
||
// 限流剩余时间较短,等待后继续
|
||
log.Printf("%s pre_check: rate_limit_wait remaining=%v model=%s account=%d",
|
||
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
|
||
select {
|
||
case <-p.ctx.Done():
|
||
return nil, p.ctx.Err()
|
||
case <-time.After(remaining):
|
||
}
|
||
} else {
|
||
// 限流剩余时间较长,返回账号切换信号
|
||
log.Printf("%s pre_check: rate_limit_switch remaining=%v model=%s account=%d",
|
||
p.prefix, remaining.Truncate(time.Second), p.requestedModel, p.account.ID)
|
||
return nil, &AntigravityAccountSwitchError{
|
||
OriginalAccountID: p.account.ID,
|
||
RateLimitedModel: p.requestedModel,
|
||
IsStickySession: p.isStickySession,
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
||
if len(availableURLs) == 0 {
|
||
availableURLs = antigravity.BaseURLs
|
||
}
|
||
|
||
var resp *http.Response
|
||
var usedBaseURL string
|
||
logBody := p.settingService != nil && p.settingService.cfg != nil && p.settingService.cfg.Gateway.LogUpstreamErrorBody
|
||
maxBytes := 2048
|
||
if p.settingService != nil && p.settingService.cfg != nil && p.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
|
||
maxBytes = p.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||
}
|
||
getUpstreamDetail := func(body []byte) string {
|
||
if !logBody {
|
||
return ""
|
||
}
|
||
return truncateString(string(body), maxBytes)
|
||
}
|
||
|
||
urlFallbackLoop:
|
||
for urlIdx, baseURL := range availableURLs {
|
||
usedBaseURL = baseURL
|
||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
||
select {
|
||
case <-p.ctx.Done():
|
||
log.Printf("%s status=context_canceled error=%v", p.prefix, p.ctx.Err())
|
||
return nil, p.ctx.Err()
|
||
default:
|
||
}
|
||
|
||
upstreamReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// Capture upstream request body for ops retry of this attempt.
|
||
if p.c != nil && len(p.body) > 0 {
|
||
p.c.Set(OpsUpstreamRequestBodyKey, string(p.body))
|
||
}
|
||
|
||
resp, err = p.httpUpstream.Do(upstreamReq, p.proxyURL, p.account.ID, p.account.Concurrency)
|
||
if err == nil && resp == nil {
|
||
err = errors.New("upstream returned nil response")
|
||
}
|
||
if err != nil {
|
||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
|
||
Platform: p.account.Platform,
|
||
AccountID: p.account.ID,
|
||
AccountName: p.account.Name,
|
||
UpstreamStatusCode: 0,
|
||
Kind: "request_error",
|
||
Message: safeErr,
|
||
})
|
||
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||
log.Printf("%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
|
||
continue urlFallbackLoop
|
||
}
|
||
if attempt < antigravityMaxRetries {
|
||
log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, antigravityMaxRetries, err)
|
||
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
|
||
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
|
||
return nil, p.ctx.Err()
|
||
}
|
||
continue
|
||
}
|
||
log.Printf("%s status=request_failed retries_exhausted error=%v", p.prefix, err)
|
||
setOpsUpstreamError(p.c, 0, safeErr, "")
|
||
return nil, fmt.Errorf("upstream request failed after retries: %w", err)
|
||
}
|
||
|
||
// 429/503 限流处理:区分 URL 级别限流、智能重试和账户配额限流
|
||
if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable {
|
||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||
_ = resp.Body.Close()
|
||
|
||
// 尝试智能重试处理(OAuth 账号专用)
|
||
smartResult := s.handleSmartRetry(p, resp, respBody, baseURL, urlIdx, availableURLs)
|
||
switch smartResult.action {
|
||
case smartRetryActionContinueURL:
|
||
continue urlFallbackLoop
|
||
case smartRetryActionBreakWithResp:
|
||
if smartResult.err != nil {
|
||
return nil, smartResult.err
|
||
}
|
||
// 模型限流时返回切换账号信号
|
||
if smartResult.switchError != nil {
|
||
return nil, smartResult.switchError
|
||
}
|
||
resp = smartResult.resp
|
||
break urlFallbackLoop
|
||
}
|
||
// smartRetryActionContinue: 继续默认重试逻辑
|
||
|
||
// 账户/模型配额限流,重试 3 次(指数退避)- 默认逻辑(非 OAuth 账号或解析失败)
|
||
if attempt < antigravityMaxRetries {
|
||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
|
||
Platform: p.account.Platform,
|
||
AccountID: p.account.ID,
|
||
AccountName: p.account.Name,
|
||
UpstreamStatusCode: resp.StatusCode,
|
||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||
Kind: "retry",
|
||
Message: upstreamMsg,
|
||
Detail: getUpstreamDetail(respBody),
|
||
})
|
||
log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200))
|
||
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
|
||
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
|
||
return nil, p.ctx.Err()
|
||
}
|
||
continue
|
||
}
|
||
|
||
// 重试用尽,标记账户限流
|
||
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope, p.groupID, p.sessionHash, p.isStickySession)
|
||
log.Printf("%s status=%d rate_limited base_url=%s body=%s", p.prefix, resp.StatusCode, baseURL, truncateForLog(respBody, 200))
|
||
resp = &http.Response{
|
||
StatusCode: resp.StatusCode,
|
||
Header: resp.Header.Clone(),
|
||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||
}
|
||
break urlFallbackLoop
|
||
}
|
||
|
||
// 其他可重试错误(不包括 429 和 503,因为上面已处理)
|
||
if resp.StatusCode >= 400 && shouldRetryAntigravityError(resp.StatusCode) {
|
||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||
_ = resp.Body.Close()
|
||
|
||
if attempt < antigravityMaxRetries {
|
||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
|
||
Platform: p.account.Platform,
|
||
AccountID: p.account.ID,
|
||
AccountName: p.account.Name,
|
||
UpstreamStatusCode: resp.StatusCode,
|
||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||
Kind: "retry",
|
||
Message: upstreamMsg,
|
||
Detail: getUpstreamDetail(respBody),
|
||
})
|
||
log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
|
||
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
|
||
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
|
||
return nil, p.ctx.Err()
|
||
}
|
||
continue
|
||
}
|
||
resp = &http.Response{
|
||
StatusCode: resp.StatusCode,
|
||
Header: resp.Header.Clone(),
|
||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||
}
|
||
break urlFallbackLoop
|
||
}
|
||
|
||
break urlFallbackLoop
|
||
}
|
||
}
|
||
|
||
if resp != nil && resp.StatusCode < 400 && usedBaseURL != "" {
|
||
antigravity.DefaultURLAvailability.MarkSuccess(usedBaseURL)
|
||
}
|
||
|
||
return &antigravityRetryLoopResult{resp: resp}, nil
|
||
}
|
||
|
||
// shouldRetryAntigravityError 判断是否应该重试
|
||
func shouldRetryAntigravityError(statusCode int) bool {
|
||
switch statusCode {
|
||
case 429, 500, 502, 503, 504, 529:
|
||
return true
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
// isURLLevelRateLimit 判断是否为 URL 级别的限流(应切换 URL 重试)
|
||
// "Resource has been exhausted" 是 URL/节点级别限流,切换 URL 可能成功
|
||
// "exhausted your capacity on this model" 是账户/模型配额限流,切换 URL 无效
|
||
func isURLLevelRateLimit(body []byte) bool {
|
||
// 快速检查:包含 "Resource has been exhausted" 且不包含 "capacity on this model"
|
||
bodyStr := string(body)
|
||
return strings.Contains(bodyStr, "Resource has been exhausted") &&
|
||
!strings.Contains(bodyStr, "capacity on this model")
|
||
}
|
||
|
||
// isAntigravityConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝)
|
||
func isAntigravityConnectionError(err error) bool {
|
||
if err == nil {
|
||
return false
|
||
}
|
||
|
||
// 检查超时错误
|
||
var netErr net.Error
|
||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||
return true
|
||
}
|
||
|
||
// 检查连接错误(DNS 失败、连接拒绝)
|
||
var opErr *net.OpError
|
||
return errors.As(err, &opErr)
|
||
}
|
||
|
||
// shouldAntigravityFallbackToNextURL 判断是否应切换到下一个 URL
|
||
// 仅连接错误和 HTTP 429 触发 URL 降级
|
||
func shouldAntigravityFallbackToNextURL(err error, statusCode int) bool {
|
||
if isAntigravityConnectionError(err) {
|
||
return true
|
||
}
|
||
return statusCode == http.StatusTooManyRequests
|
||
}
|
||
|
||
// getSessionID 从 gin.Context 获取 session_id(用于日志追踪)
|
||
func getSessionID(c *gin.Context) string {
|
||
if c == nil {
|
||
return ""
|
||
}
|
||
return c.GetHeader("session_id")
|
||
}
|
||
|
||
// logPrefix 生成统一的日志前缀
|
||
func logPrefix(sessionID, accountName string) string {
|
||
if sessionID != "" {
|
||
return fmt.Sprintf("[antigravity-Forward] session=%s account=%s", sessionID, accountName)
|
||
}
|
||
return fmt.Sprintf("[antigravity-Forward] account=%s", accountName)
|
||
}
|
||
|
||
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
|
||
type AntigravityGatewayService struct {
|
||
accountRepo AccountRepository
|
||
tokenProvider *AntigravityTokenProvider
|
||
rateLimitService *RateLimitService
|
||
httpUpstream HTTPUpstream
|
||
settingService *SettingService
|
||
cache GatewayCache // 用于模型级限流时清除粘性会话绑定
|
||
schedulerSnapshot *SchedulerSnapshotService
|
||
}
|
||
|
||
func NewAntigravityGatewayService(
|
||
accountRepo AccountRepository,
|
||
cache GatewayCache,
|
||
schedulerSnapshot *SchedulerSnapshotService,
|
||
tokenProvider *AntigravityTokenProvider,
|
||
rateLimitService *RateLimitService,
|
||
httpUpstream HTTPUpstream,
|
||
settingService *SettingService,
|
||
) *AntigravityGatewayService {
|
||
return &AntigravityGatewayService{
|
||
accountRepo: accountRepo,
|
||
tokenProvider: tokenProvider,
|
||
rateLimitService: rateLimitService,
|
||
httpUpstream: httpUpstream,
|
||
settingService: settingService,
|
||
cache: cache,
|
||
schedulerSnapshot: schedulerSnapshot,
|
||
}
|
||
}
|
||
|
||
// GetTokenProvider 返回 token provider
|
||
func (s *AntigravityGatewayService) GetTokenProvider() *AntigravityTokenProvider {
|
||
return s.tokenProvider
|
||
}
|
||
|
||
// getLogConfig 获取上游错误日志配置
|
||
// 返回是否记录日志体和最大字节数
|
||
func (s *AntigravityGatewayService) getLogConfig() (logBody bool, maxBytes int) {
|
||
maxBytes = 2048 // 默认值
|
||
if s.settingService == nil || s.settingService.cfg == nil {
|
||
return false, maxBytes
|
||
}
|
||
cfg := s.settingService.cfg.Gateway
|
||
if cfg.LogUpstreamErrorBodyMaxBytes > 0 {
|
||
maxBytes = cfg.LogUpstreamErrorBodyMaxBytes
|
||
}
|
||
return cfg.LogUpstreamErrorBody, maxBytes
|
||
}
|
||
|
||
// getUpstreamErrorDetail 获取上游错误详情(用于日志记录)
|
||
func (s *AntigravityGatewayService) getUpstreamErrorDetail(body []byte) string {
|
||
logBody, maxBytes := s.getLogConfig()
|
||
if !logBody {
|
||
return ""
|
||
}
|
||
return truncateString(string(body), maxBytes)
|
||
}
|
||
|
||
// mapAntigravityModel 获取映射后的模型名
|
||
// 完全依赖映射配置:账户映射(通配符)→ 默认映射兜底(DefaultAntigravityModelMapping)
|
||
// 注意:返回空字符串表示模型不被支持,调度时会过滤掉该账号
|
||
func mapAntigravityModel(account *Account, requestedModel string) string {
|
||
if account == nil {
|
||
return ""
|
||
}
|
||
|
||
// 获取映射表(未配置时自动使用 DefaultAntigravityModelMapping)
|
||
mapping := account.GetModelMapping()
|
||
if len(mapping) == 0 {
|
||
return "" // 无映射配置(非 Antigravity 平台)
|
||
}
|
||
|
||
// 通过映射表查询(支持精确匹配 + 通配符)
|
||
mapped := account.GetMappedModel(requestedModel)
|
||
|
||
// 判断是否映射成功(mapped != requestedModel 说明找到了映射规则)
|
||
if mapped != requestedModel {
|
||
return mapped
|
||
}
|
||
|
||
// 如果 mapped == requestedModel,检查是否在映射表中配置(精确或通配符)
|
||
// 这区分两种情况:
|
||
// 1. 映射表中有 "model-a": "model-a"(显式透传)→ 返回 model-a
|
||
// 2. 通配符匹配 "claude-*": "claude-sonnet-4-5" 恰好目标等于请求名 → 返回 model-a
|
||
// 3. 映射表中没有 model-a 的配置 → 返回空(不支持)
|
||
if account.IsModelSupported(requestedModel) {
|
||
return requestedModel
|
||
}
|
||
|
||
// 未在映射表中配置的模型,返回空字符串(不支持)
|
||
return ""
|
||
}
|
||
|
||
// getMappedModel 获取映射后的模型名
|
||
// 完全依赖映射配置:账户映射(通配符)→ 默认映射兜底
|
||
func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedModel string) string {
|
||
return mapAntigravityModel(account, requestedModel)
|
||
}
|
||
|
||
// applyThinkingModelSuffix 根据 thinking 配置调整模型名
|
||
// 当映射结果是 claude-sonnet-4-5 且请求开启了 thinking 时,改为 claude-sonnet-4-5-thinking
|
||
func applyThinkingModelSuffix(mappedModel string, thinkingEnabled bool) string {
|
||
if !thinkingEnabled {
|
||
return mappedModel
|
||
}
|
||
if mappedModel == "claude-sonnet-4-5" {
|
||
return "claude-sonnet-4-5-thinking"
|
||
}
|
||
return mappedModel
|
||
}
|
||
|
||
// IsModelSupported 检查模型是否被支持
|
||
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持
|
||
func (s *AntigravityGatewayService) IsModelSupported(requestedModel string) bool {
|
||
return strings.HasPrefix(requestedModel, "claude-") ||
|
||
strings.HasPrefix(requestedModel, "gemini-")
|
||
}
|
||
|
||
// TestConnectionResult 测试连接结果
|
||
type TestConnectionResult struct {
|
||
Text string // 响应文本
|
||
MappedModel string // 实际使用的模型
|
||
}
|
||
|
||
// TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费)
|
||
// 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择
|
||
func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
|
||
if account.Type == AccountTypeUpstream {
|
||
return s.testUpstreamConnection(ctx, account, modelID)
|
||
}
|
||
|
||
// 获取 token
|
||
if s.tokenProvider == nil {
|
||
return nil, errors.New("antigravity token provider not configured")
|
||
}
|
||
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取 access_token 失败: %w", err)
|
||
}
|
||
|
||
// 获取 project_id(部分账户类型可能没有)
|
||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||
|
||
// 模型映射
|
||
mappedModel := s.getMappedModel(account, modelID)
|
||
if mappedModel == "" {
|
||
return nil, fmt.Errorf("model %s not in whitelist", modelID)
|
||
}
|
||
|
||
// 构建请求体
|
||
var requestBody []byte
|
||
if strings.HasPrefix(modelID, "gemini-") {
|
||
// Gemini 模型:直接使用 Gemini 格式
|
||
requestBody, err = s.buildGeminiTestRequest(projectID, mappedModel)
|
||
} else {
|
||
// Claude 模型:使用协议转换
|
||
requestBody, err = s.buildClaudeTestRequest(projectID, mappedModel)
|
||
}
|
||
if err != nil {
|
||
return nil, fmt.Errorf("构建请求失败: %w", err)
|
||
}
|
||
|
||
// 代理 URL
|
||
proxyURL := ""
|
||
if account.ProxyID != nil && account.Proxy != nil {
|
||
proxyURL = account.Proxy.URL()
|
||
}
|
||
|
||
// URL fallback 循环
|
||
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
||
if len(availableURLs) == 0 {
|
||
availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有
|
||
}
|
||
|
||
var lastErr error
|
||
for urlIdx, baseURL := range availableURLs {
|
||
// 构建 HTTP 请求(总是使用流式 endpoint,与官方客户端一致)
|
||
req, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, "streamGenerateContent", accessToken, requestBody)
|
||
if err != nil {
|
||
lastErr = err
|
||
continue
|
||
}
|
||
|
||
// 调试日志:Test 请求信息
|
||
log.Printf("[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String())
|
||
|
||
// 发送请求
|
||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||
if err != nil {
|
||
lastErr = fmt.Errorf("请求失败: %w", err)
|
||
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||
log.Printf("[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
||
continue
|
||
}
|
||
return nil, lastErr
|
||
}
|
||
|
||
// 读取响应
|
||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||
_ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏
|
||
if err != nil {
|
||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||
}
|
||
|
||
// 检查是否需要 URL 降级
|
||
if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
||
log.Printf("[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
||
continue
|
||
}
|
||
|
||
if resp.StatusCode >= 400 {
|
||
return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody))
|
||
}
|
||
|
||
// 解析流式响应,提取文本
|
||
text := extractTextFromSSEResponse(respBody)
|
||
|
||
// 标记成功的 URL,下次优先使用
|
||
antigravity.DefaultURLAvailability.MarkSuccess(baseURL)
|
||
return &TestConnectionResult{
|
||
Text: text,
|
||
MappedModel: mappedModel,
|
||
}, nil
|
||
}
|
||
|
||
return nil, lastErr
|
||
}
|
||
|
||
// buildGeminiTestRequest 构建 Gemini 格式测试请求
|
||
// 使用最小 token 消耗:输入 "." + maxOutputTokens: 1
|
||
func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model string) ([]byte, error) {
|
||
payload := map[string]any{
|
||
"contents": []map[string]any{
|
||
{
|
||
"role": "user",
|
||
"parts": []map[string]any{
|
||
{"text": "."},
|
||
},
|
||
},
|
||
},
|
||
// Antigravity 上游要求必须包含身份提示词
|
||
"systemInstruction": map[string]any{
|
||
"parts": []map[string]any{
|
||
{"text": antigravity.GetDefaultIdentityPatch()},
|
||
},
|
||
},
|
||
"generationConfig": map[string]any{
|
||
"maxOutputTokens": 1,
|
||
},
|
||
}
|
||
payloadBytes, _ := json.Marshal(payload)
|
||
return s.wrapV1InternalRequest(projectID, model, payloadBytes)
|
||
}
|
||
|
||
// buildClaudeTestRequest 构建 Claude 格式测试请求并转换为 Gemini 格式
|
||
// 使用最小 token 消耗:输入 "." + MaxTokens: 1
|
||
func (s *AntigravityGatewayService) buildClaudeTestRequest(projectID, mappedModel string) ([]byte, error) {
|
||
claudeReq := &antigravity.ClaudeRequest{
|
||
Model: mappedModel,
|
||
Messages: []antigravity.ClaudeMessage{
|
||
{
|
||
Role: "user",
|
||
Content: json.RawMessage(`"."`),
|
||
},
|
||
},
|
||
MaxTokens: 1,
|
||
Stream: false,
|
||
}
|
||
return antigravity.TransformClaudeToGemini(claudeReq, projectID, mappedModel)
|
||
}
|
||
|
||
func (s *AntigravityGatewayService) getClaudeTransformOptions(ctx context.Context) antigravity.TransformOptions {
|
||
opts := antigravity.DefaultTransformOptions()
|
||
if s.settingService == nil {
|
||
return opts
|
||
}
|
||
opts.EnableIdentityPatch = s.settingService.IsIdentityPatchEnabled(ctx)
|
||
opts.IdentityPatch = s.settingService.GetIdentityPatchPrompt(ctx)
|
||
return opts
|
||
}
|
||
|
||
// extractTextFromSSEResponse 从 SSE 流式响应中提取文本
|
||
func extractTextFromSSEResponse(respBody []byte) string {
|
||
var texts []string
|
||
lines := bytes.Split(respBody, []byte("\n"))
|
||
|
||
for _, line := range lines {
|
||
line = bytes.TrimSpace(line)
|
||
if len(line) == 0 {
|
||
continue
|
||
}
|
||
|
||
// 跳过 SSE 前缀
|
||
if bytes.HasPrefix(line, []byte("data:")) {
|
||
line = bytes.TrimPrefix(line, []byte("data:"))
|
||
line = bytes.TrimSpace(line)
|
||
}
|
||
|
||
// 跳过非 JSON 行
|
||
if len(line) == 0 || line[0] != '{' {
|
||
continue
|
||
}
|
||
|
||
// 解析 JSON
|
||
var data map[string]any
|
||
if err := json.Unmarshal(line, &data); err != nil {
|
||
continue
|
||
}
|
||
|
||
// 尝试从 response.candidates[0].content.parts[].text 提取
|
||
response, ok := data["response"].(map[string]any)
|
||
if !ok {
|
||
// 尝试直接从 candidates 提取(某些响应格式)
|
||
response = data
|
||
}
|
||
|
||
candidates, ok := response["candidates"].([]any)
|
||
if !ok || len(candidates) == 0 {
|
||
continue
|
||
}
|
||
|
||
candidate, ok := candidates[0].(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
content, ok := candidate["content"].(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
parts, ok := content["parts"].([]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
for _, part := range parts {
|
||
if partMap, ok := part.(map[string]any); ok {
|
||
if text, ok := partMap["text"].(string); ok && text != "" {
|
||
texts = append(texts, text)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
return strings.Join(texts, "")
|
||
}
|
||
|
||
// injectIdentityPatchToGeminiRequest 为 Gemini 格式请求注入身份提示词
|
||
// 如果请求中已包含 "You are Antigravity" 则不重复注入
|
||
func injectIdentityPatchToGeminiRequest(body []byte) ([]byte, error) {
|
||
var request map[string]any
|
||
if err := json.Unmarshal(body, &request); err != nil {
|
||
return nil, fmt.Errorf("解析 Gemini 请求失败: %w", err)
|
||
}
|
||
|
||
// 检查现有 systemInstruction 是否已包含身份提示词
|
||
if sysInst, ok := request["systemInstruction"].(map[string]any); ok {
|
||
if parts, ok := sysInst["parts"].([]any); ok {
|
||
for _, part := range parts {
|
||
if partMap, ok := part.(map[string]any); ok {
|
||
if text, ok := partMap["text"].(string); ok {
|
||
if strings.Contains(text, "You are Antigravity") {
|
||
// 已包含身份提示词,直接返回原始请求
|
||
return body, nil
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 获取默认身份提示词
|
||
identityPatch := antigravity.GetDefaultIdentityPatch()
|
||
|
||
// 构建新的 systemInstruction
|
||
newPart := map[string]any{"text": identityPatch}
|
||
|
||
if existing, ok := request["systemInstruction"].(map[string]any); ok {
|
||
// 已有 systemInstruction,在开头插入身份提示词
|
||
if parts, ok := existing["parts"].([]any); ok {
|
||
existing["parts"] = append([]any{newPart}, parts...)
|
||
} else {
|
||
existing["parts"] = []any{newPart}
|
||
}
|
||
} else {
|
||
// 没有 systemInstruction,创建新的
|
||
request["systemInstruction"] = map[string]any{
|
||
"parts": []any{newPart},
|
||
}
|
||
}
|
||
|
||
return json.Marshal(request)
|
||
}
|
||
|
||
// wrapV1InternalRequest 包装请求为 v1internal 格式
|
||
func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model string, originalBody []byte) ([]byte, error) {
|
||
var request any
|
||
if err := json.Unmarshal(originalBody, &request); err != nil {
|
||
return nil, fmt.Errorf("解析请求体失败: %w", err)
|
||
}
|
||
|
||
wrapped := map[string]any{
|
||
"project": projectID,
|
||
"requestId": "agent-" + uuid.New().String(),
|
||
"userAgent": "antigravity", // 固定值,与官方客户端一致
|
||
"requestType": "agent",
|
||
"model": model,
|
||
"request": request,
|
||
}
|
||
|
||
return json.Marshal(wrapped)
|
||
}
|
||
|
||
// unwrapV1InternalResponse 解包 v1internal 响应
|
||
func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byte, error) {
|
||
var outer map[string]any
|
||
if err := json.Unmarshal(body, &outer); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if resp, ok := outer["response"]; ok {
|
||
return json.Marshal(resp)
|
||
}
|
||
|
||
return body, nil
|
||
}
|
||
|
||
// isModelNotFoundError 检测是否为模型不存在的 404 错误
|
||
func isModelNotFoundError(statusCode int, body []byte) bool {
|
||
if statusCode != 404 {
|
||
return false
|
||
}
|
||
|
||
bodyStr := strings.ToLower(string(body))
|
||
keywords := []string{"model not found", "unknown model", "not found"}
|
||
for _, keyword := range keywords {
|
||
if strings.Contains(bodyStr, keyword) {
|
||
return true
|
||
}
|
||
}
|
||
return true // 404 without specific message also treated as model not found
|
||
}
|
||
|
||
// Forward 转发 Claude 协议请求(Claude → Gemini 转换)
|
||
func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) {
|
||
startTime := time.Now()
|
||
|
||
if account.Type == AccountTypeUpstream {
|
||
return s.ForwardUpstream(ctx, c, account, body, isStickySession)
|
||
}
|
||
|
||
sessionID := getSessionID(c)
|
||
prefix := logPrefix(sessionID, account.Name)
|
||
|
||
// 解析 Claude 请求
|
||
var claudeReq antigravity.ClaudeRequest
|
||
if err := json.Unmarshal(body, &claudeReq); err != nil {
|
||
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Invalid request body")
|
||
}
|
||
if strings.TrimSpace(claudeReq.Model) == "" {
|
||
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Missing model")
|
||
}
|
||
|
||
originalModel := claudeReq.Model
|
||
mappedModel := s.getMappedModel(account, claudeReq.Model)
|
||
if mappedModel == "" {
|
||
return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model))
|
||
}
|
||
loadModel := mappedModel
|
||
// 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本
|
||
thinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
|
||
mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled)
|
||
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
|
||
|
||
// 获取 access_token
|
||
if s.tokenProvider == nil {
|
||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "api_error", "Antigravity token provider not configured")
|
||
}
|
||
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
|
||
if err != nil {
|
||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "authentication_error", "Failed to get upstream access token")
|
||
}
|
||
|
||
// 获取 project_id(部分账户类型可能没有)
|
||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||
|
||
// 代理 URL
|
||
proxyURL := ""
|
||
if account.ProxyID != nil && account.Proxy != nil {
|
||
proxyURL = account.Proxy.URL()
|
||
}
|
||
|
||
// 获取转换选项
|
||
// Antigravity 上游要求必须包含身份提示词,否则会返回 429
|
||
transformOpts := s.getClaudeTransformOptions(ctx)
|
||
transformOpts.EnableIdentityPatch = true // 强制启用,Antigravity 上游必需
|
||
|
||
// 转换 Claude 请求为 Gemini 格式
|
||
geminiBody, err := antigravity.TransformClaudeToGeminiWithOptions(&claudeReq, projectID, mappedModel, transformOpts)
|
||
if err != nil {
|
||
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Invalid request")
|
||
}
|
||
|
||
// Antigravity 上游只支持流式请求,统一使用 streamGenerateContent
|
||
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
|
||
action := "streamGenerateContent"
|
||
|
||
// 统计模型调用次数(包括粘性会话,用于负载均衡调度)
|
||
if s.cache != nil {
|
||
_, _ = s.cache.IncrModelCallCount(ctx, account.ID, loadModel)
|
||
}
|
||
|
||
// 执行带重试的请求
|
||
result, err := s.antigravityRetryLoop(antigravityRetryLoopParams{
|
||
ctx: ctx,
|
||
prefix: prefix,
|
||
account: account,
|
||
proxyURL: proxyURL,
|
||
accessToken: accessToken,
|
||
action: action,
|
||
body: geminiBody,
|
||
quotaScope: quotaScope,
|
||
c: c,
|
||
httpUpstream: s.httpUpstream,
|
||
settingService: s.settingService,
|
||
accountRepo: s.accountRepo,
|
||
handleError: s.handleUpstreamError,
|
||
requestedModel: originalModel,
|
||
isStickySession: isStickySession, // Forward 由上层判断粘性会话
|
||
groupID: 0, // Forward 方法没有 groupID,由上层处理粘性会话清除
|
||
sessionHash: "", // Forward 方法没有 sessionHash,由上层处理粘性会话清除
|
||
})
|
||
if err != nil {
|
||
// 检查是否是账号切换信号,转换为 UpstreamFailoverError 让 Handler 切换账号
|
||
if switchErr, ok := IsAntigravityAccountSwitchError(err); ok {
|
||
return nil, &UpstreamFailoverError{
|
||
StatusCode: http.StatusServiceUnavailable,
|
||
ForceCacheBilling: switchErr.IsStickySession,
|
||
}
|
||
}
|
||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
|
||
}
|
||
resp := result.resp
|
||
defer func() { _ = resp.Body.Close() }()
|
||
|
||
if resp.StatusCode >= 400 {
|
||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||
|
||
// 优先检测 thinking block 的 signature 相关错误(400)并重试一次:
|
||
// Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验,
|
||
// 当历史消息携带的 signature 不合法时会直接 400;去除 thinking 后可继续完成请求。
|
||
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) {
|
||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||
logBody, maxBytes := s.getLogConfig()
|
||
upstreamDetail := s.getUpstreamErrorDetail(respBody)
|
||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||
Platform: account.Platform,
|
||
AccountID: account.ID,
|
||
AccountName: account.Name,
|
||
UpstreamStatusCode: resp.StatusCode,
|
||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||
Kind: "signature_error",
|
||
Message: upstreamMsg,
|
||
Detail: upstreamDetail,
|
||
})
|
||
|
||
// Conservative two-stage fallback:
|
||
// 1) Disable top-level thinking + thinking->text
|
||
// 2) Only if still signature-related 400: also downgrade tool_use/tool_result to text.
|
||
|
||
retryStages := []struct {
|
||
name string
|
||
strip func(*antigravity.ClaudeRequest) (bool, error)
|
||
}{
|
||
{name: "thinking-only", strip: stripThinkingFromClaudeRequest},
|
||
{name: "thinking+tools", strip: stripSignatureSensitiveBlocksFromClaudeRequest},
|
||
}
|
||
|
||
for _, stage := range retryStages {
|
||
retryClaudeReq := claudeReq
|
||
retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...)
|
||
|
||
stripped, stripErr := stage.strip(&retryClaudeReq)
|
||
if stripErr != nil || !stripped {
|
||
continue
|
||
}
|
||
|
||
log.Printf("Antigravity account %d: detected signature-related 400, retrying once (%s)", account.ID, stage.name)
|
||
|
||
retryGeminiBody, txErr := antigravity.TransformClaudeToGeminiWithOptions(&retryClaudeReq, projectID, mappedModel, s.getClaudeTransformOptions(ctx))
|
||
if txErr != nil {
|
||
continue
|
||
}
|
||
retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{
|
||
ctx: ctx,
|
||
prefix: prefix,
|
||
account: account,
|
||
proxyURL: proxyURL,
|
||
accessToken: accessToken,
|
||
action: action,
|
||
body: retryGeminiBody,
|
||
quotaScope: quotaScope,
|
||
c: c,
|
||
httpUpstream: s.httpUpstream,
|
||
settingService: s.settingService,
|
||
accountRepo: s.accountRepo,
|
||
handleError: s.handleUpstreamError,
|
||
requestedModel: originalModel,
|
||
isStickySession: isStickySession,
|
||
groupID: 0, // Forward 方法没有 groupID,由上层处理粘性会话清除
|
||
sessionHash: "", // Forward 方法没有 sessionHash,由上层处理粘性会话清除
|
||
})
|
||
if retryErr != nil {
|
||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||
Platform: account.Platform,
|
||
AccountID: account.ID,
|
||
AccountName: account.Name,
|
||
UpstreamStatusCode: 0,
|
||
Kind: "signature_retry_request_error",
|
||
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
|
||
})
|
||
log.Printf("Antigravity account %d: signature retry request failed (%s): %v", account.ID, stage.name, retryErr)
|
||
continue
|
||
}
|
||
|
||
retryResp := retryResult.resp
|
||
if retryResp.StatusCode < 400 {
|
||
_ = resp.Body.Close()
|
||
resp = retryResp
|
||
respBody = nil
|
||
break
|
||
}
|
||
|
||
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
||
_ = retryResp.Body.Close()
|
||
if retryResp.StatusCode == http.StatusTooManyRequests {
|
||
retryBaseURL := ""
|
||
if retryResp.Request != nil && retryResp.Request.URL != nil {
|
||
retryBaseURL = retryResp.Request.URL.Scheme + "://" + retryResp.Request.URL.Host
|
||
}
|
||
log.Printf("%s status=429 rate_limited base_url=%s retry_stage=%s body=%s", prefix, retryBaseURL, stage.name, truncateForLog(retryBody, 200))
|
||
}
|
||
kind := "signature_retry"
|
||
if strings.TrimSpace(stage.name) != "" {
|
||
kind = "signature_retry_" + strings.ReplaceAll(stage.name, "+", "_")
|
||
}
|
||
retryUpstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(retryBody))
|
||
retryUpstreamMsg = sanitizeUpstreamErrorMessage(retryUpstreamMsg)
|
||
retryUpstreamDetail := ""
|
||
if logBody {
|
||
retryUpstreamDetail = truncateString(string(retryBody), maxBytes)
|
||
}
|
||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||
Platform: account.Platform,
|
||
AccountID: account.ID,
|
||
AccountName: account.Name,
|
||
UpstreamStatusCode: retryResp.StatusCode,
|
||
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
|
||
Kind: kind,
|
||
Message: retryUpstreamMsg,
|
||
Detail: retryUpstreamDetail,
|
||
})
|
||
|
||
// If this stage fixed the signature issue, we stop; otherwise we may try the next stage.
|
||
if retryResp.StatusCode != http.StatusBadRequest || !isSignatureRelatedError(retryBody) {
|
||
respBody = retryBody
|
||
resp = &http.Response{
|
||
StatusCode: retryResp.StatusCode,
|
||
Header: retryResp.Header.Clone(),
|
||
Body: io.NopCloser(bytes.NewReader(retryBody)),
|
||
}
|
||
break
|
||
}
|
||
|
||
// Still signature-related; capture context and allow next stage.
|
||
respBody = retryBody
|
||
resp = &http.Response{
|
||
StatusCode: retryResp.StatusCode,
|
||
Header: retryResp.Header.Clone(),
|
||
Body: io.NopCloser(bytes.NewReader(retryBody)),
|
||
}
|
||
}
|
||
}
|
||
|
||
// 处理错误响应(重试后仍失败或不触发重试)
|
||
if resp.StatusCode >= 400 {
|
||
// 检测 prompt too long 错误,返回特殊错误类型供上层 fallback
|
||
if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) {
|
||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||
upstreamDetail := s.getUpstreamErrorDetail(respBody)
|
||
logBody, maxBytes := s.getLogConfig()
|
||
if logBody {
|
||
log.Printf("%s status=400 prompt_too_long=true upstream_message=%q request_id=%s body=%s", prefix, upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, maxBytes))
|
||
}
|
||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||
Platform: account.Platform,
|
||
AccountID: account.ID,
|
||
AccountName: account.Name,
|
||
UpstreamStatusCode: resp.StatusCode,
|
||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||
Kind: "prompt_too_long",
|
||
Message: upstreamMsg,
|
||
Detail: upstreamDetail,
|
||
})
|
||
return nil, &PromptTooLongError{
|
||
StatusCode: resp.StatusCode,
|
||
RequestID: resp.Header.Get("x-request-id"),
|
||
Body: respBody,
|
||
}
|
||
}
|
||
|
||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession)
|
||
|
||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||
upstreamDetail := s.getUpstreamErrorDetail(respBody)
|
||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||
Platform: account.Platform,
|
||
AccountID: account.ID,
|
||
AccountName: account.Name,
|
||
UpstreamStatusCode: resp.StatusCode,
|
||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||
Kind: "failover",
|
||
Message: upstreamMsg,
|
||
Detail: upstreamDetail,
|
||
})
|
||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||
}
|
||
|
||
return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody)
|
||
}
|
||
}
|
||
|
||
requestID := resp.Header.Get("x-request-id")
|
||
if requestID != "" {
|
||
c.Header("x-request-id", requestID)
|
||
}
|
||
|
||
var usage *ClaudeUsage
|
||
var firstTokenMs *int
|
||
if claudeReq.Stream {
|
||
// 客户端要求流式,直接透传转换
|
||
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
|
||
if err != nil {
|
||
log.Printf("%s status=stream_error error=%v", prefix, err)
|
||
return nil, err
|
||
}
|
||
usage = streamRes.usage
|
||
firstTokenMs = streamRes.firstTokenMs
|
||
} else {
|
||
// 客户端要求非流式,收集流式响应后转换返回
|
||
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel)
|
||
if err != nil {
|
||
log.Printf("%s status=stream_collect_error error=%v", prefix, err)
|
||
return nil, err
|
||
}
|
||
usage = streamRes.usage
|
||
firstTokenMs = streamRes.firstTokenMs
|
||
}
|
||
|
||
return &ForwardResult{
|
||
RequestID: requestID,
|
||
Usage: *usage,
|
||
Model: originalModel, // 使用原始模型用于计费和日志
|
||
Stream: claudeReq.Stream,
|
||
Duration: time.Since(startTime),
|
||
FirstTokenMs: firstTokenMs,
|
||
}, nil
|
||
}
|
||
|
||
func isSignatureRelatedError(respBody []byte) bool {
|
||
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
|
||
if msg == "" {
|
||
// Fallback: best-effort scan of the raw payload.
|
||
msg = strings.ToLower(string(respBody))
|
||
}
|
||
|
||
// Keep this intentionally broad: different upstreams may use "signature" or "thought_signature".
|
||
if strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") {
|
||
return true
|
||
}
|
||
|
||
// Also detect thinking block structural errors:
|
||
// "Expected `thinking` or `redacted_thinking`, but found `text`"
|
||
if strings.Contains(msg, "expected") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) {
|
||
return true
|
||
}
|
||
|
||
return false
|
||
}
|
||
|
||
// isPromptTooLongError 检测是否为 prompt too long 错误
|
||
func isPromptTooLongError(respBody []byte) bool {
|
||
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
|
||
if msg == "" {
|
||
msg = strings.ToLower(string(respBody))
|
||
}
|
||
return strings.Contains(msg, "prompt is too long") ||
|
||
strings.Contains(msg, "request is too long") ||
|
||
strings.Contains(msg, "context length exceeded") ||
|
||
strings.Contains(msg, "max_tokens")
|
||
}
|
||
|
||
// isPassthroughErrorMessage 检查错误消息是否在透传白名单中
|
||
func isPassthroughErrorMessage(msg string) bool {
|
||
lower := strings.ToLower(msg)
|
||
for _, pattern := range antigravityPassthroughErrorMessages {
|
||
if strings.Contains(lower, pattern) {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// getPassthroughOrDefault 若消息在白名单内则返回原始消息,否则返回默认消息
|
||
func getPassthroughOrDefault(upstreamMsg, defaultMsg string) string {
|
||
if isPassthroughErrorMessage(upstreamMsg) {
|
||
return upstreamMsg
|
||
}
|
||
return defaultMsg
|
||
}
|
||
|
||
func extractAntigravityErrorMessage(body []byte) string {
|
||
var payload map[string]any
|
||
if err := json.Unmarshal(body, &payload); err != nil {
|
||
return ""
|
||
}
|
||
|
||
// Google-style: {"error": {"message": "..."}}
|
||
if errObj, ok := payload["error"].(map[string]any); ok {
|
||
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
||
return msg
|
||
}
|
||
}
|
||
|
||
// Fallback: top-level message
|
||
if msg, ok := payload["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
||
return msg
|
||
}
|
||
|
||
return ""
|
||
}
|
||
|
||
// stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request.
|
||
// This preserves the thinking content while avoiding signature validation errors.
|
||
// Note: redacted_thinking blocks are removed because they cannot be converted to text.
|
||
// It also disables top-level `thinking` to avoid upstream structural constraints for thinking mode.
|
||
func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error) {
|
||
if req == nil {
|
||
return false, nil
|
||
}
|
||
|
||
changed := false
|
||
if req.Thinking != nil {
|
||
req.Thinking = nil
|
||
changed = true
|
||
}
|
||
|
||
for i := range req.Messages {
|
||
raw := req.Messages[i].Content
|
||
if len(raw) == 0 {
|
||
continue
|
||
}
|
||
|
||
// If content is a string, nothing to strip.
|
||
var str string
|
||
if json.Unmarshal(raw, &str) == nil {
|
||
continue
|
||
}
|
||
|
||
// Otherwise treat as an array of blocks and convert thinking blocks to text.
|
||
var blocks []map[string]any
|
||
if err := json.Unmarshal(raw, &blocks); err != nil {
|
||
continue
|
||
}
|
||
|
||
filtered := make([]map[string]any, 0, len(blocks))
|
||
modifiedAny := false
|
||
for _, block := range blocks {
|
||
t, _ := block["type"].(string)
|
||
switch t {
|
||
case "thinking":
|
||
thinkingText, _ := block["thinking"].(string)
|
||
if thinkingText != "" {
|
||
filtered = append(filtered, map[string]any{
|
||
"type": "text",
|
||
"text": thinkingText,
|
||
})
|
||
}
|
||
modifiedAny = true
|
||
case "redacted_thinking":
|
||
modifiedAny = true
|
||
case "":
|
||
if thinkingText, hasThinking := block["thinking"].(string); hasThinking {
|
||
if thinkingText != "" {
|
||
filtered = append(filtered, map[string]any{
|
||
"type": "text",
|
||
"text": thinkingText,
|
||
})
|
||
}
|
||
modifiedAny = true
|
||
} else {
|
||
filtered = append(filtered, block)
|
||
}
|
||
default:
|
||
filtered = append(filtered, block)
|
||
}
|
||
}
|
||
|
||
if !modifiedAny {
|
||
continue
|
||
}
|
||
|
||
if len(filtered) == 0 {
|
||
filtered = append(filtered, map[string]any{
|
||
"type": "text",
|
||
"text": "(content removed)",
|
||
})
|
||
}
|
||
|
||
newRaw, err := json.Marshal(filtered)
|
||
if err != nil {
|
||
return changed, err
|
||
}
|
||
req.Messages[i].Content = newRaw
|
||
changed = true
|
||
}
|
||
|
||
return changed, nil
|
||
}
|
||
|
||
// stripSignatureSensitiveBlocksFromClaudeRequest is a stronger retry degradation that additionally converts
|
||
// tool blocks to plain text. Use this only after a thinking-only retry still fails with signature errors.
|
||
func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error) {
|
||
if req == nil {
|
||
return false, nil
|
||
}
|
||
|
||
changed := false
|
||
if req.Thinking != nil {
|
||
req.Thinking = nil
|
||
changed = true
|
||
}
|
||
|
||
for i := range req.Messages {
|
||
raw := req.Messages[i].Content
|
||
if len(raw) == 0 {
|
||
continue
|
||
}
|
||
|
||
// If content is a string, nothing to strip.
|
||
var str string
|
||
if json.Unmarshal(raw, &str) == nil {
|
||
continue
|
||
}
|
||
|
||
// Otherwise treat as an array of blocks and convert signature-sensitive blocks to text.
|
||
var blocks []map[string]any
|
||
if err := json.Unmarshal(raw, &blocks); err != nil {
|
||
continue
|
||
}
|
||
|
||
filtered := make([]map[string]any, 0, len(blocks))
|
||
modifiedAny := false
|
||
for _, block := range blocks {
|
||
t, _ := block["type"].(string)
|
||
switch t {
|
||
case "thinking":
|
||
// Convert thinking to text, skip if empty
|
||
thinkingText, _ := block["thinking"].(string)
|
||
if thinkingText != "" {
|
||
filtered = append(filtered, map[string]any{
|
||
"type": "text",
|
||
"text": thinkingText,
|
||
})
|
||
}
|
||
modifiedAny = true
|
||
case "redacted_thinking":
|
||
// Remove redacted_thinking (cannot convert encrypted content)
|
||
modifiedAny = true
|
||
case "tool_use":
|
||
// Convert tool_use to text to avoid upstream signature/thought_signature validation errors.
|
||
// This is a retry-only degradation path, so we prioritise request validity over tool semantics.
|
||
name, _ := block["name"].(string)
|
||
id, _ := block["id"].(string)
|
||
input := block["input"]
|
||
inputJSON, _ := json.Marshal(input)
|
||
text := "(tool_use)"
|
||
if name != "" {
|
||
text += " name=" + name
|
||
}
|
||
if id != "" {
|
||
text += " id=" + id
|
||
}
|
||
if len(inputJSON) > 0 && string(inputJSON) != "null" {
|
||
text += " input=" + string(inputJSON)
|
||
}
|
||
filtered = append(filtered, map[string]any{
|
||
"type": "text",
|
||
"text": text,
|
||
})
|
||
modifiedAny = true
|
||
case "tool_result":
|
||
// Convert tool_result to text so it stays consistent when tool_use is downgraded.
|
||
toolUseID, _ := block["tool_use_id"].(string)
|
||
isError, _ := block["is_error"].(bool)
|
||
content := block["content"]
|
||
contentJSON, _ := json.Marshal(content)
|
||
text := "(tool_result)"
|
||
if toolUseID != "" {
|
||
text += " tool_use_id=" + toolUseID
|
||
}
|
||
if isError {
|
||
text += " is_error=true"
|
||
}
|
||
if len(contentJSON) > 0 && string(contentJSON) != "null" {
|
||
text += "\n" + string(contentJSON)
|
||
}
|
||
filtered = append(filtered, map[string]any{
|
||
"type": "text",
|
||
"text": text,
|
||
})
|
||
modifiedAny = true
|
||
case "":
|
||
// Handle untyped block with "thinking" field
|
||
if thinkingText, hasThinking := block["thinking"].(string); hasThinking {
|
||
if thinkingText != "" {
|
||
filtered = append(filtered, map[string]any{
|
||
"type": "text",
|
||
"text": thinkingText,
|
||
})
|
||
}
|
||
modifiedAny = true
|
||
} else {
|
||
filtered = append(filtered, block)
|
||
}
|
||
default:
|
||
filtered = append(filtered, block)
|
||
}
|
||
}
|
||
|
||
if !modifiedAny {
|
||
continue
|
||
}
|
||
|
||
if len(filtered) == 0 {
|
||
// Keep request valid: upstream rejects empty content arrays.
|
||
filtered = append(filtered, map[string]any{
|
||
"type": "text",
|
||
"text": "(content removed)",
|
||
})
|
||
}
|
||
|
||
newRaw, err := json.Marshal(filtered)
|
||
if err != nil {
|
||
return changed, err
|
||
}
|
||
req.Messages[i].Content = newRaw
|
||
changed = true
|
||
}
|
||
|
||
return changed, nil
|
||
}
|
||
|
||
// ForwardGemini 转发 Gemini 协议请求
|
||
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) {
|
||
startTime := time.Now()
|
||
|
||
if account.Type == AccountTypeUpstream {
|
||
return s.ForwardUpstreamGemini(ctx, c, account, originalModel, action, stream, body, isStickySession)
|
||
}
|
||
|
||
sessionID := getSessionID(c)
|
||
prefix := logPrefix(sessionID, account.Name)
|
||
|
||
if strings.TrimSpace(originalModel) == "" {
|
||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing model in URL")
|
||
}
|
||
if strings.TrimSpace(action) == "" {
|
||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing action in URL")
|
||
}
|
||
if len(body) == 0 {
|
||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
|
||
}
|
||
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
|
||
|
||
// 解析请求以获取 image_size(用于图片计费)
|
||
imageSize := s.extractImageSize(body)
|
||
|
||
switch action {
|
||
case "generateContent", "streamGenerateContent":
|
||
// ok
|
||
case "countTokens":
|
||
// 直接返回空值,不透传上游
|
||
c.JSON(http.StatusOK, map[string]any{"totalTokens": 0})
|
||
return &ForwardResult{
|
||
RequestID: "",
|
||
Usage: ClaudeUsage{},
|
||
Model: originalModel,
|
||
Stream: false,
|
||
Duration: time.Since(time.Now()),
|
||
FirstTokenMs: nil,
|
||
}, nil
|
||
default:
|
||
return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action)
|
||
}
|
||
|
||
mappedModel := s.getMappedModel(account, originalModel)
|
||
if mappedModel == "" {
|
||
return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel))
|
||
}
|
||
|
||
// 获取 access_token
|
||
if s.tokenProvider == nil {
|
||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Antigravity token provider not configured")
|
||
}
|
||
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
|
||
if err != nil {
|
||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Failed to get upstream access token")
|
||
}
|
||
|
||
// 获取 project_id(部分账户类型可能没有)
|
||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||
|
||
// 代理 URL
|
||
proxyURL := ""
|
||
if account.ProxyID != nil && account.Proxy != nil {
|
||
proxyURL = account.Proxy.URL()
|
||
}
|
||
|
||
// Antigravity 上游要求必须包含身份提示词,注入到请求中
|
||
injectedBody, err := injectIdentityPatchToGeminiRequest(body)
|
||
if err != nil {
|
||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Invalid request body")
|
||
}
|
||
|
||
// 清理 Schema
|
||
if cleanedBody, err := cleanGeminiRequest(injectedBody); err == nil {
|
||
injectedBody = cleanedBody
|
||
log.Printf("[Antigravity] Cleaned request schema in forwarded request for account %s", account.Name)
|
||
} else {
|
||
log.Printf("[Antigravity] Failed to clean schema: %v", err)
|
||
}
|
||
|
||
// 包装请求
|
||
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody)
|
||
if err != nil {
|
||
return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build upstream request")
|
||
}
|
||
|
||
// Antigravity 上游只支持流式请求,统一使用 streamGenerateContent
|
||
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回
|
||
upstreamAction := "streamGenerateContent"
|
||
|
||
// 统计模型调用次数(包括粘性会话,用于负载均衡调度)
|
||
if s.cache != nil {
|
||
_, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel)
|
||
}
|
||
|
||
// 执行带重试的请求
|
||
result, err := s.antigravityRetryLoop(antigravityRetryLoopParams{
|
||
ctx: ctx,
|
||
prefix: prefix,
|
||
account: account,
|
||
proxyURL: proxyURL,
|
||
accessToken: accessToken,
|
||
action: upstreamAction,
|
||
body: wrappedBody,
|
||
quotaScope: quotaScope,
|
||
c: c,
|
||
httpUpstream: s.httpUpstream,
|
||
settingService: s.settingService,
|
||
accountRepo: s.accountRepo,
|
||
handleError: s.handleUpstreamError,
|
||
requestedModel: originalModel,
|
||
isStickySession: isStickySession, // ForwardGemini 由上层判断粘性会话
|
||
groupID: 0, // ForwardGemini 方法没有 groupID,由上层处理粘性会话清除
|
||
sessionHash: "", // ForwardGemini 方法没有 sessionHash,由上层处理粘性会话清除
|
||
})
|
||
if err != nil {
|
||
// 检查是否是账号切换信号,转换为 UpstreamFailoverError 让 Handler 切换账号
|
||
if switchErr, ok := IsAntigravityAccountSwitchError(err); ok {
|
||
return nil, &UpstreamFailoverError{
|
||
StatusCode: http.StatusServiceUnavailable,
|
||
ForceCacheBilling: switchErr.IsStickySession,
|
||
}
|
||
}
|
||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
|
||
}
|
||
resp := result.resp
|
||
defer func() {
|
||
if resp != nil && resp.Body != nil {
|
||
_ = resp.Body.Close()
|
||
}
|
||
}()
|
||
|
||
// 处理错误响应
|
||
if resp.StatusCode >= 400 {
|
||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||
contentType := resp.Header.Get("Content-Type")
|
||
// 尽早关闭原始响应体,释放连接;后续逻辑仍可能需要读取 body,因此用内存副本重新包装。
|
||
_ = resp.Body.Close()
|
||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||
|
||
// 模型兜底:模型不存在且开启 fallback 时,自动用 fallback 模型重试一次
|
||
if s.settingService != nil && s.settingService.IsModelFallbackEnabled(ctx) &&
|
||
isModelNotFoundError(resp.StatusCode, respBody) {
|
||
fallbackModel := s.settingService.GetFallbackModel(ctx, PlatformAntigravity)
|
||
if fallbackModel != "" && fallbackModel != mappedModel {
|
||
log.Printf("[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name)
|
||
|
||
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, injectedBody)
|
||
if err == nil {
|
||
fallbackReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, fallbackWrapped)
|
||
if err == nil {
|
||
fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency)
|
||
if err == nil && fallbackResp.StatusCode < 400 {
|
||
_ = resp.Body.Close()
|
||
resp = fallbackResp
|
||
} else if fallbackResp != nil {
|
||
_ = fallbackResp.Body.Close()
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// fallback 成功:继续按正常响应处理
|
||
if resp.StatusCode < 400 {
|
||
goto handleSuccess
|
||
}
|
||
|
||
requestID := resp.Header.Get("x-request-id")
|
||
if requestID != "" {
|
||
c.Header("x-request-id", requestID)
|
||
}
|
||
|
||
unwrapped, unwrapErr := s.unwrapV1InternalResponse(respBody)
|
||
unwrappedForOps := unwrapped
|
||
if unwrapErr != nil || len(unwrappedForOps) == 0 {
|
||
unwrappedForOps = respBody
|
||
}
|
||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession)
|
||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(unwrappedForOps))
|
||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||
upstreamDetail := s.getUpstreamErrorDetail(unwrappedForOps)
|
||
|
||
// Always record upstream context for Ops error logs, even when we will failover.
|
||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||
|
||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||
Platform: account.Platform,
|
||
AccountID: account.ID,
|
||
AccountName: account.Name,
|
||
UpstreamStatusCode: resp.StatusCode,
|
||
UpstreamRequestID: requestID,
|
||
Kind: "failover",
|
||
Message: upstreamMsg,
|
||
Detail: upstreamDetail,
|
||
})
|
||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: unwrappedForOps}
|
||
}
|
||
if contentType == "" {
|
||
contentType = "application/json"
|
||
}
|
||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||
Platform: account.Platform,
|
||
AccountID: account.ID,
|
||
AccountName: account.Name,
|
||
UpstreamStatusCode: resp.StatusCode,
|
||
UpstreamRequestID: requestID,
|
||
Kind: "http_error",
|
||
Message: upstreamMsg,
|
||
Detail: upstreamDetail,
|
||
})
|
||
log.Printf("[antigravity-Forward] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(unwrappedForOps, 500))
|
||
c.Data(resp.StatusCode, contentType, unwrappedForOps)
|
||
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
|
||
}
|
||
|
||
handleSuccess:
|
||
requestID := resp.Header.Get("x-request-id")
|
||
if requestID != "" {
|
||
c.Header("x-request-id", requestID)
|
||
}
|
||
|
||
var usage *ClaudeUsage
|
||
var firstTokenMs *int
|
||
|
||
if stream {
|
||
// 客户端要求流式,直接透传
|
||
streamRes, err := s.handleGeminiStreamingResponse(c, resp, startTime)
|
||
if err != nil {
|
||
log.Printf("%s status=stream_error error=%v", prefix, err)
|
||
return nil, err
|
||
}
|
||
usage = streamRes.usage
|
||
firstTokenMs = streamRes.firstTokenMs
|
||
} else {
|
||
// 客户端要求非流式,收集流式响应后返回
|
||
streamRes, err := s.handleGeminiStreamToNonStreaming(c, resp, startTime)
|
||
if err != nil {
|
||
log.Printf("%s status=stream_collect_error error=%v", prefix, err)
|
||
return nil, err
|
||
}
|
||
usage = streamRes.usage
|
||
firstTokenMs = streamRes.firstTokenMs
|
||
}
|
||
|
||
if usage == nil {
|
||
usage = &ClaudeUsage{}
|
||
}
|
||
|
||
// 判断是否为图片生成模型
|
||
imageCount := 0
|
||
if isImageGenerationModel(mappedModel) {
|
||
// Gemini 图片生成 API 每次请求只生成一张图片(API 限制)
|
||
imageCount = 1
|
||
}
|
||
|
||
return &ForwardResult{
|
||
RequestID: requestID,
|
||
Usage: *usage,
|
||
Model: originalModel,
|
||
Stream: stream,
|
||
Duration: time.Since(startTime),
|
||
FirstTokenMs: firstTokenMs,
|
||
ImageCount: imageCount,
|
||
ImageSize: imageSize,
|
||
}, nil
|
||
}
|
||
|
||
func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
||
switch statusCode {
|
||
case 401, 403, 429, 529:
|
||
return true
|
||
default:
|
||
return statusCode >= 500
|
||
}
|
||
}
|
||
|
||
// sleepAntigravityBackoffWithContext 带 context 取消检查的退避等待
|
||
// 返回 true 表示正常完成等待,false 表示 context 已取消
|
||
func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
|
||
delay := antigravityRetryBaseDelay * time.Duration(1<<uint(attempt-1))
|
||
if delay > antigravityRetryMaxDelay {
|
||
delay = antigravityRetryMaxDelay
|
||
}
|
||
|
||
// +/- 20% jitter
|
||
r := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
|
||
jitter := time.Duration(float64(delay) * 0.2 * (r.Float64()*2 - 1))
|
||
sleepFor := delay + jitter
|
||
if sleepFor < 0 {
|
||
sleepFor = 0
|
||
}
|
||
|
||
select {
|
||
case <-ctx.Done():
|
||
return false
|
||
case <-time.After(sleepFor):
|
||
return true
|
||
}
|
||
}
|
||
|
||
// setModelRateLimitByModelName 使用官方模型 ID 设置模型级限流
|
||
// 直接使用上游返回的模型 ID(如 claude-sonnet-4-5)作为限流 key
|
||
// 返回是否已成功设置(若模型名为空或 repo 为 nil 将返回 false)
|
||
func setModelRateLimitByModelName(ctx context.Context, repo AccountRepository, accountID int64, modelName, prefix string, statusCode int, resetAt time.Time, afterSmartRetry bool) bool {
|
||
if repo == nil || modelName == "" {
|
||
return false
|
||
}
|
||
// 直接使用官方模型 ID 作为 key,不再转换为 scope
|
||
if err := repo.SetModelRateLimit(ctx, accountID, modelName, resetAt); err != nil {
|
||
log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", prefix, statusCode, modelName, err)
|
||
return false
|
||
}
|
||
if afterSmartRetry {
|
||
log.Printf("%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second))
|
||
} else {
|
||
log.Printf("%s status=%d model_rate_limited model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second))
|
||
}
|
||
return true
|
||
}
|
||
|
||
func antigravityFallbackCooldownSeconds() (time.Duration, bool) {
|
||
raw := strings.TrimSpace(os.Getenv(antigravityFallbackSecondsEnv))
|
||
if raw == "" {
|
||
return 0, false
|
||
}
|
||
seconds, err := strconv.Atoi(raw)
|
||
if err != nil || seconds <= 0 {
|
||
return 0, false
|
||
}
|
||
return time.Duration(seconds) * time.Second, true
|
||
}
|
||
|
||
// antigravitySmartRetryInfo 智能重试所需的信息
|
||
type antigravitySmartRetryInfo struct {
|
||
RetryDelay time.Duration // 重试延迟时间
|
||
ModelName string // 限流的模型名称(如 "claude-sonnet-4-5")
|
||
}
|
||
|
||
// parseAntigravitySmartRetryInfo 解析 Google RPC RetryInfo 和 ErrorInfo 信息
|
||
// 返回解析结果,如果解析失败或不满足条件返回 nil
|
||
//
|
||
// 支持两种情况:
|
||
// 1. 429 RESOURCE_EXHAUSTED + RATE_LIMIT_EXCEEDED:
|
||
// - error.status == "RESOURCE_EXHAUSTED"
|
||
// - error.details[].reason == "RATE_LIMIT_EXCEEDED"
|
||
//
|
||
// 2. 503 UNAVAILABLE + MODEL_CAPACITY_EXHAUSTED:
|
||
// - error.status == "UNAVAILABLE"
|
||
// - error.details[].reason == "MODEL_CAPACITY_EXHAUSTED"
|
||
//
|
||
// 必须满足以下条件才会返回有效值:
|
||
// - error.details[] 中存在 @type == "type.googleapis.com/google.rpc.RetryInfo" 的元素
|
||
// - 该元素包含 retryDelay 字段,格式为 "数字s"(如 "0.201506475s")
|
||
func parseAntigravitySmartRetryInfo(body []byte) *antigravitySmartRetryInfo {
|
||
var parsed map[string]any
|
||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||
return nil
|
||
}
|
||
|
||
errObj, ok := parsed["error"].(map[string]any)
|
||
if !ok {
|
||
return nil
|
||
}
|
||
|
||
// 检查 status 是否符合条件
|
||
// 情况1: 429 RESOURCE_EXHAUSTED (需要进一步检查 reason == RATE_LIMIT_EXCEEDED)
|
||
// 情况2: 503 UNAVAILABLE (需要进一步检查 reason == MODEL_CAPACITY_EXHAUSTED)
|
||
status, _ := errObj["status"].(string)
|
||
isResourceExhausted := status == googleRPCStatusResourceExhausted
|
||
isUnavailable := status == googleRPCStatusUnavailable
|
||
|
||
if !isResourceExhausted && !isUnavailable {
|
||
return nil
|
||
}
|
||
|
||
details, ok := errObj["details"].([]any)
|
||
if !ok {
|
||
return nil
|
||
}
|
||
|
||
var retryDelay time.Duration
|
||
var modelName string
|
||
var hasRateLimitExceeded bool // 429 需要此 reason
|
||
var hasModelCapacityExhausted bool // 503 需要此 reason
|
||
|
||
for _, d := range details {
|
||
dm, ok := d.(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
atType, _ := dm["@type"].(string)
|
||
|
||
// 从 ErrorInfo 提取模型名称和 reason
|
||
if atType == googleRPCTypeErrorInfo {
|
||
if meta, ok := dm["metadata"].(map[string]any); ok {
|
||
if model, ok := meta["model"].(string); ok {
|
||
modelName = model
|
||
}
|
||
}
|
||
// 检查 reason
|
||
if reason, ok := dm["reason"].(string); ok {
|
||
if reason == googleRPCReasonModelCapacityExhausted {
|
||
hasModelCapacityExhausted = true
|
||
}
|
||
if reason == googleRPCReasonRateLimitExceeded {
|
||
hasRateLimitExceeded = true
|
||
}
|
||
}
|
||
continue
|
||
}
|
||
|
||
// 从 RetryInfo 提取重试延迟
|
||
if atType == googleRPCTypeRetryInfo {
|
||
delay, ok := dm["retryDelay"].(string)
|
||
if !ok || delay == "" {
|
||
continue
|
||
}
|
||
// 使用 time.ParseDuration 解析,支持所有 Go duration 格式
|
||
// 例如: "0.5s", "10s", "4m50s", "1h30m", "200ms" 等
|
||
dur, err := time.ParseDuration(delay)
|
||
if err != nil {
|
||
log.Printf("[Antigravity] failed to parse retryDelay: %s error=%v", delay, err)
|
||
continue
|
||
}
|
||
retryDelay = dur
|
||
}
|
||
}
|
||
|
||
// 验证条件
|
||
// 情况1: RESOURCE_EXHAUSTED 需要有 RATE_LIMIT_EXCEEDED reason
|
||
// 情况2: UNAVAILABLE 需要有 MODEL_CAPACITY_EXHAUSTED reason
|
||
if isResourceExhausted && !hasRateLimitExceeded {
|
||
return nil
|
||
}
|
||
if isUnavailable && !hasModelCapacityExhausted {
|
||
return nil
|
||
}
|
||
|
||
// 必须有模型名才返回有效结果
|
||
if modelName == "" {
|
||
return nil
|
||
}
|
||
|
||
// 如果上游未提供 retryDelay,使用默认限流时间
|
||
if retryDelay <= 0 {
|
||
retryDelay = antigravityDefaultRateLimitDuration
|
||
}
|
||
|
||
return &antigravitySmartRetryInfo{
|
||
RetryDelay: retryDelay,
|
||
ModelName: modelName,
|
||
}
|
||
}
|
||
|
||
// shouldTriggerAntigravitySmartRetry 判断是否应该触发智能重试
|
||
// 返回:
|
||
// - shouldRetry: 是否应该智能重试(retryDelay < antigravityRateLimitThreshold)
|
||
// - shouldRateLimitModel: 是否应该限流模型(retryDelay >= antigravityRateLimitThreshold)
|
||
// - waitDuration: 等待时间(智能重试时使用,shouldRateLimitModel=true 时为 0)
|
||
// - modelName: 限流的模型名称
|
||
func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shouldRetry bool, shouldRateLimitModel bool, waitDuration time.Duration, modelName string) {
|
||
if account.Platform != PlatformAntigravity {
|
||
return false, false, 0, ""
|
||
}
|
||
|
||
info := parseAntigravitySmartRetryInfo(respBody)
|
||
if info == nil {
|
||
return false, false, 0, ""
|
||
}
|
||
|
||
// retryDelay >= 阈值:直接限流模型,不重试
|
||
// 注意:如果上游未提供 retryDelay,parseAntigravitySmartRetryInfo 已设置为默认 5 分钟
|
||
if info.RetryDelay >= antigravityRateLimitThreshold {
|
||
return false, true, 0, info.ModelName
|
||
}
|
||
|
||
// retryDelay < 阈值:智能重试
|
||
waitDuration = info.RetryDelay
|
||
if waitDuration < antigravitySmartRetryMinWait {
|
||
waitDuration = antigravitySmartRetryMinWait
|
||
}
|
||
|
||
return true, false, waitDuration, info.ModelName
|
||
}
|
||
|
||
// handleModelRateLimitParams 模型级限流处理参数
|
||
type handleModelRateLimitParams struct {
|
||
ctx context.Context
|
||
prefix string
|
||
account *Account
|
||
statusCode int
|
||
body []byte
|
||
cache GatewayCache
|
||
groupID int64
|
||
sessionHash string
|
||
isStickySession bool
|
||
}
|
||
|
||
// handleModelRateLimitResult 模型级限流处理结果
|
||
type handleModelRateLimitResult struct {
|
||
Handled bool // 是否已处理
|
||
ShouldRetry bool // 是否等待后重试
|
||
WaitDuration time.Duration // 等待时间
|
||
SwitchError *AntigravityAccountSwitchError // 账号切换错误
|
||
}
|
||
|
||
// handleModelRateLimit 处理模型级限流(在原有逻辑之前调用)
|
||
// 仅处理 429/503,解析模型名和 retryDelay
|
||
// - retryDelay < antigravityRateLimitThreshold: 返回 ShouldRetry=true,由调用方等待后重试
|
||
// - retryDelay >= antigravityRateLimitThreshold: 设置模型限流 + 清除粘性会话 + 返回 SwitchError
|
||
func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimitParams) *handleModelRateLimitResult {
|
||
if p.statusCode != 429 && p.statusCode != 503 {
|
||
return &handleModelRateLimitResult{Handled: false}
|
||
}
|
||
|
||
info := parseAntigravitySmartRetryInfo(p.body)
|
||
if info == nil || info.ModelName == "" {
|
||
return &handleModelRateLimitResult{Handled: false}
|
||
}
|
||
|
||
// < antigravityRateLimitThreshold: 等待后重试
|
||
if info.RetryDelay < antigravityRateLimitThreshold {
|
||
log.Printf("%s status=%d model_rate_limit_wait model=%s wait=%v",
|
||
p.prefix, p.statusCode, info.ModelName, info.RetryDelay)
|
||
return &handleModelRateLimitResult{
|
||
Handled: true,
|
||
ShouldRetry: true,
|
||
WaitDuration: info.RetryDelay,
|
||
}
|
||
}
|
||
|
||
// >= antigravityRateLimitThreshold: 设置限流 + 清除粘性会话 + 切换账号
|
||
s.setModelRateLimitAndClearSession(p, info)
|
||
|
||
return &handleModelRateLimitResult{
|
||
Handled: true,
|
||
SwitchError: &AntigravityAccountSwitchError{
|
||
OriginalAccountID: p.account.ID,
|
||
RateLimitedModel: info.ModelName,
|
||
IsStickySession: p.isStickySession,
|
||
},
|
||
}
|
||
}
|
||
|
||
// setModelRateLimitAndClearSession 设置模型限流并清除粘性会话
|
||
func (s *AntigravityGatewayService) setModelRateLimitAndClearSession(p *handleModelRateLimitParams, info *antigravitySmartRetryInfo) {
|
||
resetAt := time.Now().Add(info.RetryDelay)
|
||
log.Printf("%s status=%d model_rate_limited model=%s account=%d reset_in=%v",
|
||
p.prefix, p.statusCode, info.ModelName, p.account.ID, info.RetryDelay)
|
||
|
||
// 设置模型限流状态(数据库)
|
||
if err := s.accountRepo.SetModelRateLimit(p.ctx, p.account.ID, info.ModelName, resetAt); err != nil {
|
||
log.Printf("%s model_rate_limit_failed model=%s error=%v", p.prefix, info.ModelName, err)
|
||
}
|
||
|
||
// 立即更新 Redis 快照中账号的限流状态,避免并发请求重复选中
|
||
s.updateAccountModelRateLimitInCache(p.ctx, p.account, info.ModelName, resetAt)
|
||
|
||
// 清除粘性会话绑定
|
||
if p.cache != nil && p.sessionHash != "" {
|
||
_ = p.cache.DeleteSessionAccountID(p.ctx, p.groupID, p.sessionHash)
|
||
}
|
||
}
|
||
|
||
// updateAccountModelRateLimitInCache 立即更新 Redis 中账号的模型限流状态
|
||
func (s *AntigravityGatewayService) updateAccountModelRateLimitInCache(ctx context.Context, account *Account, modelKey string, resetAt time.Time) {
|
||
if s.schedulerSnapshot == nil || account == nil || modelKey == "" {
|
||
return
|
||
}
|
||
|
||
// 更新账号对象的 Extra 字段
|
||
if account.Extra == nil {
|
||
account.Extra = make(map[string]any)
|
||
}
|
||
|
||
limits, _ := account.Extra["model_rate_limits"].(map[string]any)
|
||
if limits == nil {
|
||
limits = make(map[string]any)
|
||
account.Extra["model_rate_limits"] = limits
|
||
}
|
||
|
||
limits[modelKey] = map[string]any{
|
||
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
|
||
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
|
||
}
|
||
|
||
// 更新 Redis 快照
|
||
if err := s.schedulerSnapshot.UpdateAccountInCache(ctx, account); err != nil {
|
||
log.Printf("[antigravity-Forward] cache_update_failed account=%d model=%s err=%v", account.ID, modelKey, err)
|
||
}
|
||
}
|
||
|
||
func (s *AntigravityGatewayService) handleUpstreamError(
|
||
ctx context.Context, prefix string, account *Account,
|
||
statusCode int, headers http.Header, body []byte,
|
||
quotaScope AntigravityQuotaScope,
|
||
groupID int64, sessionHash string, isStickySession bool,
|
||
) *handleModelRateLimitResult {
|
||
// ✨ 模型级限流处理(在原有逻辑之前)
|
||
result := s.handleModelRateLimit(&handleModelRateLimitParams{
|
||
ctx: ctx,
|
||
prefix: prefix,
|
||
account: account,
|
||
statusCode: statusCode,
|
||
body: body,
|
||
cache: s.cache,
|
||
groupID: groupID,
|
||
sessionHash: sessionHash,
|
||
isStickySession: isStickySession,
|
||
})
|
||
if result.Handled {
|
||
return result
|
||
}
|
||
|
||
// 503 仅处理模型限流(MODEL_CAPACITY_EXHAUSTED),非模型限流不做额外处理
|
||
// 避免将普通的 503 错误误判为账号问题
|
||
if statusCode == 503 {
|
||
return nil
|
||
}
|
||
|
||
// ========== 原有逻辑,保持不变 ==========
|
||
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
|
||
if statusCode == 429 {
|
||
// 调试日志遵循统一日志开关与长度限制,避免无条件记录完整上游响应体。
|
||
if logBody, maxBytes := s.getLogConfig(); logBody {
|
||
log.Printf("[Antigravity-Debug] 429 response body: %s", truncateString(string(body), maxBytes))
|
||
}
|
||
|
||
useScopeLimit := quotaScope != ""
|
||
resetAt := ParseGeminiRateLimitResetTime(body)
|
||
if resetAt == nil {
|
||
// 解析失败:使用默认限流时间(与临时限流保持一致)
|
||
// 可通过配置或环境变量覆盖
|
||
defaultDur := antigravityDefaultRateLimitDuration
|
||
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes > 0 {
|
||
defaultDur = time.Duration(s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes) * time.Minute
|
||
}
|
||
// 秒级环境变量优先级最高
|
||
if override, ok := antigravityFallbackCooldownSeconds(); ok {
|
||
defaultDur = override
|
||
}
|
||
ra := time.Now().Add(defaultDur)
|
||
if useScopeLimit {
|
||
log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur)
|
||
if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, ra); err != nil {
|
||
log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
|
||
}
|
||
} else {
|
||
log.Printf("%s status=429 rate_limited account=%d reset_in=%v (fallback)", prefix, account.ID, defaultDur)
|
||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, ra); err != nil {
|
||
log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
resetTime := time.Unix(*resetAt, 0)
|
||
if useScopeLimit {
|
||
log.Printf("%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v", prefix, quotaScope, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
|
||
if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, resetTime); err != nil {
|
||
log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
|
||
}
|
||
} else {
|
||
log.Printf("%s status=429 rate_limited account=%d reset_at=%v reset_in=%v", prefix, account.ID, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
|
||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
|
||
log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
// 其他错误码继续使用 rateLimitService
|
||
if s.rateLimitService == nil {
|
||
return nil
|
||
}
|
||
shouldDisable := s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
|
||
if shouldDisable {
|
||
log.Printf("%s status=%d marked_error", prefix, statusCode)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
type antigravityStreamResult struct {
|
||
usage *ClaudeUsage
|
||
firstTokenMs *int
|
||
}
|
||
|
||
func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) {
|
||
c.Status(resp.StatusCode)
|
||
c.Header("Cache-Control", "no-cache")
|
||
c.Header("Connection", "keep-alive")
|
||
c.Header("X-Accel-Buffering", "no")
|
||
|
||
contentType := resp.Header.Get("Content-Type")
|
||
if contentType == "" {
|
||
contentType = "text/event-stream; charset=utf-8"
|
||
}
|
||
c.Header("Content-Type", contentType)
|
||
|
||
flusher, ok := c.Writer.(http.Flusher)
|
||
if !ok {
|
||
return nil, errors.New("streaming not supported")
|
||
}
|
||
|
||
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
|
||
scanner := bufio.NewScanner(resp.Body)
|
||
maxLineSize := defaultMaxLineSize
|
||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
||
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
|
||
}
|
||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||
usage := &ClaudeUsage{}
|
||
var firstTokenMs *int
|
||
|
||
type scanEvent struct {
|
||
line string
|
||
err error
|
||
}
|
||
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
|
||
events := make(chan scanEvent, 16)
|
||
done := make(chan struct{})
|
||
sendEvent := func(ev scanEvent) bool {
|
||
select {
|
||
case events <- ev:
|
||
return true
|
||
case <-done:
|
||
return false
|
||
}
|
||
}
|
||
var lastReadAt int64
|
||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||
go func() {
|
||
defer close(events)
|
||
for scanner.Scan() {
|
||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||
return
|
||
}
|
||
}
|
||
if err := scanner.Err(); err != nil {
|
||
_ = sendEvent(scanEvent{err: err})
|
||
}
|
||
}()
|
||
defer close(done)
|
||
|
||
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
|
||
streamInterval := time.Duration(0)
|
||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||
}
|
||
var intervalTicker *time.Ticker
|
||
if streamInterval > 0 {
|
||
intervalTicker = time.NewTicker(streamInterval)
|
||
defer intervalTicker.Stop()
|
||
}
|
||
var intervalCh <-chan time.Time
|
||
if intervalTicker != nil {
|
||
intervalCh = intervalTicker.C
|
||
}
|
||
|
||
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
||
errorEventSent := false
|
||
sendErrorEvent := func(reason string) {
|
||
if errorEventSent {
|
||
return
|
||
}
|
||
errorEventSent = true
|
||
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
||
flusher.Flush()
|
||
}
|
||
|
||
for {
|
||
select {
|
||
case ev, ok := <-events:
|
||
if !ok {
|
||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||
}
|
||
if ev.err != nil {
|
||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
|
||
sendErrorEvent("response_too_large")
|
||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
|
||
}
|
||
sendErrorEvent("stream_read_error")
|
||
return nil, ev.err
|
||
}
|
||
|
||
line := ev.line
|
||
trimmed := strings.TrimRight(line, "\r\n")
|
||
if strings.HasPrefix(trimmed, "data:") {
|
||
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
|
||
if payload == "" || payload == "[DONE]" {
|
||
if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil {
|
||
sendErrorEvent("write_failed")
|
||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||
}
|
||
flusher.Flush()
|
||
continue
|
||
}
|
||
|
||
// 解包 v1internal 响应
|
||
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
|
||
if parseErr == nil && inner != nil {
|
||
payload = string(inner)
|
||
}
|
||
|
||
// 解析 usage
|
||
var parsed map[string]any
|
||
if json.Unmarshal(inner, &parsed) == nil {
|
||
if u := extractGeminiUsage(parsed); u != nil {
|
||
usage = u
|
||
}
|
||
// Check for MALFORMED_FUNCTION_CALL
|
||
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
|
||
if cand, ok := candidates[0].(map[string]any); ok {
|
||
if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" {
|
||
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward stream")
|
||
if content, ok := cand["content"]; ok {
|
||
if b, err := json.Marshal(content); err == nil {
|
||
log.Printf("[Antigravity] Malformed content: %s", string(b))
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if firstTokenMs == nil {
|
||
ms := int(time.Since(startTime).Milliseconds())
|
||
firstTokenMs = &ms
|
||
}
|
||
|
||
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", payload); err != nil {
|
||
sendErrorEvent("write_failed")
|
||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||
}
|
||
flusher.Flush()
|
||
continue
|
||
}
|
||
|
||
if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil {
|
||
sendErrorEvent("write_failed")
|
||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||
}
|
||
flusher.Flush()
|
||
|
||
case <-intervalCh:
|
||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||
if time.Since(lastRead) < streamInterval {
|
||
continue
|
||
}
|
||
log.Printf("Stream data interval timeout (antigravity)")
|
||
// 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout
|
||
sendErrorEvent("stream_timeout")
|
||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||
}
|
||
}
|
||
}
|
||
|
||
// handleGeminiStreamToNonStreaming 读取上游流式响应,合并为非流式响应返回给客户端
|
||
// Gemini 流式响应是增量的,需要累积所有 chunk 的内容
|
||
func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) {
|
||
scanner := bufio.NewScanner(resp.Body)
|
||
maxLineSize := defaultMaxLineSize
|
||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
||
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
|
||
}
|
||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||
|
||
usage := &ClaudeUsage{}
|
||
var firstTokenMs *int
|
||
var last map[string]any
|
||
var lastWithParts map[string]any
|
||
var collectedImageParts []map[string]any // 收集所有包含图片的 parts
|
||
var collectedTextParts []string // 收集所有文本片段
|
||
|
||
type scanEvent struct {
|
||
line string
|
||
err error
|
||
}
|
||
|
||
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
|
||
events := make(chan scanEvent, 16)
|
||
done := make(chan struct{})
|
||
sendEvent := func(ev scanEvent) bool {
|
||
select {
|
||
case events <- ev:
|
||
return true
|
||
case <-done:
|
||
return false
|
||
}
|
||
}
|
||
|
||
var lastReadAt int64
|
||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||
go func() {
|
||
defer close(events)
|
||
for scanner.Scan() {
|
||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||
return
|
||
}
|
||
}
|
||
if err := scanner.Err(); err != nil {
|
||
_ = sendEvent(scanEvent{err: err})
|
||
}
|
||
}()
|
||
defer close(done)
|
||
|
||
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
|
||
streamInterval := time.Duration(0)
|
||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||
}
|
||
var intervalTicker *time.Ticker
|
||
if streamInterval > 0 {
|
||
intervalTicker = time.NewTicker(streamInterval)
|
||
defer intervalTicker.Stop()
|
||
}
|
||
var intervalCh <-chan time.Time
|
||
if intervalTicker != nil {
|
||
intervalCh = intervalTicker.C
|
||
}
|
||
|
||
for {
|
||
select {
|
||
case ev, ok := <-events:
|
||
if !ok {
|
||
// 流结束,返回收集的响应
|
||
goto returnResponse
|
||
}
|
||
if ev.err != nil {
|
||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||
log.Printf("SSE line too long (antigravity non-stream): max_size=%d error=%v", maxLineSize, ev.err)
|
||
}
|
||
return nil, ev.err
|
||
}
|
||
|
||
line := ev.line
|
||
trimmed := strings.TrimRight(line, "\r\n")
|
||
|
||
if !strings.HasPrefix(trimmed, "data:") {
|
||
continue
|
||
}
|
||
|
||
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
|
||
if payload == "" || payload == "[DONE]" {
|
||
continue
|
||
}
|
||
|
||
// 解包 v1internal 响应
|
||
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
|
||
if parseErr != nil {
|
||
continue
|
||
}
|
||
|
||
var parsed map[string]any
|
||
if err := json.Unmarshal(inner, &parsed); err != nil {
|
||
continue
|
||
}
|
||
|
||
// 记录首 token 时间
|
||
if firstTokenMs == nil {
|
||
ms := int(time.Since(startTime).Milliseconds())
|
||
firstTokenMs = &ms
|
||
}
|
||
|
||
last = parsed
|
||
|
||
// 提取 usage
|
||
if u := extractGeminiUsage(parsed); u != nil {
|
||
usage = u
|
||
}
|
||
|
||
// Check for MALFORMED_FUNCTION_CALL
|
||
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
|
||
if cand, ok := candidates[0].(map[string]any); ok {
|
||
if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" {
|
||
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward non-stream collect")
|
||
if content, ok := cand["content"]; ok {
|
||
if b, err := json.Marshal(content); err == nil {
|
||
log.Printf("[Antigravity] Malformed content: %s", string(b))
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 保留最后一个有 parts 的响应
|
||
if parts := extractGeminiParts(parsed); len(parts) > 0 {
|
||
lastWithParts = parsed
|
||
// 收集包含图片和文本的 parts
|
||
for _, part := range parts {
|
||
if inlineData, ok := part["inlineData"].(map[string]any); ok {
|
||
collectedImageParts = append(collectedImageParts, part)
|
||
_ = inlineData // 避免 unused 警告
|
||
}
|
||
if text, ok := part["text"].(string); ok && text != "" {
|
||
collectedTextParts = append(collectedTextParts, text)
|
||
}
|
||
}
|
||
}
|
||
|
||
case <-intervalCh:
|
||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||
if time.Since(lastRead) < streamInterval {
|
||
continue
|
||
}
|
||
log.Printf("Stream data interval timeout (antigravity non-stream)")
|
||
return nil, fmt.Errorf("stream data interval timeout")
|
||
}
|
||
}
|
||
|
||
returnResponse:
|
||
// 选择最后一个有效响应
|
||
finalResponse := pickGeminiCollectResult(last, lastWithParts)
|
||
|
||
// 处理空响应情况
|
||
if last == nil && lastWithParts == nil {
|
||
log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received")
|
||
}
|
||
|
||
// 如果收集到了图片 parts,需要合并到最终响应中
|
||
if len(collectedImageParts) > 0 {
|
||
finalResponse = mergeImagePartsToResponse(finalResponse, collectedImageParts)
|
||
}
|
||
|
||
// 如果收集到了文本,需要合并到最终响应中
|
||
if len(collectedTextParts) > 0 {
|
||
finalResponse = mergeTextPartsToResponse(finalResponse, collectedTextParts)
|
||
}
|
||
|
||
respBody, err := json.Marshal(finalResponse)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to marshal response: %w", err)
|
||
}
|
||
c.Data(http.StatusOK, "application/json", respBody)
|
||
|
||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||
}
|
||
|
||
// getOrCreateGeminiParts 获取 Gemini 响应的 parts 结构,返回深拷贝和更新回调
|
||
func getOrCreateGeminiParts(response map[string]any) (result map[string]any, existingParts []any, setParts func([]any)) {
|
||
// 深拷贝 response
|
||
result = make(map[string]any)
|
||
for k, v := range response {
|
||
result[k] = v
|
||
}
|
||
|
||
// 获取或创建 candidates
|
||
candidates, ok := result["candidates"].([]any)
|
||
if !ok || len(candidates) == 0 {
|
||
candidates = []any{map[string]any{}}
|
||
}
|
||
|
||
// 获取第一个 candidate
|
||
candidate, ok := candidates[0].(map[string]any)
|
||
if !ok {
|
||
candidate = make(map[string]any)
|
||
candidates[0] = candidate
|
||
}
|
||
|
||
// 获取或创建 content
|
||
content, ok := candidate["content"].(map[string]any)
|
||
if !ok {
|
||
content = map[string]any{"role": "model"}
|
||
candidate["content"] = content
|
||
}
|
||
|
||
// 获取现有 parts
|
||
existingParts, ok = content["parts"].([]any)
|
||
if !ok {
|
||
existingParts = []any{}
|
||
}
|
||
|
||
// 返回更新回调
|
||
setParts = func(newParts []any) {
|
||
content["parts"] = newParts
|
||
result["candidates"] = candidates
|
||
}
|
||
|
||
return result, existingParts, setParts
|
||
}
|
||
|
||
// mergeCollectedPartsToResponse 将收集的所有 parts 合并到 Gemini 响应中
|
||
// 这个函数会合并所有类型的 parts:text、thinking、functionCall、inlineData 等
|
||
// 保持原始顺序,只合并连续的普通 text parts
|
||
func mergeCollectedPartsToResponse(response map[string]any, collectedParts []map[string]any) map[string]any {
|
||
if len(collectedParts) == 0 {
|
||
return response
|
||
}
|
||
|
||
result, _, setParts := getOrCreateGeminiParts(response)
|
||
|
||
// 合并策略:
|
||
// 1. 保持原始顺序
|
||
// 2. 连续的普通 text parts 合并为一个
|
||
// 3. thinking、functionCall、inlineData 等保持原样
|
||
var mergedParts []any
|
||
var textBuffer strings.Builder
|
||
|
||
flushTextBuffer := func() {
|
||
if textBuffer.Len() > 0 {
|
||
mergedParts = append(mergedParts, map[string]any{
|
||
"text": textBuffer.String(),
|
||
})
|
||
textBuffer.Reset()
|
||
}
|
||
}
|
||
|
||
for _, part := range collectedParts {
|
||
// 检查是否是普通 text part
|
||
if text, ok := part["text"].(string); ok {
|
||
// 检查是否有 thought 标记
|
||
if thought, _ := part["thought"].(bool); thought {
|
||
// thinking part,先刷新 text buffer,然后保留原样
|
||
flushTextBuffer()
|
||
mergedParts = append(mergedParts, part)
|
||
} else {
|
||
// 普通 text,累积到 buffer
|
||
_, _ = textBuffer.WriteString(text)
|
||
}
|
||
} else {
|
||
// 非 text part(functionCall、inlineData 等),先刷新 text buffer,然后保留原样
|
||
flushTextBuffer()
|
||
mergedParts = append(mergedParts, part)
|
||
}
|
||
}
|
||
|
||
// 刷新剩余的 text
|
||
flushTextBuffer()
|
||
|
||
setParts(mergedParts)
|
||
return result
|
||
}
|
||
|
||
// mergeImagePartsToResponse 将收集到的图片 parts 合并到 Gemini 响应中
|
||
func mergeImagePartsToResponse(response map[string]any, imageParts []map[string]any) map[string]any {
|
||
if len(imageParts) == 0 {
|
||
return response
|
||
}
|
||
|
||
result, existingParts, setParts := getOrCreateGeminiParts(response)
|
||
|
||
// 检查现有 parts 中是否已经有图片
|
||
for _, p := range existingParts {
|
||
if pm, ok := p.(map[string]any); ok {
|
||
if _, hasInline := pm["inlineData"]; hasInline {
|
||
return result // 已有图片,不重复添加
|
||
}
|
||
}
|
||
}
|
||
|
||
// 添加收集到的图片 parts
|
||
for _, imgPart := range imageParts {
|
||
existingParts = append(existingParts, imgPart)
|
||
}
|
||
setParts(existingParts)
|
||
return result
|
||
}
|
||
|
||
// mergeTextPartsToResponse 将收集到的文本合并到 Gemini 响应中
|
||
func mergeTextPartsToResponse(response map[string]any, textParts []string) map[string]any {
|
||
if len(textParts) == 0 {
|
||
return response
|
||
}
|
||
|
||
mergedText := strings.Join(textParts, "")
|
||
result, existingParts, setParts := getOrCreateGeminiParts(response)
|
||
|
||
// 查找并更新第一个 text part,或创建新的
|
||
newParts := make([]any, 0, len(existingParts)+1)
|
||
textUpdated := false
|
||
|
||
for _, p := range existingParts {
|
||
pm, ok := p.(map[string]any)
|
||
if !ok {
|
||
newParts = append(newParts, p)
|
||
continue
|
||
}
|
||
if _, hasText := pm["text"]; hasText && !textUpdated {
|
||
// 用累积的文本替换
|
||
newPart := make(map[string]any)
|
||
for k, v := range pm {
|
||
newPart[k] = v
|
||
}
|
||
newPart["text"] = mergedText
|
||
newParts = append(newParts, newPart)
|
||
textUpdated = true
|
||
} else {
|
||
newParts = append(newParts, pm)
|
||
}
|
||
}
|
||
|
||
if !textUpdated {
|
||
newParts = append([]any{map[string]any{"text": mergedText}}, newParts...)
|
||
}
|
||
|
||
setParts(newParts)
|
||
return result
|
||
}
|
||
|
||
func (s *AntigravityGatewayService) writeClaudeError(c *gin.Context, status int, errType, message string) error {
|
||
c.JSON(status, gin.H{
|
||
"type": "error",
|
||
"error": gin.H{"type": errType, "message": message},
|
||
})
|
||
return fmt.Errorf("%s", message)
|
||
}
|
||
|
||
// WriteMappedClaudeError 导出版本,供 handler 层使用(如 fallback 错误处理)
|
||
func (s *AntigravityGatewayService) WriteMappedClaudeError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error {
|
||
return s.writeMappedClaudeError(c, account, upstreamStatus, upstreamRequestID, body)
|
||
}
|
||
|
||
func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error {
|
||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||
logBody, maxBytes := s.getLogConfig()
|
||
upstreamDetail := s.getUpstreamErrorDetail(body)
|
||
setOpsUpstreamError(c, upstreamStatus, upstreamMsg, upstreamDetail)
|
||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||
Platform: account.Platform,
|
||
AccountID: account.ID,
|
||
AccountName: account.Name,
|
||
UpstreamStatusCode: upstreamStatus,
|
||
UpstreamRequestID: upstreamRequestID,
|
||
Kind: "http_error",
|
||
Message: upstreamMsg,
|
||
Detail: upstreamDetail,
|
||
})
|
||
|
||
// 记录上游错误详情便于排障(可选:由配置控制;不回显到客户端)
|
||
if logBody {
|
||
log.Printf("[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, truncateForLog(body, maxBytes))
|
||
}
|
||
|
||
var statusCode int
|
||
var errType, errMsg string
|
||
|
||
switch upstreamStatus {
|
||
case 400:
|
||
statusCode = http.StatusBadRequest
|
||
errType = "invalid_request_error"
|
||
errMsg = getPassthroughOrDefault(upstreamMsg, "Invalid request")
|
||
case 401:
|
||
statusCode = http.StatusBadGateway
|
||
errType = "authentication_error"
|
||
errMsg = "Upstream authentication failed"
|
||
case 403:
|
||
statusCode = http.StatusBadGateway
|
||
errType = "permission_error"
|
||
errMsg = "Upstream access forbidden"
|
||
case 429:
|
||
statusCode = http.StatusTooManyRequests
|
||
errType = "rate_limit_error"
|
||
errMsg = "Upstream rate limit exceeded"
|
||
case 529:
|
||
statusCode = http.StatusServiceUnavailable
|
||
errType = "overloaded_error"
|
||
errMsg = "Upstream service overloaded"
|
||
default:
|
||
statusCode = http.StatusBadGateway
|
||
errType = "upstream_error"
|
||
errMsg = "Upstream request failed"
|
||
}
|
||
|
||
c.JSON(statusCode, gin.H{
|
||
"type": "error",
|
||
"error": gin.H{"type": errType, "message": errMsg},
|
||
})
|
||
if upstreamMsg == "" {
|
||
return fmt.Errorf("upstream error: %d", upstreamStatus)
|
||
}
|
||
return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg)
|
||
}
|
||
|
||
func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error {
|
||
statusStr := "UNKNOWN"
|
||
switch status {
|
||
case 400:
|
||
statusStr = "INVALID_ARGUMENT"
|
||
case 404:
|
||
statusStr = "NOT_FOUND"
|
||
case 429:
|
||
statusStr = "RESOURCE_EXHAUSTED"
|
||
case 500:
|
||
statusStr = "INTERNAL"
|
||
case 502, 503:
|
||
statusStr = "UNAVAILABLE"
|
||
}
|
||
|
||
c.JSON(status, gin.H{
|
||
"error": gin.H{
|
||
"code": status,
|
||
"message": message,
|
||
"status": statusStr,
|
||
},
|
||
})
|
||
return fmt.Errorf("%s", message)
|
||
}
|
||
|
||
// handleClaudeStreamToNonStreaming 收集上游流式响应,转换为 Claude 非流式格式返回
|
||
// 用于处理客户端非流式请求但上游只支持流式的情况
|
||
func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
|
||
scanner := bufio.NewScanner(resp.Body)
|
||
maxLineSize := defaultMaxLineSize
|
||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
||
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
|
||
}
|
||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||
|
||
var firstTokenMs *int
|
||
var last map[string]any
|
||
var lastWithParts map[string]any
|
||
var collectedParts []map[string]any // 收集所有 parts(包括 text、thinking、functionCall、inlineData 等)
|
||
|
||
type scanEvent struct {
|
||
line string
|
||
err error
|
||
}
|
||
|
||
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
|
||
events := make(chan scanEvent, 16)
|
||
done := make(chan struct{})
|
||
sendEvent := func(ev scanEvent) bool {
|
||
select {
|
||
case events <- ev:
|
||
return true
|
||
case <-done:
|
||
return false
|
||
}
|
||
}
|
||
|
||
var lastReadAt int64
|
||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||
go func() {
|
||
defer close(events)
|
||
for scanner.Scan() {
|
||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||
return
|
||
}
|
||
}
|
||
if err := scanner.Err(); err != nil {
|
||
_ = sendEvent(scanEvent{err: err})
|
||
}
|
||
}()
|
||
defer close(done)
|
||
|
||
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
|
||
streamInterval := time.Duration(0)
|
||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||
}
|
||
var intervalTicker *time.Ticker
|
||
if streamInterval > 0 {
|
||
intervalTicker = time.NewTicker(streamInterval)
|
||
defer intervalTicker.Stop()
|
||
}
|
||
var intervalCh <-chan time.Time
|
||
if intervalTicker != nil {
|
||
intervalCh = intervalTicker.C
|
||
}
|
||
|
||
for {
|
||
select {
|
||
case ev, ok := <-events:
|
||
if !ok {
|
||
// 流结束,转换并返回响应
|
||
goto returnResponse
|
||
}
|
||
if ev.err != nil {
|
||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||
log.Printf("SSE line too long (antigravity claude non-stream): max_size=%d error=%v", maxLineSize, ev.err)
|
||
}
|
||
return nil, ev.err
|
||
}
|
||
|
||
line := ev.line
|
||
trimmed := strings.TrimRight(line, "\r\n")
|
||
|
||
if !strings.HasPrefix(trimmed, "data:") {
|
||
continue
|
||
}
|
||
|
||
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
|
||
if payload == "" || payload == "[DONE]" {
|
||
continue
|
||
}
|
||
|
||
// 解包 v1internal 响应
|
||
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
|
||
if parseErr != nil {
|
||
continue
|
||
}
|
||
|
||
var parsed map[string]any
|
||
if err := json.Unmarshal(inner, &parsed); err != nil {
|
||
continue
|
||
}
|
||
|
||
// 记录首 token 时间
|
||
if firstTokenMs == nil {
|
||
ms := int(time.Since(startTime).Milliseconds())
|
||
firstTokenMs = &ms
|
||
}
|
||
|
||
last = parsed
|
||
|
||
// 保留最后一个有 parts 的响应,并收集所有 parts
|
||
if parts := extractGeminiParts(parsed); len(parts) > 0 {
|
||
lastWithParts = parsed
|
||
|
||
// 收集所有 parts(text、thinking、functionCall、inlineData 等)
|
||
collectedParts = append(collectedParts, parts...)
|
||
}
|
||
|
||
case <-intervalCh:
|
||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||
if time.Since(lastRead) < streamInterval {
|
||
continue
|
||
}
|
||
log.Printf("Stream data interval timeout (antigravity claude non-stream)")
|
||
return nil, fmt.Errorf("stream data interval timeout")
|
||
}
|
||
}
|
||
|
||
returnResponse:
|
||
// 选择最后一个有效响应
|
||
finalResponse := pickGeminiCollectResult(last, lastWithParts)
|
||
|
||
// 处理空响应情况
|
||
if last == nil && lastWithParts == nil {
|
||
log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received")
|
||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream")
|
||
}
|
||
|
||
// 将收集的所有 parts 合并到最终响应中
|
||
if len(collectedParts) > 0 {
|
||
finalResponse = mergeCollectedPartsToResponse(finalResponse, collectedParts)
|
||
}
|
||
|
||
// 序列化为 JSON(Gemini 格式)
|
||
geminiBody, err := json.Marshal(finalResponse)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to marshal gemini response: %w", err)
|
||
}
|
||
|
||
// 转换 Gemini 响应为 Claude 格式
|
||
claudeResp, agUsage, err := antigravity.TransformGeminiToClaude(geminiBody, originalModel)
|
||
if err != nil {
|
||
log.Printf("[antigravity-Forward] transform_error error=%v body=%s", err, string(geminiBody))
|
||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
|
||
}
|
||
|
||
c.Data(http.StatusOK, "application/json", claudeResp)
|
||
|
||
// 转换为 service.ClaudeUsage
|
||
usage := &ClaudeUsage{
|
||
InputTokens: agUsage.InputTokens,
|
||
OutputTokens: agUsage.OutputTokens,
|
||
CacheCreationInputTokens: agUsage.CacheCreationInputTokens,
|
||
CacheReadInputTokens: agUsage.CacheReadInputTokens,
|
||
}
|
||
|
||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||
}
|
||
|
||
// handleClaudeStreamingResponse 处理 Claude 流式响应(Gemini SSE → Claude SSE 转换)
|
||
func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
|
||
c.Header("Content-Type", "text/event-stream")
|
||
c.Header("Cache-Control", "no-cache")
|
||
c.Header("Connection", "keep-alive")
|
||
c.Header("X-Accel-Buffering", "no")
|
||
c.Status(http.StatusOK)
|
||
|
||
flusher, ok := c.Writer.(http.Flusher)
|
||
if !ok {
|
||
return nil, errors.New("streaming not supported")
|
||
}
|
||
|
||
processor := antigravity.NewStreamingProcessor(originalModel)
|
||
var firstTokenMs *int
|
||
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
|
||
scanner := bufio.NewScanner(resp.Body)
|
||
maxLineSize := defaultMaxLineSize
|
||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
||
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
|
||
}
|
||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||
|
||
// 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage
|
||
convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage {
|
||
if agUsage == nil {
|
||
return &ClaudeUsage{}
|
||
}
|
||
return &ClaudeUsage{
|
||
InputTokens: agUsage.InputTokens,
|
||
OutputTokens: agUsage.OutputTokens,
|
||
CacheCreationInputTokens: agUsage.CacheCreationInputTokens,
|
||
CacheReadInputTokens: agUsage.CacheReadInputTokens,
|
||
}
|
||
}
|
||
|
||
type scanEvent struct {
|
||
line string
|
||
err error
|
||
}
|
||
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
|
||
events := make(chan scanEvent, 16)
|
||
done := make(chan struct{})
|
||
sendEvent := func(ev scanEvent) bool {
|
||
select {
|
||
case events <- ev:
|
||
return true
|
||
case <-done:
|
||
return false
|
||
}
|
||
}
|
||
var lastReadAt int64
|
||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||
go func() {
|
||
defer close(events)
|
||
for scanner.Scan() {
|
||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||
return
|
||
}
|
||
}
|
||
if err := scanner.Err(); err != nil {
|
||
_ = sendEvent(scanEvent{err: err})
|
||
}
|
||
}()
|
||
defer close(done)
|
||
|
||
streamInterval := time.Duration(0)
|
||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||
}
|
||
var intervalTicker *time.Ticker
|
||
if streamInterval > 0 {
|
||
intervalTicker = time.NewTicker(streamInterval)
|
||
defer intervalTicker.Stop()
|
||
}
|
||
var intervalCh <-chan time.Time
|
||
if intervalTicker != nil {
|
||
intervalCh = intervalTicker.C
|
||
}
|
||
|
||
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
||
errorEventSent := false
|
||
sendErrorEvent := func(reason string) {
|
||
if errorEventSent {
|
||
return
|
||
}
|
||
errorEventSent = true
|
||
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
||
flusher.Flush()
|
||
}
|
||
|
||
for {
|
||
select {
|
||
case ev, ok := <-events:
|
||
if !ok {
|
||
// 发送结束事件
|
||
finalEvents, agUsage := processor.Finish()
|
||
if len(finalEvents) > 0 {
|
||
_, _ = c.Writer.Write(finalEvents)
|
||
flusher.Flush()
|
||
}
|
||
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil
|
||
}
|
||
if ev.err != nil {
|
||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
|
||
sendErrorEvent("response_too_large")
|
||
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, ev.err
|
||
}
|
||
sendErrorEvent("stream_read_error")
|
||
return nil, fmt.Errorf("stream read error: %w", ev.err)
|
||
}
|
||
|
||
line := ev.line
|
||
// 处理 SSE 行,转换为 Claude 格式
|
||
claudeEvents := processor.ProcessLine(strings.TrimRight(line, "\r\n"))
|
||
|
||
if len(claudeEvents) > 0 {
|
||
if firstTokenMs == nil {
|
||
ms := int(time.Since(startTime).Milliseconds())
|
||
firstTokenMs = &ms
|
||
}
|
||
|
||
if _, writeErr := c.Writer.Write(claudeEvents); writeErr != nil {
|
||
finalEvents, agUsage := processor.Finish()
|
||
if len(finalEvents) > 0 {
|
||
_, _ = c.Writer.Write(finalEvents)
|
||
}
|
||
sendErrorEvent("write_failed")
|
||
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, writeErr
|
||
}
|
||
flusher.Flush()
|
||
}
|
||
|
||
case <-intervalCh:
|
||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||
if time.Since(lastRead) < streamInterval {
|
||
continue
|
||
}
|
||
log.Printf("Stream data interval timeout (antigravity)")
|
||
// 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout
|
||
sendErrorEvent("stream_timeout")
|
||
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||
}
|
||
}
|
||
|
||
}
|
||
|
||
// extractImageSize 从 Gemini 请求中提取 image_size 参数
|
||
func (s *AntigravityGatewayService) extractImageSize(body []byte) string {
|
||
var req antigravity.GeminiRequest
|
||
if err := json.Unmarshal(body, &req); err != nil {
|
||
return "2K" // 默认 2K
|
||
}
|
||
|
||
if req.GenerationConfig != nil && req.GenerationConfig.ImageConfig != nil {
|
||
size := strings.ToUpper(strings.TrimSpace(req.GenerationConfig.ImageConfig.ImageSize))
|
||
if size == "1K" || size == "2K" || size == "4K" {
|
||
return size
|
||
}
|
||
}
|
||
|
||
return "2K" // 默认 2K
|
||
}
|
||
|
||
// isImageGenerationModel 判断模型是否为图片生成模型
|
||
// 支持的模型:gemini-3-pro-image, gemini-3-pro-image-preview, gemini-2.5-flash-image 等
|
||
func isImageGenerationModel(model string) bool {
|
||
modelLower := strings.ToLower(model)
|
||
// 移除 models/ 前缀
|
||
modelLower = strings.TrimPrefix(modelLower, "models/")
|
||
|
||
// 精确匹配或前缀匹配
|
||
return modelLower == "gemini-3-pro-image" ||
|
||
modelLower == "gemini-3-pro-image-preview" ||
|
||
strings.HasPrefix(modelLower, "gemini-3-pro-image-") ||
|
||
modelLower == "gemini-2.5-flash-image" ||
|
||
modelLower == "gemini-2.5-flash-image-preview" ||
|
||
strings.HasPrefix(modelLower, "gemini-2.5-flash-image-")
|
||
}
|
||
|
||
// cleanGeminiRequest 清理 Gemini 请求体中的 Schema
|
||
func cleanGeminiRequest(body []byte) ([]byte, error) {
|
||
var payload map[string]any
|
||
if err := json.Unmarshal(body, &payload); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
modified := false
|
||
|
||
// 1. 清理 Tools
|
||
if tools, ok := payload["tools"].([]any); ok && len(tools) > 0 {
|
||
for _, t := range tools {
|
||
toolMap, ok := t.(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
// function_declarations (snake_case) or functionDeclarations (camelCase)
|
||
var funcs []any
|
||
if f, ok := toolMap["functionDeclarations"].([]any); ok {
|
||
funcs = f
|
||
} else if f, ok := toolMap["function_declarations"].([]any); ok {
|
||
funcs = f
|
||
}
|
||
|
||
if len(funcs) == 0 {
|
||
continue
|
||
}
|
||
|
||
for _, f := range funcs {
|
||
funcMap, ok := f.(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
if params, ok := funcMap["parameters"].(map[string]any); ok {
|
||
antigravity.DeepCleanUndefined(params)
|
||
cleaned := antigravity.CleanJSONSchema(params)
|
||
funcMap["parameters"] = cleaned
|
||
modified = true
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if !modified {
|
||
return body, nil
|
||
}
|
||
|
||
return json.Marshal(payload)
|
||
}
|
||
|
||
// filterEmptyPartsFromGeminiRequest 过滤掉 parts 为空的消息
|
||
// Gemini API 不接受空 parts,需要在请求前过滤
|
||
func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) {
|
||
var payload map[string]any
|
||
if err := json.Unmarshal(body, &payload); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
contents, ok := payload["contents"].([]any)
|
||
if !ok || len(contents) == 0 {
|
||
return body, nil
|
||
}
|
||
|
||
filtered := make([]any, 0, len(contents))
|
||
modified := false
|
||
|
||
for _, c := range contents {
|
||
contentMap, ok := c.(map[string]any)
|
||
if !ok {
|
||
filtered = append(filtered, c)
|
||
continue
|
||
}
|
||
|
||
parts, hasParts := contentMap["parts"]
|
||
if !hasParts {
|
||
filtered = append(filtered, c)
|
||
continue
|
||
}
|
||
|
||
partsSlice, ok := parts.([]any)
|
||
if !ok {
|
||
filtered = append(filtered, c)
|
||
continue
|
||
}
|
||
|
||
// 跳过 parts 为空数组的消息
|
||
if len(partsSlice) == 0 {
|
||
modified = true
|
||
continue
|
||
}
|
||
|
||
filtered = append(filtered, c)
|
||
}
|
||
|
||
if !modified {
|
||
return body, nil
|
||
}
|
||
|
||
payload["contents"] = filtered
|
||
return json.Marshal(payload)
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// Upstream 专用转发方法
|
||
// upstream 账号直接连接上游 Anthropic/Gemini 兼容端点,不走 Antigravity OAuth 协议转换。
|
||
// ---------------------------------------------------------------------------
|
||
|
||
// testUpstreamConnection 测试 upstream 账号连接
|
||
func (s *AntigravityGatewayService) testUpstreamConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
|
||
baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/")
|
||
if baseURL == "" {
|
||
return nil, errors.New("upstream account missing base_url in credentials")
|
||
}
|
||
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
|
||
if apiKey == "" {
|
||
return nil, errors.New("upstream account missing api_key in credentials")
|
||
}
|
||
|
||
mappedModel := s.getMappedModel(account, modelID)
|
||
if mappedModel == "" {
|
||
return nil, fmt.Errorf("model %s not in whitelist", modelID)
|
||
}
|
||
|
||
// 构建最小 Claude 格式请求
|
||
requestBody, _ := json.Marshal(map[string]any{
|
||
"model": mappedModel,
|
||
"max_tokens": 1,
|
||
"messages": []map[string]any{
|
||
{"role": "user", "content": "."},
|
||
},
|
||
"stream": false,
|
||
})
|
||
|
||
apiURL := baseURL + "/antigravity/v1/messages"
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(requestBody))
|
||
if err != nil {
|
||
return nil, fmt.Errorf("构建请求失败: %w", err)
|
||
}
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||
req.Header.Set("x-api-key", apiKey)
|
||
req.Header.Set("anthropic-version", "2023-06-01")
|
||
|
||
proxyURL := ""
|
||
if account.ProxyID != nil && account.Proxy != nil {
|
||
proxyURL = account.Proxy.URL()
|
||
}
|
||
|
||
log.Printf("[antigravity-Test-Upstream] account=%s url=%s", account.Name, apiURL)
|
||
|
||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("请求失败: %w", err)
|
||
}
|
||
defer func() { _ = resp.Body.Close() }()
|
||
|
||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||
if err != nil {
|
||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||
}
|
||
|
||
if resp.StatusCode >= 400 {
|
||
return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody))
|
||
}
|
||
|
||
// 从 Claude 格式非流式响应中提取文本
|
||
var claudeResp struct {
|
||
Content []struct {
|
||
Text string `json:"text"`
|
||
} `json:"content"`
|
||
}
|
||
text := ""
|
||
if json.Unmarshal(respBody, &claudeResp) == nil && len(claudeResp.Content) > 0 {
|
||
text = claudeResp.Content[0].Text
|
||
}
|
||
|
||
return &TestConnectionResult{
|
||
Text: text,
|
||
MappedModel: mappedModel,
|
||
}, nil
|
||
}
|
||
|
||
// ForwardUpstream 转发 Claude 协议请求到 upstream(不做协议转换)
|
||
func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) {
|
||
startTime := time.Now()
|
||
sessionID := getSessionID(c)
|
||
prefix := logPrefix(sessionID, account.Name)
|
||
|
||
baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/")
|
||
if baseURL == "" {
|
||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "api_error", "Upstream account missing base_url")
|
||
}
|
||
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
|
||
if apiKey == "" {
|
||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "authentication_error", "Upstream account missing api_key")
|
||
}
|
||
|
||
// 解析请求以获取模型和流式标志
|
||
var claudeReq antigravity.ClaudeRequest
|
||
if err := json.Unmarshal(body, &claudeReq); err != nil {
|
||
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Invalid request body")
|
||
}
|
||
if strings.TrimSpace(claudeReq.Model) == "" {
|
||
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Missing model")
|
||
}
|
||
|
||
originalModel := claudeReq.Model
|
||
mappedModel := s.getMappedModel(account, claudeReq.Model)
|
||
if mappedModel == "" {
|
||
return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model))
|
||
}
|
||
|
||
// 代理 URL
|
||
proxyURL := ""
|
||
if account.ProxyID != nil && account.Proxy != nil {
|
||
proxyURL = account.Proxy.URL()
|
||
}
|
||
|
||
// 统计模型调用次数
|
||
if s.cache != nil {
|
||
_, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel)
|
||
}
|
||
|
||
apiURL := baseURL + "/antigravity/v1/messages"
|
||
log.Printf("%s upstream_forward url=%s model=%s", prefix, apiURL, mappedModel)
|
||
|
||
// 构建请求:body 原样透传
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body))
|
||
if err != nil {
|
||
return nil, s.writeClaudeError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
|
||
}
|
||
// 透传客户端所有请求头(排除 hop-by-hop 和认证头)
|
||
if c != nil && c.Request != nil {
|
||
for key, values := range c.Request.Header {
|
||
if upstreamHopByHopHeaders[strings.ToLower(key)] {
|
||
continue
|
||
}
|
||
for _, v := range values {
|
||
req.Header.Add(key, v)
|
||
}
|
||
}
|
||
}
|
||
// 覆盖认证头
|
||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||
req.Header.Set("x-api-key", apiKey)
|
||
|
||
if c != nil && len(body) > 0 {
|
||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||
}
|
||
|
||
// 单次发送,不重试
|
||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||
if err != nil {
|
||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", fmt.Sprintf("Upstream request failed: %v", err))
|
||
}
|
||
defer func() { _ = resp.Body.Close() }()
|
||
|
||
// 错误响应处理
|
||
if resp.StatusCode >= 400 {
|
||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||
|
||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, "", 0, "", isStickySession)
|
||
|
||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||
}
|
||
|
||
return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody)
|
||
}
|
||
|
||
// 成功响应:透传 response header + body
|
||
requestID := resp.Header.Get("x-request-id")
|
||
|
||
// 透传上游响应头(排除 hop-by-hop)
|
||
for key, values := range resp.Header {
|
||
if upstreamHopByHopHeaders[strings.ToLower(key)] {
|
||
continue
|
||
}
|
||
for _, v := range values {
|
||
c.Header(key, v)
|
||
}
|
||
}
|
||
|
||
c.Status(resp.StatusCode)
|
||
_, copyErr := io.Copy(c.Writer, resp.Body)
|
||
if copyErr != nil {
|
||
log.Printf("%s status=copy_error error=%v", prefix, copyErr)
|
||
}
|
||
|
||
return &ForwardResult{
|
||
RequestID: requestID,
|
||
Model: originalModel,
|
||
Stream: claudeReq.Stream,
|
||
Duration: time.Since(startTime),
|
||
}, nil
|
||
}
|
||
|
||
// ForwardUpstreamGemini 转发 Gemini 协议请求到 upstream(不做协议转换)
|
||
func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) {
|
||
startTime := time.Now()
|
||
sessionID := getSessionID(c)
|
||
prefix := logPrefix(sessionID, account.Name)
|
||
|
||
baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/")
|
||
if baseURL == "" {
|
||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream account missing base_url")
|
||
}
|
||
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
|
||
if apiKey == "" {
|
||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream account missing api_key")
|
||
}
|
||
|
||
if strings.TrimSpace(originalModel) == "" {
|
||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing model in URL")
|
||
}
|
||
if strings.TrimSpace(action) == "" {
|
||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing action in URL")
|
||
}
|
||
if len(body) == 0 {
|
||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
|
||
}
|
||
|
||
imageSize := s.extractImageSize(body)
|
||
|
||
switch action {
|
||
case "generateContent", "streamGenerateContent":
|
||
// ok
|
||
case "countTokens":
|
||
c.JSON(http.StatusOK, map[string]any{"totalTokens": 0})
|
||
return &ForwardResult{
|
||
RequestID: "",
|
||
Usage: ClaudeUsage{},
|
||
Model: originalModel,
|
||
Stream: false,
|
||
Duration: time.Since(time.Now()),
|
||
FirstTokenMs: nil,
|
||
}, nil
|
||
default:
|
||
return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action)
|
||
}
|
||
|
||
mappedModel := s.getMappedModel(account, originalModel)
|
||
if mappedModel == "" {
|
||
return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel))
|
||
}
|
||
|
||
// 代理 URL
|
||
proxyURL := ""
|
||
if account.ProxyID != nil && account.Proxy != nil {
|
||
proxyURL = account.Proxy.URL()
|
||
}
|
||
|
||
// 统计模型调用次数
|
||
if s.cache != nil {
|
||
_, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel)
|
||
}
|
||
|
||
// 构建 upstream URL: base_url + /antigravity/v1beta/models/MODEL:ACTION
|
||
apiURL := fmt.Sprintf("%s/antigravity/v1beta/models/%s:%s", baseURL, mappedModel, action)
|
||
if stream || action == "streamGenerateContent" {
|
||
apiURL += "?alt=sse"
|
||
}
|
||
|
||
log.Printf("%s upstream_forward_gemini url=%s model=%s action=%s", prefix, apiURL, mappedModel, action)
|
||
|
||
// 构建请求:body 原样透传
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body))
|
||
if err != nil {
|
||
return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build request")
|
||
}
|
||
// 透传客户端所有请求头(排除 hop-by-hop 和认证头)
|
||
if c != nil && c.Request != nil {
|
||
for key, values := range c.Request.Header {
|
||
if upstreamHopByHopHeaders[strings.ToLower(key)] {
|
||
continue
|
||
}
|
||
for _, v := range values {
|
||
req.Header.Add(key, v)
|
||
}
|
||
}
|
||
}
|
||
// 覆盖认证头
|
||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||
|
||
if c != nil && len(body) > 0 {
|
||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||
}
|
||
|
||
// 单次发送,不重试
|
||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||
if err != nil {
|
||
return nil, s.writeGoogleError(c, http.StatusBadGateway, fmt.Sprintf("Upstream request failed: %v", err))
|
||
}
|
||
defer func() { _ = resp.Body.Close() }()
|
||
|
||
// 错误响应处理
|
||
if resp.StatusCode >= 400 {
|
||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||
contentType := resp.Header.Get("Content-Type")
|
||
|
||
requestID := resp.Header.Get("x-request-id")
|
||
if requestID != "" {
|
||
c.Header("x-request-id", requestID)
|
||
}
|
||
|
||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, "", 0, "", isStickySession)
|
||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||
upstreamDetail := s.getUpstreamErrorDetail(respBody)
|
||
|
||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||
|
||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||
Platform: account.Platform,
|
||
AccountID: account.ID,
|
||
AccountName: account.Name,
|
||
UpstreamStatusCode: resp.StatusCode,
|
||
UpstreamRequestID: requestID,
|
||
Kind: "failover",
|
||
Message: upstreamMsg,
|
||
Detail: upstreamDetail,
|
||
})
|
||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||
}
|
||
if contentType == "" {
|
||
contentType = "application/json"
|
||
}
|
||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||
Platform: account.Platform,
|
||
AccountID: account.ID,
|
||
AccountName: account.Name,
|
||
UpstreamStatusCode: resp.StatusCode,
|
||
UpstreamRequestID: requestID,
|
||
Kind: "http_error",
|
||
Message: upstreamMsg,
|
||
Detail: upstreamDetail,
|
||
})
|
||
log.Printf("[antigravity-Forward-Upstream] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(respBody, 500))
|
||
c.Data(resp.StatusCode, contentType, respBody)
|
||
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
|
||
}
|
||
|
||
// 成功响应:透传 response header + body
|
||
requestID := resp.Header.Get("x-request-id")
|
||
|
||
// 透传上游响应头(排除 hop-by-hop)
|
||
for key, values := range resp.Header {
|
||
if upstreamHopByHopHeaders[strings.ToLower(key)] {
|
||
continue
|
||
}
|
||
for _, v := range values {
|
||
c.Header(key, v)
|
||
}
|
||
}
|
||
|
||
c.Status(resp.StatusCode)
|
||
_, copyErr := io.Copy(c.Writer, resp.Body)
|
||
if copyErr != nil {
|
||
log.Printf("%s status=copy_error error=%v", prefix, copyErr)
|
||
}
|
||
|
||
imageCount := 0
|
||
if isImageGenerationModel(mappedModel) {
|
||
imageCount = 1
|
||
}
|
||
|
||
return &ForwardResult{
|
||
RequestID: requestID,
|
||
Model: originalModel,
|
||
Stream: stream,
|
||
Duration: time.Since(startTime),
|
||
ImageCount: imageCount,
|
||
ImageSize: imageSize,
|
||
}, nil
|
||
}
|