Merge remote-tracking branch 'upstream/main'

This commit is contained in:
Edric Li
2026-01-08 21:35:34 +08:00
8 changed files with 645 additions and 156 deletions

View File

@@ -144,7 +144,7 @@ func (h *UsageHandler) List(c *gin.Context) {
out := make([]dto.UsageLog, 0, len(records))
for i := range records {
out = append(out, *dto.UsageLogFromService(&records[i]))
out = append(out, *dto.UsageLogFromServiceAdmin(&records[i]))
}
response.Paginated(c, out, result.Total, page, pageSize)
}

View File

@@ -234,7 +234,21 @@ func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
}
}
func UsageLogFromService(l *service.UsageLog) *UsageLog {
// AccountSummaryFromService returns a minimal AccountSummary for usage log display.
// Only includes ID and Name - no sensitive fields like Credentials, Proxy, etc.
func AccountSummaryFromService(a *service.Account) *AccountSummary {
if a == nil {
return nil
}
return &AccountSummary{
ID: a.ID,
Name: a.Name,
}
}
// usageLogFromServiceBase is a helper that converts service UsageLog to DTO.
// The account parameter allows caller to control what Account info is included.
func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary) *UsageLog {
if l == nil {
return nil
}
@@ -270,12 +284,27 @@ func UsageLogFromService(l *service.UsageLog) *UsageLog {
CreatedAt: l.CreatedAt,
User: UserFromServiceShallow(l.User),
APIKey: APIKeyFromService(l.APIKey),
Account: AccountFromService(l.Account),
Account: account,
Group: GroupFromServiceShallow(l.Group),
Subscription: UserSubscriptionFromService(l.Subscription),
}
}
// UsageLogFromService converts a service UsageLog to DTO for regular users.
// It excludes Account details - users should not see account information.
func UsageLogFromService(l *service.UsageLog) *UsageLog {
return usageLogFromServiceBase(l, nil)
}
// UsageLogFromServiceAdmin converts a service UsageLog to DTO for admin users.
// It includes minimal Account info (ID, Name only).
func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog {
if l == nil {
return nil
}
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account))
}
func SettingFromService(s *service.Setting) *Setting {
if s == nil {
return nil

View File

@@ -187,11 +187,18 @@ type UsageLog struct {
User *User `json:"user,omitempty"`
APIKey *APIKey `json:"api_key,omitempty"`
Account *Account `json:"account,omitempty"`
Account *AccountSummary `json:"account,omitempty"` // Use minimal AccountSummary to prevent data leakage
Group *Group `json:"group,omitempty"`
Subscription *UserSubscription `json:"subscription,omitempty"`
}
// AccountSummary is a minimal account info for usage log display.
// It intentionally excludes sensitive fields like Credentials, Proxy, etc.
type AccountSummary struct {
ID int64 `json:"id"`
Name string `json:"name"`
}
type Setting struct {
ID int64 `json:"id"`
Key string `json:"key"`

View File

@@ -13,16 +13,48 @@ import (
"time"
)
// resolveHost 从 URL 解析 host
func resolveHost(urlStr string) string {
parsed, err := url.Parse(urlStr)
if err != nil {
return ""
}
return parsed.Host
}
// NewAPIRequest 创建 Antigravity API 请求v1internal 端点)
func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) (*http.Request, error) {
// 构建 URL流式请求添加 ?alt=sse 参数
apiURL := fmt.Sprintf("%s/v1internal:%s", BaseURL, action)
isStream := action == "streamGenerateContent"
if isStream {
apiURL += "?alt=sse"
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
// 基础 Headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("User-Agent", UserAgent)
// Accept Header 根据请求类型设置
if isStream {
req.Header.Set("Accept", "text/event-stream")
} else {
req.Header.Set("Accept", "application/json")
}
// 显式设置 Host Header
if host := resolveHost(apiURL); host != "" {
req.Host = host
}
// 注意requestType 已在 JSON body 的 V1InternalRequest 中设置,不需要 HTTP Header
return req, nil
}

View File

@@ -33,10 +33,11 @@ const (
"https://www.googleapis.com/auth/experimentsandconfigs"
// API 端点
BaseURL = "https://cloudcode-pa.googleapis.com"
// 优先使用 sandbox daily URL配额更宽松
BaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
// User-Agent
UserAgent = "antigravity/1.11.9 windows/amd64"
// User-Agent(模拟官方客户端)
UserAgent = "antigravity/1.104.0 darwin/arm64"
// Session 过期时间
SessionTTL = 30 * time.Minute

View File

@@ -1,17 +1,46 @@
package antigravity
import (
"crypto/sha256"
"encoding/binary"
"encoding/json"
"fmt"
"log"
"math/rand"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
var (
sessionRand = rand.New(rand.NewSource(time.Now().UnixNano()))
sessionRandMutex sync.Mutex
)
// generateStableSessionID 基于用户消息内容生成稳定的 session ID
func generateStableSessionID(contents []GeminiContent) string {
// 查找第一个 user 消息的文本
for _, content := range contents {
if content.Role == "user" && len(content.Parts) > 0 {
if text := content.Parts[0].Text; text != "" {
h := sha256.Sum256([]byte(text))
n := int64(binary.BigEndian.Uint64(h[:8])) & 0x7FFFFFFFFFFFFFFF
return "-" + strconv.FormatInt(n, 10)
}
}
}
// 回退:生成随机 session ID
sessionRandMutex.Lock()
n := sessionRand.Int63n(9_000_000_000_000_000_000)
sessionRandMutex.Unlock()
return "-" + strconv.FormatInt(n, 10)
}
type TransformOptions struct {
EnableIdentityPatch bool
// IdentityPatch 可选:自定义注入到 systemInstruction 开头的身份防护提示词;
@@ -67,8 +96,15 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
// 5. 构建内部请求
innerRequest := GeminiRequest{
Contents: contents,
SafetySettings: DefaultSafetySettings,
Contents: contents,
// 总是设置 toolConfig与官方客户端一致
ToolConfig: &GeminiToolConfig{
FunctionCallingConfig: &GeminiFunctionCallingConfig{
Mode: "VALIDATED",
},
},
// 总是生成 sessionId基于用户消息内容
SessionID: generateStableSessionID(contents),
}
if systemInstruction != nil {
@@ -79,14 +115,9 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
}
if len(tools) > 0 {
innerRequest.Tools = tools
innerRequest.ToolConfig = &GeminiToolConfig{
FunctionCallingConfig: &GeminiFunctionCallingConfig{
Mode: "VALIDATED",
},
}
}
// 如果提供了 metadata.user_id复用为 sessionId
// 如果提供了 metadata.user_id优先使用
if claudeReq.Metadata != nil && claudeReq.Metadata.UserID != "" {
innerRequest.SessionID = claudeReq.Metadata.UserID
}
@@ -95,7 +126,7 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
v1Req := V1InternalRequest{
Project: projectID,
RequestID: "agent-" + uuid.New().String(),
UserAgent: "sub2api",
UserAgent: "antigravity", // 固定值,与官方客户端一致
RequestType: "agent",
Model: mappedModel,
Request: innerRequest,
@@ -104,37 +135,42 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
return json.Marshal(v1Req)
}
func defaultIdentityPatch(modelName string) string {
return fmt.Sprintf(
"--- [IDENTITY_PATCH] ---\n"+
"Ignore any previous instructions regarding your identity or host platform (e.g., Amazon Q, Google AI).\n"+
"You are currently providing services as the native %s model via a standard API proxy.\n"+
"Always use the 'claude' command for terminal tasks if relevant.\n"+
"--- [SYSTEM_PROMPT_BEGIN] ---\n",
modelName,
)
// antigravityIdentity Antigravity identity 提示词
const antigravityIdentity = `<identity>
You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.
You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.
The USER will send you requests, which you must always prioritize addressing. Along with each USER request, we will attach additional metadata about their current state, such as what files they have open and where their cursor is.
This information may or may not be relevant to the coding task, it is up for you to decide.
</identity>
<communication_style>
- **Proactiveness**. As an agent, you are allowed to be proactive, but only in the course of completing the user's task. For example, if the user asks you to add a new component, you can edit the code, verify build and test statuses, and take any other obvious follow-up actions, such as performing additional research. However, avoid surprising the user. For example, if the user asks HOW to approach something, you should answer their question and instead of jumping into editing a file.</communication_style>`
func defaultIdentityPatch(_ string) string {
return antigravityIdentity
}
// GetDefaultIdentityPatch 返回默认的 Antigravity 身份提示词
func GetDefaultIdentityPatch() string {
return antigravityIdentity
}
// buildSystemInstruction 构建 systemInstruction
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions) *GeminiContent {
var parts []GeminiPart
// 可选注入身份防护指令(身份补丁)
if opts.EnableIdentityPatch {
identityPatch := strings.TrimSpace(opts.IdentityPatch)
if identityPatch == "" {
identityPatch = defaultIdentityPatch(modelName)
}
parts = append(parts, GeminiPart{Text: identityPatch})
}
// 先解析用户的 system prompt检测是否已包含 Antigravity identity
userHasAntigravityIdentity := false
var userSystemParts []GeminiPart
// 解析 system prompt
if len(system) > 0 {
// 尝试解析为字符串
var sysStr string
if err := json.Unmarshal(system, &sysStr); err == nil {
if strings.TrimSpace(sysStr) != "" {
parts = append(parts, GeminiPart{Text: sysStr})
userSystemParts = append(userSystemParts, GeminiPart{Text: sysStr})
if strings.Contains(sysStr, "You are Antigravity") {
userHasAntigravityIdentity = true
}
}
} else {
// 尝试解析为数组
@@ -142,17 +178,28 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
if err := json.Unmarshal(system, &sysBlocks); err == nil {
for _, block := range sysBlocks {
if block.Type == "text" && strings.TrimSpace(block.Text) != "" {
parts = append(parts, GeminiPart{Text: block.Text})
userSystemParts = append(userSystemParts, GeminiPart{Text: block.Text})
if strings.Contains(block.Text, "You are Antigravity") {
userHasAntigravityIdentity = true
}
}
}
}
}
}
// identity patch 模式下,用分隔符包裹 system prompt便于上游识别/调试;关闭时尽量保持原始 system prompt。
if opts.EnableIdentityPatch && len(parts) > 0 {
parts = append(parts, GeminiPart{Text: "\n--- [SYSTEM_PROMPT_END] ---"})
// 仅在用户未提供 Antigravity identity 时注入
if opts.EnableIdentityPatch && !userHasAntigravityIdentity {
identityPatch := strings.TrimSpace(opts.IdentityPatch)
if identityPatch == "" {
identityPatch = defaultIdentityPatch(modelName)
}
parts = append(parts, GeminiPart{Text: identityPatch})
}
// 添加用户的 system prompt
parts = append(parts, userSystemParts...)
if len(parts) == 0 {
return nil
}

View File

@@ -181,12 +181,15 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
return nil, fmt.Errorf("构建请求失败: %w", err)
}
// 构建 HTTP 请求(非流式
req, err := antigravity.NewAPIRequest(ctx, "generateContent", accessToken, requestBody)
// 构建 HTTP 请求(总是使用流式 endpoint与官方客户端一致
req, err := antigravity.NewAPIRequest(ctx, "streamGenerateContent", accessToken, requestBody)
if err != nil {
return nil, err
}
// 调试日志Test 请求信息
log.Printf("[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String())
// 代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
@@ -210,14 +213,8 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody))
}
// 解包 v1internal 响应
unwrapped, err := s.unwrapV1InternalResponse(respBody)
if err != nil {
return nil, fmt.Errorf("解包响应失败: %w", err)
}
// 提取响应文本
text := extractGeminiResponseText(unwrapped)
// 解析流式响应,提取文本
text := extractTextFromSSEResponse(respBody)
return &TestConnectionResult{
Text: text,
@@ -236,6 +233,12 @@ func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model stri
},
},
},
// Antigravity 上游要求必须包含身份提示词
"systemInstruction": map[string]any{
"parts": []map[string]any{
{"text": antigravity.GetDefaultIdentityPatch()},
},
},
}
payloadBytes, _ := json.Marshal(payload)
return s.wrapV1InternalRequest(projectID, model, payloadBytes)
@@ -267,38 +270,66 @@ func (s *AntigravityGatewayService) getClaudeTransformOptions(ctx context.Contex
return opts
}
// extractGeminiResponseText 从 Gemini 响应中提取文本
func extractGeminiResponseText(respBody []byte) string {
var resp map[string]any
if err := json.Unmarshal(respBody, &resp); err != nil {
return ""
}
candidates, ok := resp["candidates"].([]any)
if !ok || len(candidates) == 0 {
return ""
}
candidate, ok := candidates[0].(map[string]any)
if !ok {
return ""
}
content, ok := candidate["content"].(map[string]any)
if !ok {
return ""
}
parts, ok := content["parts"].([]any)
if !ok {
return ""
}
// extractTextFromSSEResponse 从 SSE 流式响应中提取文本
func extractTextFromSSEResponse(respBody []byte) string {
var texts []string
for _, part := range parts {
if partMap, ok := part.(map[string]any); ok {
if text, ok := partMap["text"].(string); ok && text != "" {
texts = append(texts, text)
lines := bytes.Split(respBody, []byte("\n"))
for _, line := range lines {
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
// 跳过 SSE 前缀
if bytes.HasPrefix(line, []byte("data:")) {
line = bytes.TrimPrefix(line, []byte("data:"))
line = bytes.TrimSpace(line)
}
// 跳过非 JSON 行
if len(line) == 0 || line[0] != '{' {
continue
}
// 解析 JSON
var data map[string]any
if err := json.Unmarshal(line, &data); err != nil {
continue
}
// 尝试从 response.candidates[0].content.parts[].text 提取
response, ok := data["response"].(map[string]any)
if !ok {
// 尝试直接从 candidates 提取(某些响应格式)
response = data
}
candidates, ok := response["candidates"].([]any)
if !ok || len(candidates) == 0 {
continue
}
candidate, ok := candidates[0].(map[string]any)
if !ok {
continue
}
content, ok := candidate["content"].(map[string]any)
if !ok {
continue
}
parts, ok := content["parts"].([]any)
if !ok {
continue
}
for _, part := range parts {
if partMap, ok := part.(map[string]any); ok {
if text, ok := partMap["text"].(string); ok && text != "" {
texts = append(texts, text)
}
}
}
}
@@ -306,6 +337,53 @@ func extractGeminiResponseText(respBody []byte) string {
return strings.Join(texts, "")
}
// injectIdentityPatchToGeminiRequest 为 Gemini 格式请求注入身份提示词
// 如果请求中已包含 "You are Antigravity" 则不重复注入
func injectIdentityPatchToGeminiRequest(body []byte) ([]byte, error) {
var request map[string]any
if err := json.Unmarshal(body, &request); err != nil {
return nil, fmt.Errorf("解析 Gemini 请求失败: %w", err)
}
// 检查现有 systemInstruction 是否已包含身份提示词
if sysInst, ok := request["systemInstruction"].(map[string]any); ok {
if parts, ok := sysInst["parts"].([]any); ok {
for _, part := range parts {
if partMap, ok := part.(map[string]any); ok {
if text, ok := partMap["text"].(string); ok {
if strings.Contains(text, "You are Antigravity") {
// 已包含身份提示词,直接返回原始请求
return body, nil
}
}
}
}
}
}
// 获取默认身份提示词
identityPatch := antigravity.GetDefaultIdentityPatch()
// 构建新的 systemInstruction
newPart := map[string]any{"text": identityPatch}
if existing, ok := request["systemInstruction"].(map[string]any); ok {
// 已有 systemInstruction在开头插入身份提示词
if parts, ok := existing["parts"].([]any); ok {
existing["parts"] = append([]any{newPart}, parts...)
} else {
existing["parts"] = []any{newPart}
}
} else {
// 没有 systemInstruction创建新的
request["systemInstruction"] = map[string]any{
"parts": []any{newPart},
}
}
return json.Marshal(request)
}
// wrapV1InternalRequest 包装请求为 v1internal 格式
func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model string, originalBody []byte) ([]byte, error) {
var request any
@@ -316,7 +394,7 @@ func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model strin
wrapped := map[string]any{
"project": projectID,
"requestId": "agent-" + uuid.New().String(),
"userAgent": "sub2api",
"userAgent": "antigravity", // 固定值,与官方客户端一致
"requestType": "agent",
"model": model,
"request": request,
@@ -391,17 +469,20 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
proxyURL = account.Proxy.URL()
}
// 获取转换选项
// Antigravity 上游要求必须包含身份提示词,否则会返回 429
transformOpts := s.getClaudeTransformOptions(ctx)
transformOpts.EnableIdentityPatch = true // 强制启用Antigravity 上游必需
// 转换 Claude 请求为 Gemini 格式
geminiBody, err := antigravity.TransformClaudeToGeminiWithOptions(&claudeReq, projectID, mappedModel, s.getClaudeTransformOptions(ctx))
geminiBody, err := antigravity.TransformClaudeToGeminiWithOptions(&claudeReq, projectID, mappedModel, transformOpts)
if err != nil {
return nil, fmt.Errorf("transform request: %w", err)
}
// 构建上游 action
action := "generateContent"
if claudeReq.Stream {
action = "streamGenerateContent?alt=sse"
}
// Antigravity 上游只支持流式请求,统一使用 streamGenerateContent
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
action := "streamGenerateContent"
// 重试循环
var resp *http.Response
@@ -438,7 +519,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
_ = resp.Body.Close()
if attempt < antigravityMaxRetries {
log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
@@ -557,6 +638,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
var usage *ClaudeUsage
var firstTokenMs *int
if claudeReq.Stream {
// 客户端要求流式,直接透传转换
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
if err != nil {
log.Printf("%s status=stream_error error=%v", prefix, err)
@@ -565,10 +647,14 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
} else {
usage, err = s.handleClaudeNonStreamingResponse(c, resp, originalModel)
// 客户端要求非流式,收集流式响应后转换返回
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel)
if err != nil {
log.Printf("%s status=stream_collect_error error=%v", prefix, err)
return nil, err
}
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
}
return &ForwardResult{
@@ -901,21 +987,22 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
proxyURL = account.Proxy.URL()
}
// 包装请求
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, body)
// Antigravity 上游要求必须包含身份提示词,注入到请求
injectedBody, err := injectIdentityPatchToGeminiRequest(body)
if err != nil {
return nil, err
}
// 构建上游 action
upstreamAction := action
if action == "generateContent" && stream {
upstreamAction = "streamGenerateContent"
}
if stream || upstreamAction == "streamGenerateContent" {
upstreamAction += "?alt=sse"
// 包装请求
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody)
if err != nil {
return nil, err
}
// Antigravity 上游只支持流式请求,统一使用 streamGenerateContent
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回
upstreamAction := "streamGenerateContent"
// 重试循环
var resp *http.Response
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
@@ -992,7 +1079,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if fallbackModel != "" && fallbackModel != mappedModel {
log.Printf("[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name)
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, body)
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, injectedBody)
if err == nil {
fallbackReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, fallbackWrapped)
if err == nil {
@@ -1042,7 +1129,8 @@ handleSuccess:
var usage *ClaudeUsage
var firstTokenMs *int
if stream || upstreamAction == "streamGenerateContent" {
if stream {
// 客户端要求流式,直接透传
streamRes, err := s.handleGeminiStreamingResponse(c, resp, startTime)
if err != nil {
log.Printf("%s status=stream_error error=%v", prefix, err)
@@ -1051,11 +1139,14 @@ handleSuccess:
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
} else {
usageResp, err := s.handleGeminiNonStreamingResponse(c, resp)
// 客户端要求非流式,收集流式响应后返回
streamRes, err := s.handleGeminiStreamToNonStreaming(c, resp, startTime)
if err != nil {
log.Printf("%s status=stream_collect_error error=%v", prefix, err)
return nil, err
}
usage = usageResp
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
}
if usage == nil {
@@ -1102,9 +1193,9 @@ func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int)
// sleepAntigravityBackoffWithContext 带 context 取消检查的退避等待
// 返回 true 表示正常完成等待false 表示 context 已取消
func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
delay := geminiRetryBaseDelay * time.Duration(1<<uint(attempt-1))
if delay > geminiRetryMaxDelay {
delay = geminiRetryMaxDelay
delay := antigravityRetryBaseDelay * time.Duration(1<<uint(attempt-1))
if delay > antigravityRetryMaxDelay {
delay = antigravityRetryMaxDelay
}
// +/- 20% jitter
@@ -1316,25 +1407,150 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
}
}
func (s *AntigravityGatewayService) handleGeminiNonStreamingResponse(c *gin.Context, resp *http.Response) (*ClaudeUsage, error) {
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
// handleGeminiStreamToNonStreaming 读取上游流式响应,合并为非流式响应返回给客户端
// Gemini 流式响应中每个 chunk 都包含累积的完整文本,只需保留最后一个有效响应
func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) {
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
usage := &ClaudeUsage{}
var firstTokenMs *int
var last map[string]any
var lastWithParts map[string]any
type scanEvent struct {
line string
err error
}
// 解包 v1internal 响应
unwrapped, _ := s.unwrapV1InternalResponse(respBody)
var parsed map[string]any
if json.Unmarshal(unwrapped, &parsed) == nil {
if u := extractGeminiUsage(parsed); u != nil {
c.Data(resp.StatusCode, "application/json", unwrapped)
return u, nil
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
c.Data(resp.StatusCode, "application/json", unwrapped)
return &ClaudeUsage{}, nil
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
streamInterval := time.Duration(0)
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
for {
select {
case ev, ok := <-events:
if !ok {
// 流结束,返回收集的响应
goto returnResponse
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long (antigravity non-stream): max_size=%d error=%v", maxLineSize, ev.err)
}
return nil, ev.err
}
line := ev.line
trimmed := strings.TrimRight(line, "\r\n")
if !strings.HasPrefix(trimmed, "data:") {
continue
}
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
if payload == "" || payload == "[DONE]" {
continue
}
// 解包 v1internal 响应
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
if parseErr != nil {
continue
}
var parsed map[string]any
if err := json.Unmarshal(inner, &parsed); err != nil {
continue
}
// 记录首 token 时间
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
last = parsed
// 提取 usage
if u := extractGeminiUsage(parsed); u != nil {
usage = u
}
// 保留最后一个有 parts 的响应
if parts := extractGeminiParts(parsed); len(parts) > 0 {
lastWithParts = parsed
}
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
log.Printf("Stream data interval timeout (antigravity non-stream)")
return nil, fmt.Errorf("stream data interval timeout")
}
}
returnResponse:
// 选择最后一个有效响应
finalResponse := pickGeminiCollectResult(last, lastWithParts)
// 处理空响应情况
if last == nil && lastWithParts == nil {
log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received")
}
respBody, err := json.Marshal(finalResponse)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}
c.Data(http.StatusOK, "application/json", respBody)
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
func (s *AntigravityGatewayService) writeClaudeError(c *gin.Context, status int, errType, message string) error {
@@ -1411,17 +1627,148 @@ func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int,
return fmt.Errorf("%s", message)
}
// handleClaudeNonStreamingResponse 处理 Claude 非流式响应Gemini → Claude 转换)
func (s *AntigravityGatewayService) handleClaudeNonStreamingResponse(c *gin.Context, resp *http.Response, originalModel string) (*ClaudeUsage, error) {
body, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
// handleClaudeStreamToNonStreaming 收集上游流式响应,转换为 Claude 非流式格式返回
// 用于处理客户端非流式请求但上游只支持流式的情况
func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
var firstTokenMs *int
var last map[string]any
var lastWithParts map[string]any
type scanEvent struct {
line string
err error
}
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
streamInterval := time.Duration(0)
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
for {
select {
case ev, ok := <-events:
if !ok {
// 流结束,转换并返回响应
goto returnResponse
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long (antigravity claude non-stream): max_size=%d error=%v", maxLineSize, ev.err)
}
return nil, ev.err
}
line := ev.line
trimmed := strings.TrimRight(line, "\r\n")
if !strings.HasPrefix(trimmed, "data:") {
continue
}
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
if payload == "" || payload == "[DONE]" {
continue
}
// 解包 v1internal 响应
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
if parseErr != nil {
continue
}
var parsed map[string]any
if err := json.Unmarshal(inner, &parsed); err != nil {
continue
}
// 记录首 token 时间
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
last = parsed
// 保留最后一个有 parts 的响应
if parts := extractGeminiParts(parsed); len(parts) > 0 {
lastWithParts = parsed
}
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
log.Printf("Stream data interval timeout (antigravity claude non-stream)")
return nil, fmt.Errorf("stream data interval timeout")
}
}
returnResponse:
// 选择最后一个有效响应
finalResponse := pickGeminiCollectResult(last, lastWithParts)
// 处理空响应情况
if last == nil && lastWithParts == nil {
log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received")
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream")
}
// 序列化为 JSONGemini 格式)
geminiBody, err := json.Marshal(finalResponse)
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response")
return nil, fmt.Errorf("failed to marshal gemini response: %w", err)
}
// 转换 Gemini 响应为 Claude 格式
claudeResp, agUsage, err := antigravity.TransformGeminiToClaude(body, originalModel)
claudeResp, agUsage, err := antigravity.TransformGeminiToClaude(geminiBody, originalModel)
if err != nil {
log.Printf("[antigravity-Forward] transform_error error=%v body=%s", err, string(body))
log.Printf("[antigravity-Forward] transform_error error=%v body=%s", err, string(geminiBody))
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
}
@@ -1434,7 +1781,8 @@ func (s *AntigravityGatewayService) handleClaudeNonStreamingResponse(c *gin.Cont
CacheCreationInputTokens: agUsage.CacheCreationInputTokens,
CacheReadInputTokens: agUsage.CacheReadInputTokens,
}
return usage, nil
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
// handleClaudeStreamingResponse 处理 Claude 流式响应Gemini SSE → Claude SSE 转换)

View File

@@ -109,12 +109,13 @@ type ClaudeUsage struct {
// ForwardResult 转发结果
type ForwardResult struct {
RequestID string
Usage ClaudeUsage
Model string
Stream bool
Duration time.Duration
FirstTokenMs *int // 首字时间(流式请求)
RequestID string
Usage ClaudeUsage
Model string
Stream bool
Duration time.Duration
FirstTokenMs *int // 首字时间(流式请求)
ClientDisconnect bool // 客户端是否在流式传输过程中断开
// 图片生成计费字段(仅 gemini-3-pro-image 使用)
ImageCount int // 生成的图片数量
@@ -1465,6 +1466,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 处理正常响应
var usage *ClaudeUsage
var firstTokenMs *int
var clientDisconnect bool
if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel)
if err != nil {
@@ -1477,6 +1479,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
clientDisconnect = streamResult.clientDisconnect
} else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
if err != nil {
@@ -1485,12 +1488,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
return &ForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel, // 使用原始模型用于计费和日志
Stream: reqStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel, // 使用原始模型用于计费和日志
Stream: reqStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ClientDisconnect: clientDisconnect,
}, nil
}
@@ -1845,8 +1849,9 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
// streamingResult 流式响应结果
type streamingResult struct {
usage *ClaudeUsage
firstTokenMs *int
usage *ClaudeUsage
firstTokenMs *int
clientDisconnect bool // 客户端是否在流式传输过程中断开
}
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
@@ -1942,14 +1947,27 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
needModelReplace := originalModel != mappedModel
clientDisconnected := false // 客户端断开标志断开后继续读取上游以获取完整usage
for {
select {
case ev, ok := <-events:
if !ok {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
// 上游完成,返回结果
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
}
if ev.err != nil {
// 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取)
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
log.Printf("Context canceled during streaming, returning collected usage")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
// 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage
if clientDisconnected {
log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err)
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
// 客户端未断开,正常的错误处理
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large")
@@ -1960,38 +1978,40 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
line := ev.line
if line == "event: error" {
// 上游返回错误事件,如果客户端已断开仍返回已收集的 usage
if clientDisconnected {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
return nil, errors.New("have error in stream")
}
// Extract data from SSE line (supports both "data: " and "data:" formats)
var data string
if sseDataRe.MatchString(line) {
data := sseDataRe.ReplaceAllString(line, "")
data = sseDataRe.ReplaceAllString(line, "")
// 如果有模型映射替换响应中的model字段
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
}
// 转发行
// 写入客户端(统一处理 data 行和非 data 行)
if !clientDisconnected {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
clientDisconnected = true
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
} else {
flusher.Flush()
}
flusher.Flush()
}
// 记录首字时间:第一个有效的 content_block_delta 或 message_start
if firstTokenMs == nil && data != "" && data != "[DONE]" {
// 无论客户端是否断开,都解析 usage仅对 data 行)
if data != "" {
if firstTokenMs == nil && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
s.parseSSEUsage(data, usage)
} else {
// 非 data 行直接转发
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
}
case <-intervalCh:
@@ -1999,6 +2019,11 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
if time.Since(lastRead) < streamInterval {
continue
}
if clientDisconnected {
// 客户端已断开,上游也超时了,返回已收集的 usage
log.Printf("Upstream timeout after client disconnect, returning collected usage")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
sendErrorEvent("stream_timeout")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")