From e1a68497d63775222772cf435a08fef057d9bfdb Mon Sep 17 00:00:00 2001 From: erio Date: Sat, 7 Feb 2026 17:06:49 +0800 Subject: [PATCH 01/41] =?UTF-8?q?refactor:=20simplify=20sticky=20session?= =?UTF-8?q?=20rate=20limit=20handling=20=E2=80=94=20switch=20immediately?= =?UTF-8?q?=20on=20any=20rate=20limit?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove threshold-based waiting in both sticky session and antigravity pre-check paths. When a model is rate-limited, immediately clear the sticky session and switch accounts instead of waiting for short durations. --- .../service/antigravity_gateway_service.go | 26 +++++-------------- .../service/antigravity_rate_limit_test.go | 18 ++++++------- backend/internal/service/gateway_service.go | 14 +++------- .../internal/service/sticky_session_test.go | 13 +++++----- 4 files changed, 26 insertions(+), 45 deletions(-) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 7fdb4d19..22065e61 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -264,27 +264,15 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam // antigravityRetryLoop 执行带 URL fallback 的重试循环 func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) { - // 预检查:如果账号已限流,根据剩余时间决定等待或切换 + // 预检查:如果账号已限流,直接返回切换信号 if p.requestedModel != "" { if remaining := p.account.GetRateLimitRemainingTimeWithContext(p.ctx, p.requestedModel); remaining > 0 { - if remaining < antigravityRateLimitThreshold { - // 限流剩余时间较短,等待后继续 - log.Printf("%s pre_check: rate_limit_wait remaining=%v model=%s account=%d", - p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID) - select { - case <-p.ctx.Done(): - return nil, p.ctx.Err() - case <-time.After(remaining): - } - } else { - // 限流剩余时间较长,返回账号切换信号 - log.Printf("%s pre_check: rate_limit_switch remaining=%v model=%s account=%d", - p.prefix, remaining.Truncate(time.Second), p.requestedModel, p.account.ID) - return nil, &AntigravityAccountSwitchError{ - OriginalAccountID: p.account.ID, - RateLimitedModel: p.requestedModel, - IsStickySession: p.isStickySession, - } + log.Printf("%s pre_check: rate_limit_switch remaining=%v model=%s account=%d", + p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID) + return nil, &AntigravityAccountSwitchError{ + OriginalAccountID: p.account.ID, + RateLimitedModel: p.requestedModel, + IsStickySession: p.isStickySession, } } } diff --git a/backend/internal/service/antigravity_rate_limit_test.go b/backend/internal/service/antigravity_rate_limit_test.go index 20936356..cd2a7a4a 100644 --- a/backend/internal/service/antigravity_rate_limit_test.go +++ b/backend/internal/service/antigravity_rate_limit_test.go @@ -803,7 +803,7 @@ func TestSetModelRateLimitByModelName_NotConvertToScope(t *testing.T) { require.NotEqual(t, "claude_sonnet", call.modelKey, "should NOT be scope") } -func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testing.T) { +func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRateLimited(t *testing.T) { upstream := &recordingOKUpstream{} account := &Account{ ID: 1, @@ -815,19 +815,15 @@ func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testi Extra: map[string]any{ modelRateLimitsKey: map[string]any{ "claude-sonnet-4-5": map[string]any{ - // RFC3339 here is second-precision; keep it safely in the future. "rate_limit_reset_at": time.Now().Add(2 * time.Second).Format(time.RFC3339), }, }, }, } - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond) - defer cancel() - svc := &AntigravityGatewayService{} result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{ - ctx: ctx, + ctx: context.Background(), prefix: "[test]", account: account, accessToken: "token", @@ -841,12 +837,16 @@ func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testi }, }) - require.ErrorIs(t, err, context.DeadlineExceeded) require.Nil(t, result) - require.Equal(t, 0, upstream.calls, "should not call upstream while waiting on pre-check") + var switchErr *AntigravityAccountSwitchError + require.ErrorAs(t, err, &switchErr) + require.Equal(t, account.ID, switchErr.OriginalAccountID) + require.Equal(t, "claude-sonnet-4-5", switchErr.RateLimitedModel) + require.True(t, switchErr.IsStickySession) + require.Equal(t, 0, upstream.calls, "should not call upstream when switching on pre-check") } -func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingAtOrAboveThreshold(t *testing.T) { +func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingLong(t *testing.T) { upstream := &recordingOKUpstream{} account := &Account{ ID: 2, diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 32646b11..6ba15399 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -323,21 +323,15 @@ func derefGroupID(groupID *int64) int64 { return *groupID } -// stickySessionRateLimitThreshold 定义清除粘性会话的限流时间阈值。 -// 当账号限流剩余时间超过此阈值时,清除粘性会话以便切换到其他账号。 -// 低于此阈值时保持粘性会话,等待短暂限流结束。 -const stickySessionRateLimitThreshold = 10 * time.Second - // shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。 // 当账号状态为错误、禁用、不可调度、处于临时不可调度期间, -// 或模型限流剩余时间超过 stickySessionRateLimitThreshold 时,返回 true。 +// 或请求的模型处于限流状态时,返回 true。 // 这确保后续请求不会继续使用不可用的账号。 // // shouldClearStickySession checks if an account is in an unschedulable state // and the sticky session binding should be cleared. // Returns true when account status is error/disabled, schedulable is false, -// within temporary unschedulable period, or model rate limit remaining time -// exceeds stickySessionRateLimitThreshold. +// within temporary unschedulable period, or the requested model is rate-limited. // This ensures subsequent requests won't continue using unavailable accounts. func shouldClearStickySession(account *Account, requestedModel string) bool { if account == nil { @@ -349,8 +343,8 @@ func shouldClearStickySession(account *Account, requestedModel string) bool { if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) { return true } - // 检查模型限流和 scope 限流,只在超过阈值时清除粘性会话 - if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > stickySessionRateLimitThreshold { + // 检查模型限流和 scope 限流,有限流即清除粘性会话 + if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > 0 { return true } return false diff --git a/backend/internal/service/sticky_session_test.go b/backend/internal/service/sticky_session_test.go index c70f12fe..e7ef8982 100644 --- a/backend/internal/service/sticky_session_test.go +++ b/backend/internal/service/sticky_session_test.go @@ -23,8 +23,7 @@ import ( // - 临时不可调度且未过期:清理 // - 临时不可调度已过期:不清理 // - 正常可调度状态:不清理 -// - 模型限流超过阈值:清理 -// - 模型限流未超过阈值:不清理 +// - 模型限流(任意时长):清理 // // TestShouldClearStickySession tests the sticky session clearing logic. // Verifies correct behavior for various account states including: @@ -35,9 +34,9 @@ func TestShouldClearStickySession(t *testing.T) { future := now.Add(1 * time.Hour) past := now.Add(-1 * time.Hour) - // 短限流时间(低于阈值,不应清除粘性会话) + // 短限流时间(有限流即清除粘性会话) shortRateLimitReset := now.Add(5 * time.Second).Format(time.RFC3339) - // 长限流时间(超过阈值,应清除粘性会话) + // 长限流时间(有限流即清除粘性会话) longRateLimitReset := now.Add(30 * time.Second).Format(time.RFC3339) tests := []struct { @@ -53,7 +52,7 @@ func TestShouldClearStickySession(t *testing.T) { {name: "temp unschedulable", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &future}, requestedModel: "", want: true}, {name: "temp unschedulable expired", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &past}, requestedModel: "", want: false}, {name: "active schedulable", account: &Account{Status: StatusActive, Schedulable: true}, requestedModel: "", want: false}, - // 模型限流测试 + // 模型限流测试:有限流即清除 { name: "model rate limited short duration", account: &Account{ @@ -68,7 +67,7 @@ func TestShouldClearStickySession(t *testing.T) { }, }, requestedModel: "claude-sonnet-4", - want: false, // 低于阈值,不清除 + want: true, // 有限流即清除 }, { name: "model rate limited long duration", @@ -84,7 +83,7 @@ func TestShouldClearStickySession(t *testing.T) { }, }, requestedModel: "claude-sonnet-4", - want: true, // 超过阈值,清除 + want: true, // 有限流即清除 }, { name: "model rate limited different model", From 50a783ff0153b6c8b680ff3132e69bf1ebd65045 Mon Sep 17 00:00:00 2001 From: erio Date: Sat, 7 Feb 2026 17:35:05 +0800 Subject: [PATCH 02/41] feat: add Anthropic sticky session digest chain matching via Trie The previous fallback (step 3) in GenerateSessionHash hashed system + all messages together, producing a different hash each round as the conversation grew ([a] -> [a,b] -> [a,b,c]). This made fallback sticky sessions ineffective for multi-turn conversations. Implement per-message Trie digest chain matching (reusing Gemini's Trie infrastructure) so that the previous round's chain is always a prefix of the current round's chain, enabling reliable session affinity. --- backend/internal/handler/gateway_handler.go | 62 +++ backend/internal/repository/gateway_cache.go | 38 ++ backend/internal/service/anthropic_session.go | 89 +++++ .../service/anthropic_session_test.go | 357 ++++++++++++++++++ .../service/gateway_multiplatform_test.go | 8 + backend/internal/service/gateway_service.go | 56 +-- .../service/gemini_multiplatform_test.go | 8 + .../service/openai_gateway_service_test.go | 8 + 8 files changed, 604 insertions(+), 22 deletions(-) create mode 100644 backend/internal/service/anthropic_session.go create mode 100644 backend/internal/service/anthropic_session_test.go diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 7e6b2f03..a505f578 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -22,6 +22,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" + "github.com/google/uuid" ) // GatewayHandler handles API gateway requests @@ -212,6 +213,53 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if sessionKey != "" { sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey) } + + // === Anthropic 内容摘要会话 Fallback 逻辑 === + // 当原有会话标识无效时(sessionBoundAccountID == 0),尝试基于内容摘要链匹配 + var anthropicDigestChain string + var anthropicPrefixHash string + var anthropicSessionUUID string + useAnthropicDigestFallback := sessionBoundAccountID == 0 && platform != service.PlatformGemini + + if useAnthropicDigestFallback { + anthropicDigestChain = service.BuildAnthropicDigestChain(parsedReq) + if anthropicDigestChain != "" { + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + anthropicPrefixHash = service.GenerateGeminiPrefixHash( + subject.UserID, + apiKey.ID, + clientIP, + userAgent, + platform, + reqModel, + ) + + foundUUID, foundAccountID, found := h.gatewayService.FindAnthropicSession( + c.Request.Context(), + derefGroupID(apiKey.GroupID), + anthropicPrefixHash, + anthropicDigestChain, + ) + if found { + sessionBoundAccountID = foundAccountID + anthropicSessionUUID = foundUUID + log.Printf("[Anthropic] Digest fallback matched: uuid=%s, accountID=%d, chain=%s", + foundUUID[:8], foundAccountID, truncateDigestChain(anthropicDigestChain)) + + if sessionKey == "" { + sessionKey = service.GenerateAnthropicDigestSessionKey(anthropicPrefixHash, foundUUID) + } + _ = h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, foundAccountID) + } else { + anthropicSessionUUID = uuid.New().String() + if sessionKey == "" { + sessionKey = service.GenerateAnthropicDigestSessionKey(anthropicPrefixHash, anthropicSessionUUID) + } + } + } + } + // 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号 hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0 @@ -540,6 +588,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) + // 保存 Anthropic 内容摘要会话(用于 Fallback 匹配) + if useAnthropicDigestFallback && anthropicDigestChain != "" && anthropicPrefixHash != "" { + if err := h.gatewayService.SaveAnthropicSession( + c.Request.Context(), + derefGroupID(apiKey.GroupID), + anthropicPrefixHash, + anthropicDigestChain, + anthropicSessionUUID, + account.ID, + ); err != nil { + log.Printf("[Anthropic] Failed to save digest session: %v", err) + } + } + // 异步记录使用量(subscription已在函数开头获取) go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) diff --git a/backend/internal/repository/gateway_cache.go b/backend/internal/repository/gateway_cache.go index 9365252a..46ae0c16 100644 --- a/backend/internal/repository/gateway_cache.go +++ b/backend/internal/repository/gateway_cache.go @@ -238,3 +238,41 @@ func (c *gatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, pre return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err() } + +// ============ Anthropic 会话 Fallback 方法 (复用 Trie 实现) ============ + +// FindAnthropicSession 查找 Anthropic 会话(复用 Gemini Trie Lua 脚本) +func (c *gatewayCache) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { + if digestChain == "" { + return "", 0, false + } + + trieKey := service.BuildAnthropicTrieKey(groupID, prefixHash) + ttlSeconds := int(service.AnthropicSessionTTL().Seconds()) + + result, err := c.rdb.Eval(ctx, geminiTrieFindScript, []string{trieKey}, digestChain, ttlSeconds).Result() + if err != nil || result == nil { + return "", 0, false + } + + value, ok := result.(string) + if !ok || value == "" { + return "", 0, false + } + + uuid, accountID, ok = service.ParseGeminiSessionValue(value) + return uuid, accountID, ok +} + +// SaveAnthropicSession 保存 Anthropic 会话(复用 Gemini Trie Lua 脚本) +func (c *gatewayCache) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { + if digestChain == "" { + return nil + } + + trieKey := service.BuildAnthropicTrieKey(groupID, prefixHash) + value := service.FormatGeminiSessionValue(uuid, accountID) + ttlSeconds := int(service.AnthropicSessionTTL().Seconds()) + + return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err() +} diff --git a/backend/internal/service/anthropic_session.go b/backend/internal/service/anthropic_session.go new file mode 100644 index 00000000..2d86ed35 --- /dev/null +++ b/backend/internal/service/anthropic_session.go @@ -0,0 +1,89 @@ +package service + +import ( + "encoding/json" + "strconv" + "strings" + "time" +) + +// Anthropic 会话 Fallback 相关常量 +const ( + // anthropicSessionTTLSeconds Anthropic 会话缓存 TTL(5 分钟) + anthropicSessionTTLSeconds = 300 + + // anthropicTrieKeyPrefix Anthropic Trie 会话 key 前缀 + anthropicTrieKeyPrefix = "anthropic:trie:" + + // anthropicDigestSessionKeyPrefix Anthropic 摘要 fallback 会话 key 前缀 + anthropicDigestSessionKeyPrefix = "anthropic:digest:" +) + +// AnthropicSessionTTL 返回 Anthropic 会话缓存 TTL +func AnthropicSessionTTL() time.Duration { + return anthropicSessionTTLSeconds * time.Second +} + +// BuildAnthropicDigestChain 根据 Anthropic 请求生成摘要链 +// 格式: s:-u:-a:-u:-... +// s = system, u = user, a = assistant +func BuildAnthropicDigestChain(parsed *ParsedRequest) string { + if parsed == nil { + return "" + } + + var parts []string + + // 1. system prompt + if parsed.System != nil { + systemData, _ := json.Marshal(parsed.System) + if len(systemData) > 0 && string(systemData) != "null" { + parts = append(parts, "s:"+shortHash(systemData)) + } + } + + // 2. messages + for _, msg := range parsed.Messages { + msgMap, ok := msg.(map[string]any) + if !ok { + continue + } + role, _ := msgMap["role"].(string) + prefix := rolePrefix(role) + content := msgMap["content"] + contentData, _ := json.Marshal(content) + parts = append(parts, prefix+":"+shortHash(contentData)) + } + + return strings.Join(parts, "-") +} + +// rolePrefix 将 Anthropic 的 role 映射为单字符前缀 +func rolePrefix(role string) string { + switch role { + case "assistant": + return "a" + default: + return "u" + } +} + +// BuildAnthropicTrieKey 构建 Anthropic Trie Redis key +// 格式: anthropic:trie:{groupID}:{prefixHash} +func BuildAnthropicTrieKey(groupID int64, prefixHash string) string { + return anthropicTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash +} + +// GenerateAnthropicDigestSessionKey 生成 Anthropic 摘要 fallback 的 sessionKey +// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey +func GenerateAnthropicDigestSessionKey(prefixHash, uuid string) string { + prefix := prefixHash + if len(prefixHash) >= 8 { + prefix = prefixHash[:8] + } + uuidPart := uuid + if len(uuid) >= 8 { + uuidPart = uuid[:8] + } + return anthropicDigestSessionKeyPrefix + prefix + ":" + uuidPart +} diff --git a/backend/internal/service/anthropic_session_test.go b/backend/internal/service/anthropic_session_test.go new file mode 100644 index 00000000..e2f873e7 --- /dev/null +++ b/backend/internal/service/anthropic_session_test.go @@ -0,0 +1,357 @@ +package service + +import ( + "strings" + "testing" +) + +func TestBuildAnthropicDigestChain_NilRequest(t *testing.T) { + result := BuildAnthropicDigestChain(nil) + if result != "" { + t.Errorf("expected empty string for nil request, got: %s", result) + } +} + +func TestBuildAnthropicDigestChain_EmptyMessages(t *testing.T) { + parsed := &ParsedRequest{ + Messages: []any{}, + } + result := BuildAnthropicDigestChain(parsed) + if result != "" { + t.Errorf("expected empty string for empty messages, got: %s", result) + } +} + +func TestBuildAnthropicDigestChain_SingleUserMessage(t *testing.T) { + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + result := BuildAnthropicDigestChain(parsed) + parts := splitChain(result) + if len(parts) != 1 { + t.Fatalf("expected 1 part, got %d: %s", len(parts), result) + } + if !strings.HasPrefix(parts[0], "u:") { + t.Errorf("expected prefix 'u:', got: %s", parts[0]) + } +} + +func TestBuildAnthropicDigestChain_UserAndAssistant(t *testing.T) { + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "hi there"}, + }, + } + result := BuildAnthropicDigestChain(parsed) + parts := splitChain(result) + if len(parts) != 2 { + t.Fatalf("expected 2 parts, got %d: %s", len(parts), result) + } + if !strings.HasPrefix(parts[0], "u:") { + t.Errorf("part[0] expected prefix 'u:', got: %s", parts[0]) + } + if !strings.HasPrefix(parts[1], "a:") { + t.Errorf("part[1] expected prefix 'a:', got: %s", parts[1]) + } +} + +func TestBuildAnthropicDigestChain_WithSystemString(t *testing.T) { + parsed := &ParsedRequest{ + System: "You are a helpful assistant", + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + result := BuildAnthropicDigestChain(parsed) + parts := splitChain(result) + if len(parts) != 2 { + t.Fatalf("expected 2 parts (s + u), got %d: %s", len(parts), result) + } + if !strings.HasPrefix(parts[0], "s:") { + t.Errorf("part[0] expected prefix 's:', got: %s", parts[0]) + } + if !strings.HasPrefix(parts[1], "u:") { + t.Errorf("part[1] expected prefix 'u:', got: %s", parts[1]) + } +} + +func TestBuildAnthropicDigestChain_WithSystemContentBlocks(t *testing.T) { + parsed := &ParsedRequest{ + System: []any{ + map[string]any{"type": "text", "text": "You are a helpful assistant"}, + }, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + result := BuildAnthropicDigestChain(parsed) + parts := splitChain(result) + if len(parts) != 2 { + t.Fatalf("expected 2 parts (s + u), got %d: %s", len(parts), result) + } + if !strings.HasPrefix(parts[0], "s:") { + t.Errorf("part[0] expected prefix 's:', got: %s", parts[0]) + } +} + +func TestBuildAnthropicDigestChain_ConversationPrefixRelationship(t *testing.T) { + // 核心测试:验证对话增长时链的前缀关系 + // 上一轮的完整链一定是下一轮链的前缀 + system := "You are a helpful assistant" + + // 第 1 轮: system + user + round1 := &ParsedRequest{ + System: system, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + chain1 := BuildAnthropicDigestChain(round1) + + // 第 2 轮: system + user + assistant + user + round2 := &ParsedRequest{ + System: system, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "hi there"}, + map[string]any{"role": "user", "content": "how are you?"}, + }, + } + chain2 := BuildAnthropicDigestChain(round2) + + // 第 3 轮: system + user + assistant + user + assistant + user + round3 := &ParsedRequest{ + System: system, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "hi there"}, + map[string]any{"role": "user", "content": "how are you?"}, + map[string]any{"role": "assistant", "content": "I'm doing well"}, + map[string]any{"role": "user", "content": "great"}, + }, + } + chain3 := BuildAnthropicDigestChain(round3) + + t.Logf("Chain1: %s", chain1) + t.Logf("Chain2: %s", chain2) + t.Logf("Chain3: %s", chain3) + + // chain1 是 chain2 的前缀 + if !strings.HasPrefix(chain2, chain1) { + t.Errorf("chain1 should be prefix of chain2:\n chain1: %s\n chain2: %s", chain1, chain2) + } + + // chain2 是 chain3 的前缀 + if !strings.HasPrefix(chain3, chain2) { + t.Errorf("chain2 should be prefix of chain3:\n chain2: %s\n chain3: %s", chain2, chain3) + } + + // chain1 也是 chain3 的前缀(传递性) + if !strings.HasPrefix(chain3, chain1) { + t.Errorf("chain1 should be prefix of chain3:\n chain1: %s\n chain3: %s", chain1, chain3) + } +} + +func TestBuildAnthropicDigestChain_DifferentSystemProducesDifferentChain(t *testing.T) { + parsed1 := &ParsedRequest{ + System: "System A", + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + parsed2 := &ParsedRequest{ + System: "System B", + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + + chain1 := BuildAnthropicDigestChain(parsed1) + chain2 := BuildAnthropicDigestChain(parsed2) + + if chain1 == chain2 { + t.Error("Different system prompts should produce different chains") + } + + // 但 user 部分的 hash 应该相同 + parts1 := splitChain(chain1) + parts2 := splitChain(chain2) + if parts1[1] != parts2[1] { + t.Error("Same user message should produce same hash regardless of system") + } +} + +func TestBuildAnthropicDigestChain_DifferentContentProducesDifferentChain(t *testing.T) { + parsed1 := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "ORIGINAL reply"}, + map[string]any{"role": "user", "content": "next"}, + }, + } + parsed2 := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "TAMPERED reply"}, + map[string]any{"role": "user", "content": "next"}, + }, + } + + chain1 := BuildAnthropicDigestChain(parsed1) + chain2 := BuildAnthropicDigestChain(parsed2) + + if chain1 == chain2 { + t.Error("Different content should produce different chains") + } + + parts1 := splitChain(chain1) + parts2 := splitChain(chain2) + // 第一个 user message hash 应该相同 + if parts1[0] != parts2[0] { + t.Error("First user message hash should be the same") + } + // assistant reply hash 应该不同 + if parts1[1] == parts2[1] { + t.Error("Assistant reply hash should differ") + } +} + +func TestBuildAnthropicDigestChain_Deterministic(t *testing.T) { + parsed := &ParsedRequest{ + System: "test system", + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "hi"}, + }, + } + + chain1 := BuildAnthropicDigestChain(parsed) + chain2 := BuildAnthropicDigestChain(parsed) + + if chain1 != chain2 { + t.Errorf("BuildAnthropicDigestChain not deterministic: %s vs %s", chain1, chain2) + } +} + +func TestBuildAnthropicTrieKey(t *testing.T) { + tests := []struct { + name string + groupID int64 + prefixHash string + want string + }{ + { + name: "normal", + groupID: 123, + prefixHash: "abcdef12", + want: "anthropic:trie:123:abcdef12", + }, + { + name: "zero group", + groupID: 0, + prefixHash: "xyz", + want: "anthropic:trie:0:xyz", + }, + { + name: "empty prefix", + groupID: 1, + prefixHash: "", + want: "anthropic:trie:1:", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := BuildAnthropicTrieKey(tt.groupID, tt.prefixHash) + if got != tt.want { + t.Errorf("BuildAnthropicTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want) + } + }) + } +} + +func TestGenerateAnthropicDigestSessionKey(t *testing.T) { + tests := []struct { + name string + prefixHash string + uuid string + want string + }{ + { + name: "normal 16 char hash with uuid", + prefixHash: "abcdefgh12345678", + uuid: "550e8400-e29b-41d4-a716-446655440000", + want: "anthropic:digest:abcdefgh:550e8400", + }, + { + name: "exactly 8 chars", + prefixHash: "12345678", + uuid: "abcdefgh", + want: "anthropic:digest:12345678:abcdefgh", + }, + { + name: "short values", + prefixHash: "abc", + uuid: "xyz", + want: "anthropic:digest:abc:xyz", + }, + { + name: "empty values", + prefixHash: "", + uuid: "", + want: "anthropic:digest::", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GenerateAnthropicDigestSessionKey(tt.prefixHash, tt.uuid) + if got != tt.want { + t.Errorf("GenerateAnthropicDigestSessionKey(%q, %q) = %q, want %q", tt.prefixHash, tt.uuid, got, tt.want) + } + }) + } + + // 验证不同 uuid 产生不同 sessionKey + t.Run("different uuid different key", func(t *testing.T) { + hash := "sameprefix123456" + result1 := GenerateAnthropicDigestSessionKey(hash, "uuid0001-session-a") + result2 := GenerateAnthropicDigestSessionKey(hash, "uuid0002-session-b") + if result1 == result2 { + t.Errorf("Different UUIDs should produce different session keys: %s vs %s", result1, result2) + } + }) +} + +func TestAnthropicSessionTTL(t *testing.T) { + ttl := AnthropicSessionTTL() + if ttl.Seconds() != 300 { + t.Errorf("expected 300 seconds, got: %v", ttl.Seconds()) + } +} + +func TestBuildAnthropicDigestChain_ContentBlocks(t *testing.T) { + // 测试 content 为 content blocks 数组的情况 + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "text", "text": "describe this image"}, + map[string]any{"type": "image", "source": map[string]any{"type": "base64"}}, + }, + }, + }, + } + result := BuildAnthropicDigestChain(parsed) + parts := splitChain(result) + if len(parts) != 1 { + t.Fatalf("expected 1 part, got %d: %s", len(parts), result) + } + if !strings.HasPrefix(parts[0], "u:") { + t.Errorf("expected prefix 'u:', got: %s", parts[0]) + } +} diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index b3e60c21..d9c852e0 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -232,6 +232,14 @@ func (m *mockGatewayCacheForPlatform) SaveGeminiSession(ctx context.Context, gro return nil } +func (m *mockGatewayCacheForPlatform) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { + return "", 0, false +} + +func (m *mockGatewayCacheForPlatform) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { + return nil +} + type mockGroupRepoForGateway struct { groups map[int64]*Group getByIDCalls int diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 6ba15399..438d6643 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -313,6 +313,14 @@ type GatewayCache interface { // SaveGeminiSession 保存 Gemini 会话 // Save Gemini session binding SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error + + // FindAnthropicSession 查找 Anthropic 会话(Trie 匹配) + // Find Anthropic session using Trie matching + FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) + + // SaveAnthropicSession 保存 Anthropic 会话 + // Save Anthropic session binding + SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error } // derefGroupID safely dereferences *int64 to int64, returning 0 if nil @@ -482,23 +490,25 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { return s.hashContent(cacheableContent) } - // 3. Fallback: 使用 system 内容 + // 3. 最后 fallback: 使用 system + 所有消息的完整摘要串 + var combined strings.Builder if parsed.System != nil { systemText := s.extractTextFromSystem(parsed.System) if systemText != "" { - return s.hashContent(systemText) + combined.WriteString(systemText) } } - - // 4. 最后 fallback: 使用第一条消息 - if len(parsed.Messages) > 0 { - if firstMsg, ok := parsed.Messages[0].(map[string]any); ok { - msgText := s.extractTextFromContent(firstMsg["content"]) + for _, msg := range parsed.Messages { + if m, ok := msg.(map[string]any); ok { + msgText := s.extractTextFromContent(m["content"]) if msgText != "" { - return s.hashContent(msgText) + combined.WriteString(msgText) } } } + if combined.Len() > 0 { + return s.hashContent(combined.String()) + } return "" } @@ -541,6 +551,22 @@ func (s *GatewayService) SaveGeminiSession(ctx context.Context, groupID int64, p return s.cache.SaveGeminiSession(ctx, groupID, prefixHash, digestChain, uuid, accountID) } +// FindAnthropicSession 查找 Anthropic 会话(基于内容摘要链的 Fallback 匹配) +func (s *GatewayService) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { + if digestChain == "" || s.cache == nil { + return "", 0, false + } + return s.cache.FindAnthropicSession(ctx, groupID, prefixHash, digestChain) +} + +// SaveAnthropicSession 保存 Anthropic 会话 +func (s *GatewayService) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { + if digestChain == "" || s.cache == nil { + return nil + } + return s.cache.SaveAnthropicSession(ctx, groupID, prefixHash, digestChain, uuid, accountID) +} + func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { if parsed == nil { return "" @@ -1104,7 +1130,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro result.ReleaseFunc() // 释放槽位 // 继续到负载感知选择 } else { - _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) if s.debugModelRoutingEnabled() { log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) } @@ -1258,7 +1283,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !s.checkAndRegisterSession(ctx, account, sessionHash) { result.ReleaseFunc() // 释放槽位,继续到 Layer 2 } else { - _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) return &AccountSelectionResult{ Account: account, Acquired: true, @@ -2163,9 +2187,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { - if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { - log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) - } if s.debugModelRoutingEnabled() { log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) } @@ -2266,9 +2287,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { - if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { - log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) - } return account, nil } } @@ -2377,9 +2395,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { - if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { - log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) - } if s.debugModelRoutingEnabled() { log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) } @@ -2482,9 +2497,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { - if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { - log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) - } return account, nil } } diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 601e7e2c..0c54dc39 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -281,6 +281,14 @@ func (m *mockGatewayCacheForGemini) SaveGeminiSession(ctx context.Context, group return nil } +func (m *mockGatewayCacheForGemini) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { + return "", 0, false +} + +func (m *mockGatewayCacheForGemini) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { + return nil +} + // TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择 func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) { ctx := context.Background() diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 1c2c81ca..159b0afb 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -220,6 +220,14 @@ func (c *stubGatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, return nil } +func (c *stubGatewayCache) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { + return "", 0, false +} + +func (c *stubGatewayCache) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { + return nil +} + func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) { now := time.Now() resetAt := now.Add(10 * time.Minute) From 86b503f87fe22b7b0a21ecb60860d5f8bc373848 Mon Sep 17 00:00:00 2001 From: erio Date: Sat, 7 Feb 2026 17:57:30 +0800 Subject: [PATCH 03/41] refactor: remove Anthropic digest chain from Messages handler The digest chain fallback is only needed for Gemini endpoints, not for the Anthropic Messages API path. Remove the handler integration while keeping the reusable service/repository layer for future use. --- backend/internal/handler/gateway_handler.go | 62 --------------------- 1 file changed, 62 deletions(-) diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index a505f578..7e6b2f03 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -22,7 +22,6 @@ import ( "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" - "github.com/google/uuid" ) // GatewayHandler handles API gateway requests @@ -213,53 +212,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if sessionKey != "" { sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey) } - - // === Anthropic 内容摘要会话 Fallback 逻辑 === - // 当原有会话标识无效时(sessionBoundAccountID == 0),尝试基于内容摘要链匹配 - var anthropicDigestChain string - var anthropicPrefixHash string - var anthropicSessionUUID string - useAnthropicDigestFallback := sessionBoundAccountID == 0 && platform != service.PlatformGemini - - if useAnthropicDigestFallback { - anthropicDigestChain = service.BuildAnthropicDigestChain(parsedReq) - if anthropicDigestChain != "" { - userAgent := c.GetHeader("User-Agent") - clientIP := ip.GetClientIP(c) - anthropicPrefixHash = service.GenerateGeminiPrefixHash( - subject.UserID, - apiKey.ID, - clientIP, - userAgent, - platform, - reqModel, - ) - - foundUUID, foundAccountID, found := h.gatewayService.FindAnthropicSession( - c.Request.Context(), - derefGroupID(apiKey.GroupID), - anthropicPrefixHash, - anthropicDigestChain, - ) - if found { - sessionBoundAccountID = foundAccountID - anthropicSessionUUID = foundUUID - log.Printf("[Anthropic] Digest fallback matched: uuid=%s, accountID=%d, chain=%s", - foundUUID[:8], foundAccountID, truncateDigestChain(anthropicDigestChain)) - - if sessionKey == "" { - sessionKey = service.GenerateAnthropicDigestSessionKey(anthropicPrefixHash, foundUUID) - } - _ = h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, foundAccountID) - } else { - anthropicSessionUUID = uuid.New().String() - if sessionKey == "" { - sessionKey = service.GenerateAnthropicDigestSessionKey(anthropicPrefixHash, anthropicSessionUUID) - } - } - } - } - // 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号 hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0 @@ -588,20 +540,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) - // 保存 Anthropic 内容摘要会话(用于 Fallback 匹配) - if useAnthropicDigestFallback && anthropicDigestChain != "" && anthropicPrefixHash != "" { - if err := h.gatewayService.SaveAnthropicSession( - c.Request.Context(), - derefGroupID(apiKey.GroupID), - anthropicPrefixHash, - anthropicDigestChain, - anthropicSessionUUID, - account.ID, - ); err != nil { - log.Printf("[Anthropic] Failed to save digest session: %v", err) - } - } - // 异步记录使用量(subscription已在函数开头获取) go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) From 36e6fb5fc8ff3a4a51f31e4363a5c735b5f13254 Mon Sep 17 00:00:00 2001 From: erio Date: Sat, 7 Feb 2026 18:13:37 +0800 Subject: [PATCH 04/41] ci: trigger CI for new PR From e3748da860de64896ba05b6ab1566c3ff2124de1 Mon Sep 17 00:00:00 2001 From: erio Date: Sat, 7 Feb 2026 18:18:15 +0800 Subject: [PATCH 05/41] fix(lint): handle errcheck for strings.Builder.WriteString --- backend/internal/service/gateway_service.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 438d6643..480f5b67 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -495,14 +495,14 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { if parsed.System != nil { systemText := s.extractTextFromSystem(parsed.System) if systemText != "" { - combined.WriteString(systemText) + _, _ = combined.WriteString(systemText) } } for _, msg := range parsed.Messages { if m, ok := msg.(map[string]any); ok { msgText := s.extractTextFromContent(m["content"]) if msgText != "" { - combined.WriteString(msgText) + _, _ = combined.WriteString(msgText) } } } From 3077fd279d6edd3e23fa20078f05fba0abfcd27b Mon Sep 17 00:00:00 2001 From: erio Date: Sat, 7 Feb 2026 19:16:59 +0800 Subject: [PATCH 06/41] feat: smart retry max 1 attempt + clear sticky session on failure MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Change antigravitySmartRetryMaxAttempts from 3 to 1 to prevent repeated rate limiting and long waits - Clear sticky session binding (DeleteSessionAccountID) after smart retry exhaustion, so subsequent requests don't hit the same rate-limited account - Add flow diagrams to Forward/ForwardGemini doc comments - Add comprehensive unit tests covering: - Sticky session cleared on retry failure (429, 503, network error) - Sticky session NOT cleared on retry success - Sticky session NOT cleared for non-sticky requests (empty hash) - Sticky session NOT cleared on long delay path (handled by handler) - Nil cache safety (no panic) - MaxAttempts constant verification - End-to-end retryLoop → switchError propagation with session clear --- .../service/antigravity_gateway_service.go | 27 +- .../service/antigravity_smart_retry_test.go | 681 +++++++++++++++++- 2 files changed, 678 insertions(+), 30 deletions(-) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 22065e61..126c2326 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -35,7 +35,7 @@ const ( // - 预检查:剩余限流时间 < 此阈值时等待,>= 此阈值时切换账号 antigravityRateLimitThreshold = 7 * time.Second antigravitySmartRetryMinWait = 1 * time.Second // 智能重试最小等待时间 - antigravitySmartRetryMaxAttempts = 3 // 智能重试最大次数 + antigravitySmartRetryMaxAttempts = 1 // 智能重试最大次数(仅重试 1 次,防止重复限流/长期等待) antigravityDefaultRateLimitDuration = 30 * time.Second // 默认限流时间(无 retryDelay 时使用) // Google RPC 状态和类型常量 @@ -247,6 +247,11 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam } } + // 清除粘性会话绑定,避免下次请求仍命中限流账号 + if s.cache != nil && p.sessionHash != "" { + _ = s.cache.DeleteSessionAccountID(p.ctx, p.groupID, p.sessionHash) + } + // 返回账号切换信号,让上层切换账号重试 return &smartRetryResult{ action: smartRetryActionBreakWithResp, @@ -952,6 +957,16 @@ func isModelNotFoundError(statusCode int, body []byte) bool { } // Forward 转发 Claude 协议请求(Claude → Gemini 转换) +// +// 限流处理流程: +// +// 请求 → antigravityRetryLoop → 预检查(remaining>0? → 切换账号) → 发送上游 +// ├─ 成功 → 正常返回 +// └─ 429/503 → handleSmartRetry +// ├─ retryDelay >= 7s → 设置模型限流 + 清除粘性绑定 → 切换账号 +// └─ retryDelay < 7s → 等待后重试 1 次 +// ├─ 成功 → 正常返回 +// └─ 失败 → 设置模型限流 + 清除粘性绑定 → 切换账号 func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) { startTime := time.Now() sessionID := getSessionID(c) @@ -1571,6 +1586,16 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque } // ForwardGemini 转发 Gemini 协议请求 +// +// 限流处理流程: +// +// 请求 → antigravityRetryLoop → 预检查(remaining>0? → 切换账号) → 发送上游 +// ├─ 成功 → 正常返回 +// └─ 429/503 → handleSmartRetry +// ├─ retryDelay >= 7s → 设置模型限流 + 清除粘性绑定 → 切换账号 +// └─ retryDelay < 7s → 等待后重试 1 次 +// ├─ 成功 → 正常返回 +// └─ 失败 → 设置模型限流 + 清除粘性绑定 → 切换账号 func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) { startTime := time.Now() sessionID := getSessionID(c) diff --git a/backend/internal/service/antigravity_smart_retry_test.go b/backend/internal/service/antigravity_smart_retry_test.go index 623dfec5..999b408f 100644 --- a/backend/internal/service/antigravity_smart_retry_test.go +++ b/backend/internal/service/antigravity_smart_retry_test.go @@ -13,6 +13,23 @@ import ( "github.com/stretchr/testify/require" ) +// stubSmartRetryCache 用于 handleSmartRetry 测试的 GatewayCache mock +// 仅关注 DeleteSessionAccountID 的调用记录 +type stubSmartRetryCache struct { + GatewayCache // 嵌入接口,未实现的方法 panic(确保只调用预期方法) + deleteCalls []deleteSessionCall +} + +type deleteSessionCall struct { + groupID int64 + sessionHash string +} + +func (c *stubSmartRetryCache) DeleteSessionAccountID(_ context.Context, groupID int64, sessionHash string) error { + c.deleteCalls = append(c.deleteCalls, deleteSessionCall{groupID: groupID, sessionHash: sessionHash}) + return nil +} + // mockSmartRetryUpstream 用于 handleSmartRetry 测试的 mock upstream type mockSmartRetryUpstream struct { responses []*http.Response @@ -198,7 +215,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetrySuccess(t *testing.T) { // TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError 测试智能重试失败后返回 switchError func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *testing.T) { - // 智能重试后仍然返回 429(需要提供 3 个响应,因为智能重试最多 3 次) + // 智能重试后仍然返回 429(需要提供 1 个响应,因为智能重试最多 1 次) failRespBody := `{ "error": { "status": "RESOURCE_EXHAUSTED", @@ -213,19 +230,9 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test Header: http.Header{}, Body: io.NopCloser(strings.NewReader(failRespBody)), } - failResp2 := &http.Response{ - StatusCode: http.StatusTooManyRequests, - Header: http.Header{}, - Body: io.NopCloser(strings.NewReader(failRespBody)), - } - failResp3 := &http.Response{ - StatusCode: http.StatusTooManyRequests, - Header: http.Header{}, - Body: io.NopCloser(strings.NewReader(failRespBody)), - } upstream := &mockSmartRetryUpstream{ - responses: []*http.Response{failResp1, failResp2, failResp3}, - errors: []error{nil, nil, nil}, + responses: []*http.Response{failResp1}, + errors: []error{nil}, } repo := &stubAntigravityAccountRepo{} @@ -236,7 +243,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test Platform: PlatformAntigravity, } - // 3s < 7s 阈值,应该触发智能重试(最多 3 次) + // 3s < 7s 阈值,应该触发智能重试(最多 1 次) respBody := []byte(`{ "error": { "status": "RESOURCE_EXHAUSTED", @@ -284,7 +291,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test // 验证模型限流已设置 require.Len(t, repo.modelRateLimitCalls, 1) require.Equal(t, "gemini-3-flash", repo.modelRateLimitCalls[0].modelKey) - require.Len(t, upstream.calls, 3, "should have made three retry calls (max attempts)") + require.Len(t, upstream.calls, 1, "should have made one retry call (max attempts)") } // TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError 测试 503 MODEL_CAPACITY_EXHAUSTED 返回 switchError @@ -556,19 +563,15 @@ func TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates(t *testing require.True(t, switchErr.IsStickySession) } -// TestHandleSmartRetry_NetworkError_ContinuesRetry 测试网络错误时继续重试 -func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) { - // 第一次网络错误,第二次成功 - successResp := &http.Response{ - StatusCode: http.StatusOK, - Header: http.Header{}, - Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)), - } +// TestHandleSmartRetry_NetworkError_ExhaustsRetry 测试网络错误时(maxAttempts=1)直接耗尽重试并切换账号 +func TestHandleSmartRetry_NetworkError_ExhaustsRetry(t *testing.T) { + // 唯一一次重试遇到网络错误(nil response) upstream := &mockSmartRetryUpstream{ - responses: []*http.Response{nil, successResp}, // 第一次返回 nil(模拟网络错误) - errors: []error{nil, nil}, // mock 不返回 error,靠 nil response 触发 + responses: []*http.Response{nil}, // 返回 nil(模拟网络错误) + errors: []error{nil}, // mock 不返回 error,靠 nil response 触发 } + repo := &stubAntigravityAccountRepo{} account := &Account{ ID: 8, Name: "acc-8", @@ -600,6 +603,7 @@ func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) { action: "generateContent", body: []byte(`{"input":"test"}`), httpUpstream: upstream, + accountRepo: repo, handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, @@ -612,10 +616,15 @@ func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) { require.NotNil(t, result) require.Equal(t, smartRetryActionBreakWithResp, result.action) - require.NotNil(t, result.resp, "should return successful response after network error recovery") - require.Equal(t, http.StatusOK, result.resp.StatusCode) - require.Nil(t, result.switchError, "should not return switchError on success") - require.Len(t, upstream.calls, 2, "should have made two retry calls") + require.Nil(t, result.resp, "should not return resp when switchError is set") + require.NotNil(t, result.switchError, "should return switchError after network error exhausted retry") + require.Equal(t, account.ID, result.switchError.OriginalAccountID) + require.Equal(t, "claude-sonnet-4-5", result.switchError.RateLimitedModel) + require.Len(t, upstream.calls, 1, "should have made one retry call") + + // 验证模型限流已设置 + require.Len(t, repo.modelRateLimitCalls, 1) + require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey) } // TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit 测试无 retryDelay 时使用默认 1 分钟限流 @@ -674,3 +683,617 @@ func TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit(t *testing.T) { require.Len(t, repo.modelRateLimitCalls, 1) require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey) } + +// --------------------------------------------------------------------------- +// 以下测试覆盖本次改动: +// 1. antigravitySmartRetryMaxAttempts = 1(仅重试 1 次) +// 2. 智能重试失败后清除粘性会话绑定(DeleteSessionAccountID) +// --------------------------------------------------------------------------- + +// TestSmartRetryMaxAttempts_VerifyConstant 验证常量值为 1 +func TestSmartRetryMaxAttempts_VerifyConstant(t *testing.T) { + require.Equal(t, 1, antigravitySmartRetryMaxAttempts, + "antigravitySmartRetryMaxAttempts should be 1 to prevent repeated rate limiting") +} + +// TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_ClearsSession +// 核心场景:粘性会话 + 短延迟重试失败 → 必须清除粘性绑定 +func TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_ClearsSession(t *testing.T) { + failRespBody := `{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }` + failResp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(failRespBody)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{failResp}, + errors: []error{nil}, + } + + repo := &stubAntigravityAccountRepo{} + cache := &stubSmartRetryCache{} + account := &Account{ + ID: 10, + Name: "acc-10", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + isStickySession: true, + groupID: 42, + sessionHash: "sticky-hash-abc", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{cache: cache} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + // 验证返回 switchError + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.switchError) + require.True(t, result.switchError.IsStickySession, "switchError should carry IsStickySession=true") + require.Equal(t, account.ID, result.switchError.OriginalAccountID) + + // 核心断言:DeleteSessionAccountID 被调用,且参数正确 + require.Len(t, cache.deleteCalls, 1, "should call DeleteSessionAccountID exactly once") + require.Equal(t, int64(42), cache.deleteCalls[0].groupID) + require.Equal(t, "sticky-hash-abc", cache.deleteCalls[0].sessionHash) + + // 验证仅重试 1 次 + require.Len(t, upstream.calls, 1, "should make exactly 1 retry call (maxAttempts=1)") + + // 验证模型限流已设置 + require.Len(t, repo.modelRateLimitCalls, 1) + require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey) +} + +// TestHandleSmartRetry_ShortDelay_NonStickySession_FailedRetry_NoDeleteSession +// 非粘性会话 + 短延迟重试失败 → 不应调用 DeleteSessionAccountID(sessionHash 为空) +func TestHandleSmartRetry_ShortDelay_NonStickySession_FailedRetry_NoDeleteSession(t *testing.T) { + failRespBody := `{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }` + failResp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(failRespBody)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{failResp}, + errors: []error{nil}, + } + + repo := &stubAntigravityAccountRepo{} + cache := &stubSmartRetryCache{} + account := &Account{ + ID: 11, + Name: "acc-11", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + isStickySession: false, + groupID: 42, + sessionHash: "", // 非粘性会话,sessionHash 为空 + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{cache: cache} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.switchError) + require.False(t, result.switchError.IsStickySession) + + // 核心断言:sessionHash 为空时不应调用 DeleteSessionAccountID + require.Len(t, cache.deleteCalls, 0, "should NOT call DeleteSessionAccountID when sessionHash is empty") +} + +// TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_NilCache_NoPanic +// 边界:cache 为 nil 时不应 panic +func TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_NilCache_NoPanic(t *testing.T) { + failRespBody := `{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }` + failResp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(failRespBody)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{failResp}, + errors: []error{nil}, + } + + repo := &stubAntigravityAccountRepo{} + account := &Account{ + ID: 12, + Name: "acc-12", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + isStickySession: true, + groupID: 42, + sessionHash: "sticky-hash-nil-cache", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + // cache 为 nil,不应 panic + svc := &AntigravityGatewayService{cache: nil} + require.NotPanics(t, func() { + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.switchError) + require.True(t, result.switchError.IsStickySession) + }) +} + +// TestHandleSmartRetry_ShortDelay_StickySession_SuccessRetry_NoDeleteSession +// 重试成功时不应清除粘性会话(只有失败才清除) +func TestHandleSmartRetry_ShortDelay_StickySession_SuccessRetry_NoDeleteSession(t *testing.T) { + successResp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{successResp}, + errors: []error{nil}, + } + + cache := &stubSmartRetryCache{} + account := &Account{ + ID: 13, + Name: "acc-13", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + isStickySession: true, + groupID: 42, + sessionHash: "sticky-hash-success", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{cache: cache} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.resp, "should return successful response") + require.Equal(t, http.StatusOK, result.resp.StatusCode) + require.Nil(t, result.switchError, "should not return switchError on success") + + // 核心断言:重试成功时不应清除粘性会话 + require.Len(t, cache.deleteCalls, 0, "should NOT call DeleteSessionAccountID on successful retry") +} + +// TestHandleSmartRetry_LongDelay_StickySession_NoDeleteInHandleSmartRetry +// 长延迟路径(情况1)在 handleSmartRetry 中不直接调用 DeleteSessionAccountID +// (清除由 handler 层的 shouldClearStickySession 在下次请求时处理) +func TestHandleSmartRetry_LongDelay_StickySession_NoDeleteInHandleSmartRetry(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + cache := &stubSmartRetryCache{} + account := &Account{ + ID: 14, + Name: "acc-14", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + // 15s >= 7s 阈值 → 走长延迟路径 + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + accountRepo: repo, + isStickySession: true, + groupID: 42, + sessionHash: "sticky-hash-long-delay", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{cache: cache} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.switchError) + require.True(t, result.switchError.IsStickySession) + + // 长延迟路径不在 handleSmartRetry 中调用 DeleteSessionAccountID + // (由上游 handler 的 shouldClearStickySession 处理) + require.Len(t, cache.deleteCalls, 0, + "long delay path should NOT call DeleteSessionAccountID in handleSmartRetry (handled by handler layer)") +} + +// TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession +// 网络错误耗尽重试 + 粘性会话 → 也应清除粘性绑定 +func TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession(t *testing.T) { + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{nil}, // 网络错误 + errors: []error{nil}, + } + + repo := &stubAntigravityAccountRepo{} + cache := &stubSmartRetryCache{} + account := &Account{ + ID: 15, + Name: "acc-15", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + isStickySession: true, + groupID: 99, + sessionHash: "sticky-net-error", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{cache: cache} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.NotNil(t, result.switchError) + require.True(t, result.switchError.IsStickySession) + + // 核心断言:网络错误耗尽重试后也应清除粘性绑定 + require.Len(t, cache.deleteCalls, 1, "should call DeleteSessionAccountID after network error exhausts retry") + require.Equal(t, int64(99), cache.deleteCalls[0].groupID) + require.Equal(t, "sticky-net-error", cache.deleteCalls[0].sessionHash) +} + +// TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession +// 503 + 短延迟 + 粘性会话 + 重试失败 → 清除粘性绑定 +func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession(t *testing.T) { + failRespBody := `{ + "error": { + "code": 503, + "status": "UNAVAILABLE", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"} + ] + } + }` + failResp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(failRespBody)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{failResp}, + errors: []error{nil}, + } + + repo := &stubAntigravityAccountRepo{} + cache := &stubSmartRetryCache{} + account := &Account{ + ID: 16, + Name: "acc-16", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + respBody := []byte(`{ + "error": { + "code": 503, + "status": "UNAVAILABLE", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + isStickySession: true, + groupID: 77, + sessionHash: "sticky-503-short", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{cache: cache} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.NotNil(t, result.switchError) + require.True(t, result.switchError.IsStickySession) + + // 验证粘性绑定被清除 + require.Len(t, cache.deleteCalls, 1) + require.Equal(t, int64(77), cache.deleteCalls[0].groupID) + require.Equal(t, "sticky-503-short", cache.deleteCalls[0].sessionHash) + + // 验证模型限流已设置 + require.Len(t, repo.modelRateLimitCalls, 1) + require.Equal(t, "gemini-3-pro", repo.modelRateLimitCalls[0].modelKey) +} + +// TestAntigravityRetryLoop_SmartRetryFailed_StickySession_SwitchErrorPropagates +// 集成测试:antigravityRetryLoop → handleSmartRetry → switchError 传播 +// 验证 IsStickySession 正确传递到上层,且粘性绑定被清除 +func TestAntigravityRetryLoop_SmartRetryFailed_StickySession_SwitchErrorPropagates(t *testing.T) { + // 初始 429 响应 + initialRespBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4-6"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }`) + initialResp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(initialRespBody)), + } + + // 智能重试也返回 429 + retryRespBody := `{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4-6"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }` + retryResp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(retryRespBody)), + } + + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{initialResp, retryResp}, + errors: []error{nil, nil}, + } + + repo := &stubAntigravityAccountRepo{} + cache := &stubSmartRetryCache{} + account := &Account{ + ID: 17, + Name: "acc-17", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + } + + svc := &AntigravityGatewayService{cache: cache} + result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + isStickySession: true, + groupID: 55, + sessionHash: "sticky-loop-test", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + }) + + require.Nil(t, result, "should not return result when switchError") + require.NotNil(t, err, "should return error") + + var switchErr *AntigravityAccountSwitchError + require.ErrorAs(t, err, &switchErr, "error should be AntigravityAccountSwitchError") + require.Equal(t, account.ID, switchErr.OriginalAccountID) + require.Equal(t, "claude-opus-4-6", switchErr.RateLimitedModel) + require.True(t, switchErr.IsStickySession, "IsStickySession must propagate through retryLoop") + + // 验证粘性绑定被清除 + require.Len(t, cache.deleteCalls, 1, "should clear sticky session in handleSmartRetry") + require.Equal(t, int64(55), cache.deleteCalls[0].groupID) + require.Equal(t, "sticky-loop-test", cache.deleteCalls[0].sessionHash) +} \ No newline at end of file From 77b66653ed96ede34fcb99f5d3bfbf8a04864292 Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 8 Feb 2026 01:21:02 +0800 Subject: [PATCH 07/41] fix(gateway): restore upstream account forwarding with dedicated methods v0.1.74 merged upstream accounts into the OAuth path, causing requests to hit the wrong protocol and endpoint. Add three upstream-specific methods (testUpstreamConnection, ForwardUpstream, ForwardUpstreamGemini) that use base_url + apiKey auth and passthrough the original body, while reusing the existing response handling and error/retry logic. --- .../service/antigravity_gateway_service.go | 601 ++++++++++++++++++ 1 file changed, 601 insertions(+) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 3d3c9cca..fd53ba71 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -650,6 +650,10 @@ type TestConnectionResult struct { // TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费) // 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择 func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) { + if account.Type == AccountTypeUpstream { + return s.testUpstreamConnection(ctx, account, modelID) + } + // 获取 token if s.tokenProvider == nil { return nil, errors.New("antigravity token provider not configured") @@ -966,6 +970,11 @@ func isModelNotFoundError(statusCode int, body []byte) bool { // Forward 转发 Claude 协议请求(Claude → Gemini 转换) func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) { startTime := time.Now() + + if account.Type == AccountTypeUpstream { + return s.ForwardUpstream(ctx, c, account, body, isStickySession) + } + sessionID := getSessionID(c) prefix := logPrefix(sessionID, account.Name) @@ -1585,6 +1594,11 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque // ForwardGemini 转发 Gemini 协议请求 func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) { startTime := time.Now() + + if account.Type == AccountTypeUpstream { + return s.ForwardUpstreamGemini(ctx, c, account, originalModel, action, stream, body, isStickySession) + } + sessionID := getSessionID(c) prefix := logPrefix(sessionID, account.Name) @@ -3332,3 +3346,590 @@ func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) { payload["contents"] = filtered return json.Marshal(payload) } + +// --------------------------------------------------------------------------- +// Upstream 专用转发方法 +// upstream 账号直接连接上游 Anthropic/Gemini 兼容端点,不走 Antigravity OAuth 协议转换。 +// --------------------------------------------------------------------------- + +// testUpstreamConnection 测试 upstream 账号连接 +func (s *AntigravityGatewayService) testUpstreamConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) { + baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/") + if baseURL == "" { + return nil, errors.New("upstream account missing base_url in credentials") + } + apiKey := strings.TrimSpace(account.GetCredential("api_key")) + if apiKey == "" { + return nil, errors.New("upstream account missing api_key in credentials") + } + + mappedModel := s.getMappedModel(account, modelID) + if mappedModel == "" { + return nil, fmt.Errorf("model %s not in whitelist", modelID) + } + + // 构建最小 Claude 格式请求 + requestBody, _ := json.Marshal(map[string]any{ + "model": mappedModel, + "max_tokens": 1, + "messages": []map[string]any{ + {"role": "user", "content": "."}, + }, + "stream": false, + }) + + apiURL := baseURL + "/antigravity/v1/messages" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(requestBody)) + if err != nil { + return nil, fmt.Errorf("构建请求失败: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("x-api-key", apiKey) + req.Header.Set("anthropic-version", "2023-06-01") + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + log.Printf("[antigravity-Test-Upstream] account=%s url=%s", account.Name, apiURL) + + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + return nil, fmt.Errorf("请求失败: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + if err != nil { + return nil, fmt.Errorf("读取响应失败: %w", err) + } + + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody)) + } + + // 从 Claude 格式非流式响应中提取文本 + var claudeResp struct { + Content []struct { + Text string `json:"text"` + } `json:"content"` + } + text := "" + if json.Unmarshal(respBody, &claudeResp) == nil && len(claudeResp.Content) > 0 { + text = claudeResp.Content[0].Text + } + + return &TestConnectionResult{ + Text: text, + MappedModel: mappedModel, + }, nil +} + +// ForwardUpstream 转发 Claude 协议请求到 upstream(不做协议转换) +func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) { + startTime := time.Now() + sessionID := getSessionID(c) + prefix := logPrefix(sessionID, account.Name) + + baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/") + if baseURL == "" { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "api_error", "Upstream account missing base_url") + } + apiKey := strings.TrimSpace(account.GetCredential("api_key")) + if apiKey == "" { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "authentication_error", "Upstream account missing api_key") + } + + // 解析请求以获取模型和流式标志 + var claudeReq antigravity.ClaudeRequest + if err := json.Unmarshal(body, &claudeReq); err != nil { + return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Invalid request body") + } + if strings.TrimSpace(claudeReq.Model) == "" { + return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Missing model") + } + + originalModel := claudeReq.Model + mappedModel := s.getMappedModel(account, claudeReq.Model) + if mappedModel == "" { + return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model)) + } + loadModel := mappedModel + thinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" + mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled) + quotaScope, _ := resolveAntigravityQuotaScope(originalModel) + + // 代理 URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 统计模型调用次数 + if s.cache != nil { + _, _ = s.cache.IncrModelCallCount(ctx, account.ID, loadModel) + } + + apiURL := baseURL + "/antigravity/v1/messages" + log.Printf("%s upstream_forward url=%s model=%s", prefix, apiURL, mappedModel) + + // 预检查:模型级限流 + if remaining := account.GetRateLimitRemainingTimeWithContext(ctx, originalModel); remaining > 0 { + if remaining < antigravityRateLimitThreshold { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(remaining): + } + } else { + return nil, &UpstreamFailoverError{ + StatusCode: http.StatusServiceUnavailable, + ForceCacheBilling: isStickySession, + } + } + } + + // 重试循环 + var resp *http.Response + var lastErr error + for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) + if err != nil { + return nil, s.writeClaudeError(c, http.StatusInternalServerError, "api_error", "Failed to build request") + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("x-api-key", apiKey) + + // 透传 anthropic headers + if v := c.GetHeader("anthropic-version"); v != "" { + req.Header.Set("anthropic-version", v) + } else { + req.Header.Set("anthropic-version", "2023-06-01") + } + if v := c.GetHeader("anthropic-beta"); v != "" { + req.Header.Set("anthropic-beta", v) + } + + if c != nil && len(body) > 0 { + c.Set(OpsUpstreamRequestBodyKey, string(body)) + } + + resp, err = s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + lastErr = err + if attempt < antigravityMaxRetries { + log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) + if !sleepAntigravityBackoffWithContext(ctx, attempt) { + return nil, ctx.Err() + } + continue + } + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") + } + + // 429/503 重试 + if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) + + if attempt < antigravityMaxRetries { + log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) + if !sleepAntigravityBackoffWithContext(ctx, attempt) { + return nil, ctx.Err() + } + continue + } + + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ForceCacheBilling: isStickySession, + } + } + + break // 成功或非限流错误,跳出重试 + } + if resp == nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", fmt.Sprintf("upstream request failed: %v", lastErr)) + } + defer func() { _ = resp.Body.Close() }() + + // 错误响应处理 + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + + // signature 重试 + if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) { + log.Printf("%s upstream signature error, retrying with thinking stripped", prefix) + retryClaudeReq := claudeReq + retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...) + if stripped, stripErr := stripThinkingFromClaudeRequest(&retryClaudeReq); stripErr == nil && stripped { + retryBody, _ := json.Marshal(&retryClaudeReq) + retryReq, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(retryBody)) + if err == nil { + retryReq.Header.Set("Content-Type", "application/json") + retryReq.Header.Set("Authorization", "Bearer "+apiKey) + retryReq.Header.Set("x-api-key", apiKey) + retryReq.Header.Set("anthropic-version", "2023-06-01") + if v := c.GetHeader("anthropic-beta"); v != "" { + retryReq.Header.Set("anthropic-beta", v) + } + retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency) + if retryErr == nil && retryResp != nil && retryResp.StatusCode < 400 { + resp = retryResp + goto upstreamClaudeSuccess + } + if retryResp != nil { + _ = retryResp.Body.Close() + } + } + } + } + + // prompt too long + if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) { + return nil, &PromptTooLongError{ + StatusCode: resp.StatusCode, + RequestID: resp.Header.Get("x-request-id"), + Body: respBody, + } + } + + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) + + if s.shouldFailoverUpstreamError(resp.StatusCode) { + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + } + + return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody) + } + +upstreamClaudeSuccess: + requestID := resp.Header.Get("x-request-id") + if requestID != "" { + c.Header("x-request-id", requestID) + } + + var usage *ClaudeUsage + var firstTokenMs *int + if claudeReq.Stream { + streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel) + if err != nil { + log.Printf("%s status=stream_error error=%v", prefix, err) + return nil, err + } + usage = streamRes.usage + firstTokenMs = streamRes.firstTokenMs + } else { + streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel) + if err != nil { + log.Printf("%s status=stream_collect_error error=%v", prefix, err) + return nil, err + } + usage = streamRes.usage + firstTokenMs = streamRes.firstTokenMs + } + + return &ForwardResult{ + RequestID: requestID, + Usage: *usage, + Model: originalModel, + Stream: claudeReq.Stream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + }, nil +} + +// ForwardUpstreamGemini 转发 Gemini 协议请求到 upstream(不做协议转换) +func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) { + startTime := time.Now() + sessionID := getSessionID(c) + prefix := logPrefix(sessionID, account.Name) + + baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/") + if baseURL == "" { + return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream account missing base_url") + } + apiKey := strings.TrimSpace(account.GetCredential("api_key")) + if apiKey == "" { + return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream account missing api_key") + } + + if strings.TrimSpace(originalModel) == "" { + return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing model in URL") + } + if strings.TrimSpace(action) == "" { + return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing action in URL") + } + if len(body) == 0 { + return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty") + } + quotaScope, _ := resolveAntigravityQuotaScope(originalModel) + + imageSize := s.extractImageSize(body) + + switch action { + case "generateContent", "streamGenerateContent": + // ok + case "countTokens": + c.JSON(http.StatusOK, map[string]any{"totalTokens": 0}) + return &ForwardResult{ + RequestID: "", + Usage: ClaudeUsage{}, + Model: originalModel, + Stream: false, + Duration: time.Since(time.Now()), + FirstTokenMs: nil, + }, nil + default: + return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action) + } + + mappedModel := s.getMappedModel(account, originalModel) + if mappedModel == "" { + return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel)) + } + + // 代理 URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 统计模型调用次数 + if s.cache != nil { + _, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel) + } + + // 构建 upstream URL: base_url + /antigravity/v1beta/models/MODEL:ACTION + upstreamAction := action + if action == "generateContent" && !stream { + // 非流式也用 streamGenerateContent,与 OAuth 路径行为一致 + upstreamAction = action + } + apiURL := fmt.Sprintf("%s/antigravity/v1beta/models/%s:%s", baseURL, mappedModel, upstreamAction) + if stream || upstreamAction == "streamGenerateContent" { + apiURL += "?alt=sse" + } + + log.Printf("%s upstream_forward_gemini url=%s model=%s action=%s", prefix, apiURL, mappedModel, upstreamAction) + + // 预检查:模型级限流 + if remaining := account.GetRateLimitRemainingTimeWithContext(ctx, originalModel); remaining > 0 { + if remaining < antigravityRateLimitThreshold { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(remaining): + } + } else { + return nil, &UpstreamFailoverError{ + StatusCode: http.StatusServiceUnavailable, + ForceCacheBilling: isStickySession, + } + } + } + + // 重试循环 + var resp *http.Response + var lastErr error + for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) + if err != nil { + return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build request") + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + if c != nil && len(body) > 0 { + c.Set(OpsUpstreamRequestBodyKey, string(body)) + } + + resp, err = s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + lastErr = err + if attempt < antigravityMaxRetries { + log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) + if !sleepAntigravityBackoffWithContext(ctx, attempt) { + return nil, ctx.Err() + } + continue + } + return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") + } + + // 429/503 重试 + if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) + + if attempt < antigravityMaxRetries { + log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) + if !sleepAntigravityBackoffWithContext(ctx, attempt) { + return nil, ctx.Err() + } + continue + } + + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ForceCacheBilling: isStickySession, + } + } + + break + } + if resp == nil { + return nil, s.writeGoogleError(c, http.StatusBadGateway, fmt.Sprintf("upstream request failed: %v", lastErr)) + } + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }() + + // 错误响应处理 + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + contentType := resp.Header.Get("Content-Type") + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + // 模型兜底 + if s.settingService != nil && s.settingService.IsModelFallbackEnabled(ctx) && + isModelNotFoundError(resp.StatusCode, respBody) { + fallbackModel := s.settingService.GetFallbackModel(ctx, PlatformAntigravity) + if fallbackModel != "" && fallbackModel != mappedModel { + log.Printf("[Antigravity-Upstream] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name) + fallbackURL := fmt.Sprintf("%s/antigravity/v1beta/models/%s:%s", baseURL, fallbackModel, upstreamAction) + if stream || upstreamAction == "streamGenerateContent" { + fallbackURL += "?alt=sse" + } + fallbackReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fallbackURL, bytes.NewReader(body)) + if err == nil { + fallbackReq.Header.Set("Content-Type", "application/json") + fallbackReq.Header.Set("Authorization", "Bearer "+apiKey) + fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency) + if err == nil && fallbackResp.StatusCode < 400 { + _ = resp.Body.Close() + resp = fallbackResp + } else if fallbackResp != nil { + _ = fallbackResp.Body.Close() + } + } + } + } + + // fallback 成功 + if resp.StatusCode < 400 { + goto upstreamGeminiSuccess + } + + requestID := resp.Header.Get("x-request-id") + if requestID != "" { + c.Header("x-request-id", requestID) + } + + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) + upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := s.getUpstreamErrorDetail(respBody) + + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + + if s.shouldFailoverUpstreamError(resp.StatusCode) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: requestID, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + } + if contentType == "" { + contentType = "application/json" + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: requestID, + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + log.Printf("[antigravity-Forward-Upstream] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(respBody, 500)) + c.Data(resp.StatusCode, contentType, respBody) + return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode) + } + +upstreamGeminiSuccess: + requestID := resp.Header.Get("x-request-id") + if requestID != "" { + c.Header("x-request-id", requestID) + } + + var usage *ClaudeUsage + var firstTokenMs *int + + if stream { + streamRes, err := s.handleGeminiStreamingResponse(c, resp, startTime) + if err != nil { + log.Printf("%s status=stream_error error=%v", prefix, err) + return nil, err + } + usage = streamRes.usage + firstTokenMs = streamRes.firstTokenMs + } else { + streamRes, err := s.handleGeminiStreamToNonStreaming(c, resp, startTime) + if err != nil { + log.Printf("%s status=stream_collect_error error=%v", prefix, err) + return nil, err + } + usage = streamRes.usage + firstTokenMs = streamRes.firstTokenMs + } + + if usage == nil { + usage = &ClaudeUsage{} + } + + imageCount := 0 + if isImageGenerationModel(mappedModel) { + imageCount = 1 + } + + return &ForwardResult{ + RequestID: requestID, + Usage: *usage, + Model: originalModel, + Stream: stream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ImageCount: imageCount, + ImageSize: imageSize, + }, nil +} From df3346387fcd0c758362008de867837bd28811b8 Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 8 Feb 2026 01:46:50 +0800 Subject: [PATCH 08/41] fix(frontend): upstream account edit fields and mixed_scheduling on create - EditAccountModal: add Base URL / API Key fields for upstream type - EditAccountModal: initialize editBaseUrl from credentials on upstream account open - EditAccountModal: save upstream credentials (base_url, api_key) on submit - CreateAccountModal: pass mixed_scheduling extra when creating upstream account --- .../components/account/CreateAccountModal.vue | 3 +- .../components/account/EditAccountModal.vue | 43 +++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index ba1daea9..7d759be1 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -2714,7 +2714,8 @@ const handleSubmit = async () => { submitting.value = true try { - await createAccountAndFinish(form.platform, 'upstream', credentials) + const extra = mixedScheduling.value ? { mixed_scheduling: true } : undefined + await createAccountAndFinish(form.platform, 'upstream', credentials, extra) } catch (error: any) { appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToCreate')) } finally { diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 2e428460..986bd297 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -364,6 +364,30 @@ + +
+
+ + +

{{ t('admin.accounts.upstream.baseUrlHint') }}

+
+
+ + +

{{ t('admin.accounts.leaveEmptyToKeep') }}

+
+
+
@@ -1244,6 +1268,9 @@ watch( } else { selectedErrorCodes.value = [] } + } else if (newAccount.type === 'upstream' && newAccount.credentials) { + const credentials = newAccount.credentials as Record + editBaseUrl.value = (credentials.base_url as string) || '' } else { const platformDefaultUrl = newAccount.platform === 'openai' @@ -1584,6 +1611,22 @@ const handleSubmit = async () => { return } + updatePayload.credentials = newCredentials + } else if (props.account.type === 'upstream') { + const currentCredentials = (props.account.credentials as Record) || {} + const newCredentials: Record = { ...currentCredentials } + + newCredentials.base_url = editBaseUrl.value.trim() + + if (editApiKey.value.trim()) { + newCredentials.api_key = editApiKey.value.trim() + } + + if (!applyTempUnschedConfig(newCredentials)) { + submitting.value = false + return + } + updatePayload.credentials = newCredentials } else { // For oauth/setup-token types, only update intercept_warmup_requests if changed From 1563bd3dda85e7f18058357fc8fcfdc4308c94ef Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 8 Feb 2026 08:33:09 +0800 Subject: [PATCH 09/41] feat(upstream): passthrough all client headers instead of manual header setting Replace manual header setting (Content-Type, anthropic-version, anthropic-beta) with full client header passthrough in ForwardUpstream/ForwardUpstreamGemini. Only authentication headers (Authorization, x-api-key) are overridden with upstream account credentials. Hop-by-hop headers are excluded. Add unit tests covering header passthrough, auth override, and hop-by-hop filtering. --- .../service/antigravity_gateway_service.go | 312 ++++-------------- .../upstream_header_passthrough_test.go | 285 ++++++++++++++++ 2 files changed, 352 insertions(+), 245 deletions(-) create mode 100644 backend/internal/service/upstream_header_passthrough_test.go diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index fd53ba71..fc29eeb3 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -47,6 +47,21 @@ const ( googleRPCReasonRateLimitExceeded = "RATE_LIMIT_EXCEEDED" ) +// upstreamHopByHopHeaders 透传请求头时需要排除的 hop-by-hop 头 +var upstreamHopByHopHeaders = map[string]bool{ + "connection": true, + "keep-alive": true, + "proxy-authenticate": true, + "proxy-authorization": true, + "proxy-connection": true, + "te": true, + "trailer": true, + "transfer-encoding": true, + "upgrade": true, + "host": true, + "content-length": true, +} + // antigravityPassthroughErrorMessages 透传给客户端的错误消息白名单(小写) // 匹配时使用 strings.Contains,无需完全匹配 var antigravityPassthroughErrorMessages = []string{ @@ -3456,10 +3471,6 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. if mappedModel == "" { return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model)) } - loadModel := mappedModel - thinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" - mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled) - quotaScope, _ := resolveAntigravityQuotaScope(originalModel) // 代理 URL proxyURL := "" @@ -3469,98 +3480,38 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. // 统计模型调用次数 if s.cache != nil { - _, _ = s.cache.IncrModelCallCount(ctx, account.ID, loadModel) + _, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel) } apiURL := baseURL + "/antigravity/v1/messages" log.Printf("%s upstream_forward url=%s model=%s", prefix, apiURL, mappedModel) - // 预检查:模型级限流 - if remaining := account.GetRateLimitRemainingTimeWithContext(ctx, originalModel); remaining > 0 { - if remaining < antigravityRateLimitThreshold { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(remaining): - } - } else { - return nil, &UpstreamFailoverError{ - StatusCode: http.StatusServiceUnavailable, - ForceCacheBilling: isStickySession, - } + // 构建请求:body 原样透传 + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) + if err != nil { + return nil, s.writeClaudeError(c, http.StatusInternalServerError, "api_error", "Failed to build request") + } + // 透传客户端所有请求头(排除 hop-by-hop 和认证头) + for key, values := range c.Request.Header { + if upstreamHopByHopHeaders[strings.ToLower(key)] { + continue + } + for _, v := range values { + req.Header.Add(key, v) } } + // 覆盖认证头 + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("x-api-key", apiKey) - // 重试循环 - var resp *http.Response - var lastErr error - for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) - if err != nil { - return nil, s.writeClaudeError(c, http.StatusInternalServerError, "api_error", "Failed to build request") - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+apiKey) - req.Header.Set("x-api-key", apiKey) - - // 透传 anthropic headers - if v := c.GetHeader("anthropic-version"); v != "" { - req.Header.Set("anthropic-version", v) - } else { - req.Header.Set("anthropic-version", "2023-06-01") - } - if v := c.GetHeader("anthropic-beta"); v != "" { - req.Header.Set("anthropic-beta", v) - } - - if c != nil && len(body) > 0 { - c.Set(OpsUpstreamRequestBodyKey, string(body)) - } - - resp, err = s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - lastErr = err - if attempt < antigravityMaxRetries { - log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - return nil, ctx.Err() - } - continue - } - return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") - } - - // 429/503 重试 - if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) - - if attempt < antigravityMaxRetries { - log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - return nil, ctx.Err() - } - continue - } - - return nil, &UpstreamFailoverError{ - StatusCode: resp.StatusCode, - ForceCacheBilling: isStickySession, - } - } - - break // 成功或非限流错误,跳出重试 + if c != nil && len(body) > 0 { + c.Set(OpsUpstreamRequestBodyKey, string(body)) } - if resp == nil { - return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", fmt.Sprintf("upstream request failed: %v", lastErr)) + + // 单次发送,不重试 + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", fmt.Sprintf("Upstream request failed: %v", err)) } defer func() { _ = resp.Body.Close() }() @@ -3568,44 +3519,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - // signature 重试 - if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) { - log.Printf("%s upstream signature error, retrying with thinking stripped", prefix) - retryClaudeReq := claudeReq - retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...) - if stripped, stripErr := stripThinkingFromClaudeRequest(&retryClaudeReq); stripErr == nil && stripped { - retryBody, _ := json.Marshal(&retryClaudeReq) - retryReq, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(retryBody)) - if err == nil { - retryReq.Header.Set("Content-Type", "application/json") - retryReq.Header.Set("Authorization", "Bearer "+apiKey) - retryReq.Header.Set("x-api-key", apiKey) - retryReq.Header.Set("anthropic-version", "2023-06-01") - if v := c.GetHeader("anthropic-beta"); v != "" { - retryReq.Header.Set("anthropic-beta", v) - } - retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency) - if retryErr == nil && retryResp != nil && retryResp.StatusCode < 400 { - resp = retryResp - goto upstreamClaudeSuccess - } - if retryResp != nil { - _ = retryResp.Body.Close() - } - } - } - } - - // prompt too long - if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) { - return nil, &PromptTooLongError{ - StatusCode: resp.StatusCode, - RequestID: resp.Header.Get("x-request-id"), - Body: respBody, - } - } - - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, "", 0, "", isStickySession) if s.shouldFailoverUpstreamError(resp.StatusCode) { return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} @@ -3614,7 +3528,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody) } -upstreamClaudeSuccess: + // 成功响应 requestID := resp.Header.Get("x-request-id") if requestID != "" { c.Header("x-request-id", requestID) @@ -3674,7 +3588,6 @@ func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c if len(body) == 0 { return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty") } - quotaScope, _ := resolveAntigravityQuotaScope(originalModel) imageSize := s.extractImageSize(body) @@ -3712,143 +3625,52 @@ func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c } // 构建 upstream URL: base_url + /antigravity/v1beta/models/MODEL:ACTION - upstreamAction := action - if action == "generateContent" && !stream { - // 非流式也用 streamGenerateContent,与 OAuth 路径行为一致 - upstreamAction = action - } - apiURL := fmt.Sprintf("%s/antigravity/v1beta/models/%s:%s", baseURL, mappedModel, upstreamAction) - if stream || upstreamAction == "streamGenerateContent" { + apiURL := fmt.Sprintf("%s/antigravity/v1beta/models/%s:%s", baseURL, mappedModel, action) + if stream || action == "streamGenerateContent" { apiURL += "?alt=sse" } - log.Printf("%s upstream_forward_gemini url=%s model=%s action=%s", prefix, apiURL, mappedModel, upstreamAction) + log.Printf("%s upstream_forward_gemini url=%s model=%s action=%s", prefix, apiURL, mappedModel, action) - // 预检查:模型级限流 - if remaining := account.GetRateLimitRemainingTimeWithContext(ctx, originalModel); remaining > 0 { - if remaining < antigravityRateLimitThreshold { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(remaining): - } - } else { - return nil, &UpstreamFailoverError{ - StatusCode: http.StatusServiceUnavailable, - ForceCacheBilling: isStickySession, - } + // 构建请求:body 原样透传 + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) + if err != nil { + return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build request") + } + // 透传客户端所有请求头(排除 hop-by-hop 和认证头) + for key, values := range c.Request.Header { + if upstreamHopByHopHeaders[strings.ToLower(key)] { + continue + } + for _, v := range values { + req.Header.Add(key, v) } } + // 覆盖认证头 + req.Header.Set("Authorization", "Bearer "+apiKey) - // 重试循环 - var resp *http.Response - var lastErr error - for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) - if err != nil { - return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build request") - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+apiKey) - - if c != nil && len(body) > 0 { - c.Set(OpsUpstreamRequestBodyKey, string(body)) - } - - resp, err = s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - lastErr = err - if attempt < antigravityMaxRetries { - log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - return nil, ctx.Err() - } - continue - } - return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") - } - - // 429/503 重试 - if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) - - if attempt < antigravityMaxRetries { - log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - return nil, ctx.Err() - } - continue - } - - return nil, &UpstreamFailoverError{ - StatusCode: resp.StatusCode, - ForceCacheBilling: isStickySession, - } - } - - break + if c != nil && len(body) > 0 { + c.Set(OpsUpstreamRequestBodyKey, string(body)) } - if resp == nil { - return nil, s.writeGoogleError(c, http.StatusBadGateway, fmt.Sprintf("upstream request failed: %v", lastErr)) + + // 单次发送,不重试 + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + return nil, s.writeGoogleError(c, http.StatusBadGateway, fmt.Sprintf("Upstream request failed: %v", err)) } - defer func() { - if resp != nil && resp.Body != nil { - _ = resp.Body.Close() - } - }() + defer func() { _ = resp.Body.Close() }() // 错误响应处理 if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) contentType := resp.Header.Get("Content-Type") - _ = resp.Body.Close() - resp.Body = io.NopCloser(bytes.NewReader(respBody)) - - // 模型兜底 - if s.settingService != nil && s.settingService.IsModelFallbackEnabled(ctx) && - isModelNotFoundError(resp.StatusCode, respBody) { - fallbackModel := s.settingService.GetFallbackModel(ctx, PlatformAntigravity) - if fallbackModel != "" && fallbackModel != mappedModel { - log.Printf("[Antigravity-Upstream] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name) - fallbackURL := fmt.Sprintf("%s/antigravity/v1beta/models/%s:%s", baseURL, fallbackModel, upstreamAction) - if stream || upstreamAction == "streamGenerateContent" { - fallbackURL += "?alt=sse" - } - fallbackReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fallbackURL, bytes.NewReader(body)) - if err == nil { - fallbackReq.Header.Set("Content-Type", "application/json") - fallbackReq.Header.Set("Authorization", "Bearer "+apiKey) - fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency) - if err == nil && fallbackResp.StatusCode < 400 { - _ = resp.Body.Close() - resp = fallbackResp - } else if fallbackResp != nil { - _ = fallbackResp.Body.Close() - } - } - } - } - - // fallback 成功 - if resp.StatusCode < 400 { - goto upstreamGeminiSuccess - } requestID := resp.Header.Get("x-request-id") if requestID != "" { c.Header("x-request-id", requestID) } - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, "", 0, "", isStickySession) upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamDetail := s.getUpstreamErrorDetail(respBody) @@ -3886,7 +3708,7 @@ func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode) } -upstreamGeminiSuccess: + // 成功响应 requestID := resp.Header.Get("x-request-id") if requestID != "" { c.Header("x-request-id", requestID) diff --git a/backend/internal/service/upstream_header_passthrough_test.go b/backend/internal/service/upstream_header_passthrough_test.go new file mode 100644 index 00000000..51d8588b --- /dev/null +++ b/backend/internal/service/upstream_header_passthrough_test.go @@ -0,0 +1,285 @@ +//go:build unit + +package service + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// httpUpstreamCapture captures the outgoing *http.Request for assertion. +type httpUpstreamCapture struct { + capturedReq *http.Request + resp *http.Response + err error +} + +func (s *httpUpstreamCapture) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) { + s.capturedReq = req + return s.resp, s.err +} + +func (s *httpUpstreamCapture) DoWithTLS(req *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) { + s.capturedReq = req + return s.resp, s.err +} + +func newUpstreamAccount() *Account { + return &Account{ + ID: 100, + Name: "upstream-test", + Platform: PlatformAntigravity, + Type: AccountTypeUpstream, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "base_url": "https://upstream.example.com", + "api_key": "sk-upstream-secret", + }, + } +} + +// makeSSEOKResponse builds a minimal SSE response that +// handleClaudeStreamingResponse / handleGeminiStreamingResponse +// can consume without error. +// We return 502 to bypass streaming and hit the error branch instead, +// which is sufficient for testing header passthrough. +func makeUpstreamErrorResponse() *http.Response { + body := []byte(`{"error":{"message":"test error"}}`) + return &http.Response{ + StatusCode: http.StatusBadGateway, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(bytes.NewReader(body)), + } +} + +// --- ForwardUpstream tests --- + +func TestForwardUpstream_PassthroughHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body, _ := json.Marshal(map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []map[string]any{{"role": "user", "content": "hi"}}, + "max_tokens": 1, + "stream": false, + }) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("anthropic-version", "2024-10-22") + req.Header.Set("anthropic-beta", "output-128k-2025-02-19") + req.Header.Set("X-Custom-Header", "custom-value") + c.Request = req + + stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: stub, + } + + _, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false) + + captured := stub.capturedReq + require.NotNil(t, captured, "upstream request should have been made") + + // 客户端 header 应被透传 + require.Equal(t, "application/json", captured.Header.Get("Content-Type")) + require.Equal(t, "2024-10-22", captured.Header.Get("anthropic-version")) + require.Equal(t, "output-128k-2025-02-19", captured.Header.Get("anthropic-beta")) + require.Equal(t, "custom-value", captured.Header.Get("X-Custom-Header")) +} + +func TestForwardUpstream_OverridesAuthHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body, _ := json.Marshal(map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []map[string]any{{"role": "user", "content": "hi"}}, + "max_tokens": 1, + "stream": false, + }) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + // 客户端发来的认证头应被覆盖 + req.Header.Set("Authorization", "Bearer client-token") + req.Header.Set("x-api-key", "client-api-key") + c.Request = req + + stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: stub, + } + + _, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false) + + captured := stub.capturedReq + require.NotNil(t, captured) + + // 认证头应使用上游账号的 api_key,而非客户端的 + require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization")) + require.Equal(t, "sk-upstream-secret", captured.Header.Get("x-api-key")) +} + +func TestForwardUpstream_ExcludesHopByHopHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body, _ := json.Marshal(map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []map[string]any{{"role": "user", "content": "hi"}}, + "max_tokens": 1, + "stream": false, + }) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Connection", "keep-alive") + req.Header.Set("Keep-Alive", "timeout=5") + req.Header.Set("Transfer-Encoding", "chunked") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Te", "trailers") + c.Request = req + + stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: stub, + } + + _, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false) + + captured := stub.capturedReq + require.NotNil(t, captured) + + // hop-by-hop header 不应出现 + require.Empty(t, captured.Header.Get("Connection")) + require.Empty(t, captured.Header.Get("Keep-Alive")) + require.Empty(t, captured.Header.Get("Transfer-Encoding")) + require.Empty(t, captured.Header.Get("Upgrade")) + require.Empty(t, captured.Header.Get("Te")) + + // 但普通 header 应保留 + require.Equal(t, "application/json", captured.Header.Get("Content-Type")) +} + +// --- ForwardUpstreamGemini tests --- + +func TestForwardUpstreamGemini_PassthroughHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body, _ := json.Marshal(map[string]any{ + "contents": []map[string]any{ + {"role": "user", "parts": []map[string]any{{"text": "hi"}}}, + }, + }) + + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Custom-Gemini", "gemini-value") + req.Header.Set("X-Request-Id", "req-abc-123") + c.Request = req + + stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: stub, + } + + _, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false) + + captured := stub.capturedReq + require.NotNil(t, captured, "upstream request should have been made") + + // 客户端 header 应被透传 + require.Equal(t, "application/json", captured.Header.Get("Content-Type")) + require.Equal(t, "gemini-value", captured.Header.Get("X-Custom-Gemini")) + require.Equal(t, "req-abc-123", captured.Header.Get("X-Request-Id")) +} + +func TestForwardUpstreamGemini_OverridesAuthHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body, _ := json.Marshal(map[string]any{ + "contents": []map[string]any{ + {"role": "user", "parts": []map[string]any{{"text": "hi"}}}, + }, + }) + + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer client-gemini-token") + c.Request = req + + stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: stub, + } + + _, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false) + + captured := stub.capturedReq + require.NotNil(t, captured) + + // 认证头应使用上游账号的 api_key + require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization")) +} + +func TestForwardUpstreamGemini_ExcludesHopByHopHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body, _ := json.Marshal(map[string]any{ + "contents": []map[string]any{ + {"role": "user", "parts": []map[string]any{{"text": "hi"}}}, + }, + }) + + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Connection", "keep-alive") + req.Header.Set("Proxy-Authorization", "Basic dXNlcjpwYXNz") + req.Header.Set("Host", "evil.example.com") + c.Request = req + + stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: stub, + } + + _, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false) + + captured := stub.capturedReq + require.NotNil(t, captured) + + // hop-by-hop header 不应出现 + require.Empty(t, captured.Header.Get("Connection")) + require.Empty(t, captured.Header.Get("Proxy-Authorization")) + // Host header 在 Go http.Request 中特殊处理,但我们的黑名单应阻止透传 + require.Empty(t, captured.Header.Values("Host")) + + // 普通 header 应保留 + require.Equal(t, "application/json", captured.Header.Get("Content-Type")) +} From 4f57d7f76188f2c767060c37d516ceb3fb05cdfe Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 8 Feb 2026 08:36:35 +0800 Subject: [PATCH 10/41] fix: add nil guard for gin.Context in header passthrough to satisfy staticcheck SA5011 --- .../service/antigravity_gateway_service.go | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index fc29eeb3..c2983c47 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -3492,12 +3492,14 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. return nil, s.writeClaudeError(c, http.StatusInternalServerError, "api_error", "Failed to build request") } // 透传客户端所有请求头(排除 hop-by-hop 和认证头) - for key, values := range c.Request.Header { - if upstreamHopByHopHeaders[strings.ToLower(key)] { - continue - } - for _, v := range values { - req.Header.Add(key, v) + if c != nil && c.Request != nil { + for key, values := range c.Request.Header { + if upstreamHopByHopHeaders[strings.ToLower(key)] { + continue + } + for _, v := range values { + req.Header.Add(key, v) + } } } // 覆盖认证头 @@ -3638,12 +3640,14 @@ func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build request") } // 透传客户端所有请求头(排除 hop-by-hop 和认证头) - for key, values := range c.Request.Header { - if upstreamHopByHopHeaders[strings.ToLower(key)] { - continue - } - for _, v := range values { - req.Header.Add(key, v) + if c != nil && c.Request != nil { + for key, values := range c.Request.Header { + if upstreamHopByHopHeaders[strings.ToLower(key)] { + continue + } + for _, v := range values { + req.Header.Add(key, v) + } } } // 覆盖认证头 From 6ab77f5eb5afceb99eb32bba011261866bf6cf14 Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 8 Feb 2026 08:49:43 +0800 Subject: [PATCH 11/41] fix(upstream): passthrough response body directly instead of parsing SSE ForwardUpstream/ForwardUpstreamGemini should pipe the upstream response directly to the client (headers + body), not parse it as SSE stream. --- .../service/antigravity_gateway_service.go | 99 +++++++------------ 1 file changed, 38 insertions(+), 61 deletions(-) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index c2983c47..2d96b1ab 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -3530,39 +3530,30 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody) } - // 成功响应 + // 成功响应:透传 response header + body requestID := resp.Header.Get("x-request-id") - if requestID != "" { - c.Header("x-request-id", requestID) + + // 透传上游响应头(排除 hop-by-hop) + for key, values := range resp.Header { + if upstreamHopByHopHeaders[strings.ToLower(key)] { + continue + } + for _, v := range values { + c.Header(key, v) + } } - var usage *ClaudeUsage - var firstTokenMs *int - if claudeReq.Stream { - streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel) - if err != nil { - log.Printf("%s status=stream_error error=%v", prefix, err) - return nil, err - } - usage = streamRes.usage - firstTokenMs = streamRes.firstTokenMs - } else { - streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel) - if err != nil { - log.Printf("%s status=stream_collect_error error=%v", prefix, err) - return nil, err - } - usage = streamRes.usage - firstTokenMs = streamRes.firstTokenMs + c.Status(resp.StatusCode) + _, copyErr := io.Copy(c.Writer, resp.Body) + if copyErr != nil { + log.Printf("%s status=copy_error error=%v", prefix, copyErr) } return &ForwardResult{ - RequestID: requestID, - Usage: *usage, - Model: originalModel, - Stream: claudeReq.Stream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, + RequestID: requestID, + Model: originalModel, + Stream: claudeReq.Stream, + Duration: time.Since(startTime), }, nil } @@ -3712,35 +3703,23 @@ func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode) } - // 成功响应 + // 成功响应:透传 response header + body requestID := resp.Header.Get("x-request-id") - if requestID != "" { - c.Header("x-request-id", requestID) + + // 透传上游响应头(排除 hop-by-hop) + for key, values := range resp.Header { + if upstreamHopByHopHeaders[strings.ToLower(key)] { + continue + } + for _, v := range values { + c.Header(key, v) + } } - var usage *ClaudeUsage - var firstTokenMs *int - - if stream { - streamRes, err := s.handleGeminiStreamingResponse(c, resp, startTime) - if err != nil { - log.Printf("%s status=stream_error error=%v", prefix, err) - return nil, err - } - usage = streamRes.usage - firstTokenMs = streamRes.firstTokenMs - } else { - streamRes, err := s.handleGeminiStreamToNonStreaming(c, resp, startTime) - if err != nil { - log.Printf("%s status=stream_collect_error error=%v", prefix, err) - return nil, err - } - usage = streamRes.usage - firstTokenMs = streamRes.firstTokenMs - } - - if usage == nil { - usage = &ClaudeUsage{} + c.Status(resp.StatusCode) + _, copyErr := io.Copy(c.Writer, resp.Body) + if copyErr != nil { + log.Printf("%s status=copy_error error=%v", prefix, copyErr) } imageCount := 0 @@ -3749,13 +3728,11 @@ func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c } return &ForwardResult{ - RequestID: requestID, - Usage: *usage, - Model: originalModel, - Stream: stream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - ImageCount: imageCount, - ImageSize: imageSize, + RequestID: requestID, + Model: originalModel, + Stream: stream, + Duration: time.Since(startTime), + ImageCount: imageCount, + ImageSize: imageSize, }, nil } From fb58560d15fa34d2fc14b89f301e946e039861e7 Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 8 Feb 2026 13:06:25 +0800 Subject: [PATCH 12/41] refactor(upstream): replace upstream account type with apikey, auto-append /antigravity Upstream accounts now use the standard APIKey type instead of a dedicated upstream type. GetBaseURL() and new GetGeminiBaseURL() automatically append /antigravity for Antigravity platform APIKey accounts, eliminating the need for separate upstream forwarding methods. - Remove ForwardUpstream, ForwardUpstreamGemini, testUpstreamConnection - Remove upstream branch guards in Forward/ForwardGemini/TestConnection - Add migration 052 to convert existing upstream accounts to apikey - Update frontend CreateAccountModal to create apikey type - Add unit tests for GetBaseURL and GetGeminiBaseURL --- backend/internal/handler/gateway_handler.go | 2 +- .../internal/handler/gemini_v1beta_handler.go | 2 +- backend/internal/service/account.go | 16 + .../internal/service/account_base_url_test.go | 160 ++++++++ .../service/antigravity_gateway_service.go | 386 ------------------ .../service/gemini_messages_compat_service.go | 25 +- .../upstream_header_passthrough_test.go | 285 ------------- .../052_migrate_upstream_to_apikey.sql | 11 + .../components/account/CreateAccountModal.vue | 6 +- 9 files changed, 197 insertions(+), 696 deletions(-) create mode 100644 backend/internal/service/account_base_url_test.go delete mode 100644 backend/internal/service/upstream_header_passthrough_test.go create mode 100644 backend/migrations/052_migrate_upstream_to_apikey.sql diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index ca4442e4..255d3fab 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -482,7 +482,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if switchCount > 0 { requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) } - if account.Platform == service.PlatformAntigravity { + if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession) } else { result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq) diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index b1477ac6..2b69be2e 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -410,7 +410,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { if switchCount > 0 { requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) } - if account.Platform == service.PlatformAntigravity { + if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession) } else { result, err = h.geminiCompatService.ForwardNative(requestCtx, c, account, modelName, action, stream, body) diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index a6ae8a68..138d5bcb 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -425,6 +425,22 @@ func (a *Account) GetBaseURL() string { if baseURL == "" { return "https://api.anthropic.com" } + if a.Platform == PlatformAntigravity { + return strings.TrimRight(baseURL, "/") + "/antigravity" + } + return baseURL +} + +// GetGeminiBaseURL 返回 Gemini 兼容端点的 base URL。 +// Antigravity 平台的 APIKey 账号自动拼接 /antigravity。 +func (a *Account) GetGeminiBaseURL(defaultBaseURL string) string { + baseURL := strings.TrimSpace(a.GetCredential("base_url")) + if baseURL == "" { + return defaultBaseURL + } + if a.Platform == PlatformAntigravity && a.Type == AccountTypeAPIKey { + return strings.TrimRight(baseURL, "/") + "/antigravity" + } return baseURL } diff --git a/backend/internal/service/account_base_url_test.go b/backend/internal/service/account_base_url_test.go new file mode 100644 index 00000000..a1322193 --- /dev/null +++ b/backend/internal/service/account_base_url_test.go @@ -0,0 +1,160 @@ +//go:build unit + +package service + +import ( + "testing" +) + +func TestGetBaseURL(t *testing.T) { + tests := []struct { + name string + account Account + expected string + }{ + { + name: "non-apikey type returns empty", + account: Account{ + Type: AccountTypeOAuth, + Platform: PlatformAnthropic, + }, + expected: "", + }, + { + name: "apikey without base_url returns default anthropic", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAnthropic, + Credentials: map[string]any{}, + }, + expected: "https://api.anthropic.com", + }, + { + name: "apikey with custom base_url", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAnthropic, + Credentials: map[string]any{"base_url": "https://custom.example.com"}, + }, + expected: "https://custom.example.com", + }, + { + name: "antigravity apikey auto-appends /antigravity", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com"}, + }, + expected: "https://upstream.example.com/antigravity", + }, + { + name: "antigravity apikey trims trailing slash before appending", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com/"}, + }, + expected: "https://upstream.example.com/antigravity", + }, + { + name: "antigravity non-apikey returns empty", + account: Account{ + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com"}, + }, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.account.GetBaseURL() + if result != tt.expected { + t.Errorf("GetBaseURL() = %q, want %q", result, tt.expected) + } + }) + } +} + +func TestGetGeminiBaseURL(t *testing.T) { + const defaultGeminiURL = "https://generativelanguage.googleapis.com" + + tests := []struct { + name string + account Account + expected string + }{ + { + name: "apikey without base_url returns default", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{}, + }, + expected: defaultGeminiURL, + }, + { + name: "apikey with custom base_url", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{"base_url": "https://custom-gemini.example.com"}, + }, + expected: "https://custom-gemini.example.com", + }, + { + name: "antigravity apikey auto-appends /antigravity", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com"}, + }, + expected: "https://upstream.example.com/antigravity", + }, + { + name: "antigravity apikey trims trailing slash", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com/"}, + }, + expected: "https://upstream.example.com/antigravity", + }, + { + name: "antigravity oauth does NOT append /antigravity", + account: Account{ + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com"}, + }, + expected: "https://upstream.example.com", + }, + { + name: "oauth without base_url returns default", + account: Account{ + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{}, + }, + expected: defaultGeminiURL, + }, + { + name: "nil credentials returns default", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + }, + expected: defaultGeminiURL, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.account.GetGeminiBaseURL(defaultGeminiURL) + if result != tt.expected { + t.Errorf("GetGeminiBaseURL() = %q, want %q", result, tt.expected) + } + }) + } +} diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 2d96b1ab..4ea73e64 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -665,9 +665,6 @@ type TestConnectionResult struct { // TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费) // 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择 func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) { - if account.Type == AccountTypeUpstream { - return s.testUpstreamConnection(ctx, account, modelID) - } // 获取 token if s.tokenProvider == nil { @@ -986,10 +983,6 @@ func isModelNotFoundError(statusCode int, body []byte) bool { func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) { startTime := time.Now() - if account.Type == AccountTypeUpstream { - return s.ForwardUpstream(ctx, c, account, body, isStickySession) - } - sessionID := getSessionID(c) prefix := logPrefix(sessionID, account.Name) @@ -1610,10 +1603,6 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) { startTime := time.Now() - if account.Type == AccountTypeUpstream { - return s.ForwardUpstreamGemini(ctx, c, account, originalModel, action, stream, body, isStickySession) - } - sessionID := getSessionID(c) prefix := logPrefix(sessionID, account.Name) @@ -3361,378 +3350,3 @@ func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) { payload["contents"] = filtered return json.Marshal(payload) } - -// --------------------------------------------------------------------------- -// Upstream 专用转发方法 -// upstream 账号直接连接上游 Anthropic/Gemini 兼容端点,不走 Antigravity OAuth 协议转换。 -// --------------------------------------------------------------------------- - -// testUpstreamConnection 测试 upstream 账号连接 -func (s *AntigravityGatewayService) testUpstreamConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) { - baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/") - if baseURL == "" { - return nil, errors.New("upstream account missing base_url in credentials") - } - apiKey := strings.TrimSpace(account.GetCredential("api_key")) - if apiKey == "" { - return nil, errors.New("upstream account missing api_key in credentials") - } - - mappedModel := s.getMappedModel(account, modelID) - if mappedModel == "" { - return nil, fmt.Errorf("model %s not in whitelist", modelID) - } - - // 构建最小 Claude 格式请求 - requestBody, _ := json.Marshal(map[string]any{ - "model": mappedModel, - "max_tokens": 1, - "messages": []map[string]any{ - {"role": "user", "content": "."}, - }, - "stream": false, - }) - - apiURL := baseURL + "/antigravity/v1/messages" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(requestBody)) - if err != nil { - return nil, fmt.Errorf("构建请求失败: %w", err) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+apiKey) - req.Header.Set("x-api-key", apiKey) - req.Header.Set("anthropic-version", "2023-06-01") - - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - log.Printf("[antigravity-Test-Upstream] account=%s url=%s", account.Name, apiURL) - - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - return nil, fmt.Errorf("请求失败: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - if err != nil { - return nil, fmt.Errorf("读取响应失败: %w", err) - } - - if resp.StatusCode >= 400 { - return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody)) - } - - // 从 Claude 格式非流式响应中提取文本 - var claudeResp struct { - Content []struct { - Text string `json:"text"` - } `json:"content"` - } - text := "" - if json.Unmarshal(respBody, &claudeResp) == nil && len(claudeResp.Content) > 0 { - text = claudeResp.Content[0].Text - } - - return &TestConnectionResult{ - Text: text, - MappedModel: mappedModel, - }, nil -} - -// ForwardUpstream 转发 Claude 协议请求到 upstream(不做协议转换) -func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) { - startTime := time.Now() - sessionID := getSessionID(c) - prefix := logPrefix(sessionID, account.Name) - - baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/") - if baseURL == "" { - return nil, s.writeClaudeError(c, http.StatusBadGateway, "api_error", "Upstream account missing base_url") - } - apiKey := strings.TrimSpace(account.GetCredential("api_key")) - if apiKey == "" { - return nil, s.writeClaudeError(c, http.StatusBadGateway, "authentication_error", "Upstream account missing api_key") - } - - // 解析请求以获取模型和流式标志 - var claudeReq antigravity.ClaudeRequest - if err := json.Unmarshal(body, &claudeReq); err != nil { - return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Invalid request body") - } - if strings.TrimSpace(claudeReq.Model) == "" { - return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Missing model") - } - - originalModel := claudeReq.Model - mappedModel := s.getMappedModel(account, claudeReq.Model) - if mappedModel == "" { - return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model)) - } - - // 代理 URL - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - // 统计模型调用次数 - if s.cache != nil { - _, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel) - } - - apiURL := baseURL + "/antigravity/v1/messages" - log.Printf("%s upstream_forward url=%s model=%s", prefix, apiURL, mappedModel) - - // 构建请求:body 原样透传 - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) - if err != nil { - return nil, s.writeClaudeError(c, http.StatusInternalServerError, "api_error", "Failed to build request") - } - // 透传客户端所有请求头(排除 hop-by-hop 和认证头) - if c != nil && c.Request != nil { - for key, values := range c.Request.Header { - if upstreamHopByHopHeaders[strings.ToLower(key)] { - continue - } - for _, v := range values { - req.Header.Add(key, v) - } - } - } - // 覆盖认证头 - req.Header.Set("Authorization", "Bearer "+apiKey) - req.Header.Set("x-api-key", apiKey) - - if c != nil && len(body) > 0 { - c.Set(OpsUpstreamRequestBodyKey, string(body)) - } - - // 单次发送,不重试 - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", fmt.Sprintf("Upstream request failed: %v", err)) - } - defer func() { _ = resp.Body.Close() }() - - // 错误响应处理 - if resp.StatusCode >= 400 { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, "", 0, "", isStickySession) - - if s.shouldFailoverUpstreamError(resp.StatusCode) { - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} - } - - return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody) - } - - // 成功响应:透传 response header + body - requestID := resp.Header.Get("x-request-id") - - // 透传上游响应头(排除 hop-by-hop) - for key, values := range resp.Header { - if upstreamHopByHopHeaders[strings.ToLower(key)] { - continue - } - for _, v := range values { - c.Header(key, v) - } - } - - c.Status(resp.StatusCode) - _, copyErr := io.Copy(c.Writer, resp.Body) - if copyErr != nil { - log.Printf("%s status=copy_error error=%v", prefix, copyErr) - } - - return &ForwardResult{ - RequestID: requestID, - Model: originalModel, - Stream: claudeReq.Stream, - Duration: time.Since(startTime), - }, nil -} - -// ForwardUpstreamGemini 转发 Gemini 协议请求到 upstream(不做协议转换) -func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) { - startTime := time.Now() - sessionID := getSessionID(c) - prefix := logPrefix(sessionID, account.Name) - - baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/") - if baseURL == "" { - return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream account missing base_url") - } - apiKey := strings.TrimSpace(account.GetCredential("api_key")) - if apiKey == "" { - return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream account missing api_key") - } - - if strings.TrimSpace(originalModel) == "" { - return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing model in URL") - } - if strings.TrimSpace(action) == "" { - return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing action in URL") - } - if len(body) == 0 { - return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty") - } - - imageSize := s.extractImageSize(body) - - switch action { - case "generateContent", "streamGenerateContent": - // ok - case "countTokens": - c.JSON(http.StatusOK, map[string]any{"totalTokens": 0}) - return &ForwardResult{ - RequestID: "", - Usage: ClaudeUsage{}, - Model: originalModel, - Stream: false, - Duration: time.Since(time.Now()), - FirstTokenMs: nil, - }, nil - default: - return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action) - } - - mappedModel := s.getMappedModel(account, originalModel) - if mappedModel == "" { - return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel)) - } - - // 代理 URL - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - // 统计模型调用次数 - if s.cache != nil { - _, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel) - } - - // 构建 upstream URL: base_url + /antigravity/v1beta/models/MODEL:ACTION - apiURL := fmt.Sprintf("%s/antigravity/v1beta/models/%s:%s", baseURL, mappedModel, action) - if stream || action == "streamGenerateContent" { - apiURL += "?alt=sse" - } - - log.Printf("%s upstream_forward_gemini url=%s model=%s action=%s", prefix, apiURL, mappedModel, action) - - // 构建请求:body 原样透传 - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) - if err != nil { - return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build request") - } - // 透传客户端所有请求头(排除 hop-by-hop 和认证头) - if c != nil && c.Request != nil { - for key, values := range c.Request.Header { - if upstreamHopByHopHeaders[strings.ToLower(key)] { - continue - } - for _, v := range values { - req.Header.Add(key, v) - } - } - } - // 覆盖认证头 - req.Header.Set("Authorization", "Bearer "+apiKey) - - if c != nil && len(body) > 0 { - c.Set(OpsUpstreamRequestBodyKey, string(body)) - } - - // 单次发送,不重试 - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - return nil, s.writeGoogleError(c, http.StatusBadGateway, fmt.Sprintf("Upstream request failed: %v", err)) - } - defer func() { _ = resp.Body.Close() }() - - // 错误响应处理 - if resp.StatusCode >= 400 { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - contentType := resp.Header.Get("Content-Type") - - requestID := resp.Header.Get("x-request-id") - if requestID != "" { - c.Header("x-request-id", requestID) - } - - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, "", 0, "", isStickySession) - upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - upstreamDetail := s.getUpstreamErrorDetail(respBody) - - setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) - - if s.shouldFailoverUpstreamError(resp.StatusCode) { - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: requestID, - Kind: "failover", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} - } - if contentType == "" { - contentType = "application/json" - } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: requestID, - Kind: "http_error", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - log.Printf("[antigravity-Forward-Upstream] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(respBody, 500)) - c.Data(resp.StatusCode, contentType, respBody) - return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode) - } - - // 成功响应:透传 response header + body - requestID := resp.Header.Get("x-request-id") - - // 透传上游响应头(排除 hop-by-hop) - for key, values := range resp.Header { - if upstreamHopByHopHeaders[strings.ToLower(key)] { - continue - } - for _, v := range values { - c.Header(key, v) - } - } - - c.Status(resp.StatusCode) - _, copyErr := io.Copy(c.Writer, resp.Body) - if copyErr != nil { - log.Printf("%s status=copy_error error=%v", prefix, copyErr) - } - - imageCount := 0 - if isImageGenerationModel(mappedModel) { - imageCount = 1 - } - - return &ForwardResult{ - RequestID: requestID, - Model: originalModel, - Stream: stream, - Duration: time.Since(startTime), - ImageCount: imageCount, - ImageSize: imageSize, - }, nil -} diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 0f156c2e..4e0442fd 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -560,10 +560,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex return nil, "", errors.New("gemini api_key not configured") } - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, "", err @@ -640,10 +637,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex return upstreamReq, "x-request-id", nil } else { // Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token) - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, "", err @@ -1026,10 +1020,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. return nil, "", errors.New("gemini api_key not configured") } - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, "", err @@ -1097,10 +1088,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. return upstreamReq, "x-request-id", nil } else { // Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token) - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, "", err @@ -2420,10 +2408,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac return nil, errors.New("invalid path") } - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, err diff --git a/backend/internal/service/upstream_header_passthrough_test.go b/backend/internal/service/upstream_header_passthrough_test.go deleted file mode 100644 index 51d8588b..00000000 --- a/backend/internal/service/upstream_header_passthrough_test.go +++ /dev/null @@ -1,285 +0,0 @@ -//go:build unit - -package service - -import ( - "bytes" - "context" - "encoding/json" - "io" - "net/http" - "net/http/httptest" - "testing" - - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/require" -) - -// httpUpstreamCapture captures the outgoing *http.Request for assertion. -type httpUpstreamCapture struct { - capturedReq *http.Request - resp *http.Response - err error -} - -func (s *httpUpstreamCapture) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) { - s.capturedReq = req - return s.resp, s.err -} - -func (s *httpUpstreamCapture) DoWithTLS(req *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) { - s.capturedReq = req - return s.resp, s.err -} - -func newUpstreamAccount() *Account { - return &Account{ - ID: 100, - Name: "upstream-test", - Platform: PlatformAntigravity, - Type: AccountTypeUpstream, - Status: StatusActive, - Concurrency: 1, - Credentials: map[string]any{ - "base_url": "https://upstream.example.com", - "api_key": "sk-upstream-secret", - }, - } -} - -// makeSSEOKResponse builds a minimal SSE response that -// handleClaudeStreamingResponse / handleGeminiStreamingResponse -// can consume without error. -// We return 502 to bypass streaming and hit the error branch instead, -// which is sufficient for testing header passthrough. -func makeUpstreamErrorResponse() *http.Response { - body := []byte(`{"error":{"message":"test error"}}`) - return &http.Response{ - StatusCode: http.StatusBadGateway, - Header: http.Header{"Content-Type": []string{"application/json"}}, - Body: io.NopCloser(bytes.NewReader(body)), - } -} - -// --- ForwardUpstream tests --- - -func TestForwardUpstream_PassthroughHeaders(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - body, _ := json.Marshal(map[string]any{ - "model": "claude-sonnet-4-5", - "messages": []map[string]any{{"role": "user", "content": "hi"}}, - "max_tokens": 1, - "stream": false, - }) - - req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("anthropic-version", "2024-10-22") - req.Header.Set("anthropic-beta", "output-128k-2025-02-19") - req.Header.Set("X-Custom-Header", "custom-value") - c.Request = req - - stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} - svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: stub, - } - - _, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false) - - captured := stub.capturedReq - require.NotNil(t, captured, "upstream request should have been made") - - // 客户端 header 应被透传 - require.Equal(t, "application/json", captured.Header.Get("Content-Type")) - require.Equal(t, "2024-10-22", captured.Header.Get("anthropic-version")) - require.Equal(t, "output-128k-2025-02-19", captured.Header.Get("anthropic-beta")) - require.Equal(t, "custom-value", captured.Header.Get("X-Custom-Header")) -} - -func TestForwardUpstream_OverridesAuthHeaders(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - body, _ := json.Marshal(map[string]any{ - "model": "claude-sonnet-4-5", - "messages": []map[string]any{{"role": "user", "content": "hi"}}, - "max_tokens": 1, - "stream": false, - }) - - req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - // 客户端发来的认证头应被覆盖 - req.Header.Set("Authorization", "Bearer client-token") - req.Header.Set("x-api-key", "client-api-key") - c.Request = req - - stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} - svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: stub, - } - - _, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false) - - captured := stub.capturedReq - require.NotNil(t, captured) - - // 认证头应使用上游账号的 api_key,而非客户端的 - require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization")) - require.Equal(t, "sk-upstream-secret", captured.Header.Get("x-api-key")) -} - -func TestForwardUpstream_ExcludesHopByHopHeaders(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - body, _ := json.Marshal(map[string]any{ - "model": "claude-sonnet-4-5", - "messages": []map[string]any{{"role": "user", "content": "hi"}}, - "max_tokens": 1, - "stream": false, - }) - - req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Connection", "keep-alive") - req.Header.Set("Keep-Alive", "timeout=5") - req.Header.Set("Transfer-Encoding", "chunked") - req.Header.Set("Upgrade", "websocket") - req.Header.Set("Te", "trailers") - c.Request = req - - stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} - svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: stub, - } - - _, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false) - - captured := stub.capturedReq - require.NotNil(t, captured) - - // hop-by-hop header 不应出现 - require.Empty(t, captured.Header.Get("Connection")) - require.Empty(t, captured.Header.Get("Keep-Alive")) - require.Empty(t, captured.Header.Get("Transfer-Encoding")) - require.Empty(t, captured.Header.Get("Upgrade")) - require.Empty(t, captured.Header.Get("Te")) - - // 但普通 header 应保留 - require.Equal(t, "application/json", captured.Header.Get("Content-Type")) -} - -// --- ForwardUpstreamGemini tests --- - -func TestForwardUpstreamGemini_PassthroughHeaders(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - body, _ := json.Marshal(map[string]any{ - "contents": []map[string]any{ - {"role": "user", "parts": []map[string]any{{"text": "hi"}}}, - }, - }) - - req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("X-Custom-Gemini", "gemini-value") - req.Header.Set("X-Request-Id", "req-abc-123") - c.Request = req - - stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} - svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: stub, - } - - _, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false) - - captured := stub.capturedReq - require.NotNil(t, captured, "upstream request should have been made") - - // 客户端 header 应被透传 - require.Equal(t, "application/json", captured.Header.Get("Content-Type")) - require.Equal(t, "gemini-value", captured.Header.Get("X-Custom-Gemini")) - require.Equal(t, "req-abc-123", captured.Header.Get("X-Request-Id")) -} - -func TestForwardUpstreamGemini_OverridesAuthHeaders(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - body, _ := json.Marshal(map[string]any{ - "contents": []map[string]any{ - {"role": "user", "parts": []map[string]any{{"text": "hi"}}}, - }, - }) - - req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer client-gemini-token") - c.Request = req - - stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} - svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: stub, - } - - _, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false) - - captured := stub.capturedReq - require.NotNil(t, captured) - - // 认证头应使用上游账号的 api_key - require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization")) -} - -func TestForwardUpstreamGemini_ExcludesHopByHopHeaders(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - body, _ := json.Marshal(map[string]any{ - "contents": []map[string]any{ - {"role": "user", "parts": []map[string]any{{"text": "hi"}}}, - }, - }) - - req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Connection", "keep-alive") - req.Header.Set("Proxy-Authorization", "Basic dXNlcjpwYXNz") - req.Header.Set("Host", "evil.example.com") - c.Request = req - - stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} - svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: stub, - } - - _, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false) - - captured := stub.capturedReq - require.NotNil(t, captured) - - // hop-by-hop header 不应出现 - require.Empty(t, captured.Header.Get("Connection")) - require.Empty(t, captured.Header.Get("Proxy-Authorization")) - // Host header 在 Go http.Request 中特殊处理,但我们的黑名单应阻止透传 - require.Empty(t, captured.Header.Values("Host")) - - // 普通 header 应保留 - require.Equal(t, "application/json", captured.Header.Get("Content-Type")) -} diff --git a/backend/migrations/052_migrate_upstream_to_apikey.sql b/backend/migrations/052_migrate_upstream_to_apikey.sql new file mode 100644 index 00000000..974f3f3c --- /dev/null +++ b/backend/migrations/052_migrate_upstream_to_apikey.sql @@ -0,0 +1,11 @@ +-- Migrate upstream accounts to apikey type +-- Background: upstream type is no longer needed. Antigravity platform APIKey accounts +-- with base_url pointing to an upstream sub2api instance can reuse the standard +-- APIKey forwarding path. GetBaseURL()/GetGeminiBaseURL() automatically appends +-- /antigravity for Antigravity platform APIKey accounts. + +UPDATE accounts +SET type = 'apikey' +WHERE type = 'upstream' + AND platform = 'antigravity' + AND deleted_at IS NULL; diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 7d759be1..603941c1 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -2289,9 +2289,9 @@ watch( watch( [accountCategory, addMethod, antigravityAccountType], ([category, method, agType]) => { - // Antigravity upstream 类型 + // Antigravity upstream 类型(实际创建为 apikey) if (form.platform === 'antigravity' && agType === 'upstream') { - form.type = 'upstream' + form.type = 'apikey' return } if (category === 'oauth-based') { @@ -2715,7 +2715,7 @@ const handleSubmit = async () => { submitting.value = true try { const extra = mixedScheduling.value ? { mixed_scheduling: true } : undefined - await createAccountAndFinish(form.platform, 'upstream', credentials, extra) + await createAccountAndFinish(form.platform, 'apikey', credentials, extra) } catch (error: any) { appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToCreate')) } finally { From 3c936441469d9483bd02c2681fcfbea9fa271f9a Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 8 Feb 2026 13:14:58 +0800 Subject: [PATCH 13/41] chore: bump version to 0.1.74.7 --- backend/cmd/server/VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index f0768f09..bc88be6e 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.70 +0.1.74.7 From b4ec65785d9fbf525d9cc0202663c830be1a6791 Mon Sep 17 00:00:00 2001 From: shaw Date: Sun, 8 Feb 2026 13:26:28 +0800 Subject: [PATCH 14/41] =?UTF-8?q?fix:=20apikey=E7=B1=BB=E5=9E=8B=E8=B4=A6?= =?UTF-8?q?=E5=8F=B7test=E5=8E=BB=E6=8E=89oauth-2025-04-20?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/service/account_test_service.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 3290fe52..899a4498 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -245,7 +245,6 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account // Set common headers req.Header.Set("Content-Type", "application/json") req.Header.Set("anthropic-version", "2023-06-01") - req.Header.Set("anthropic-beta", claude.DefaultBetaHeader) // Apply Claude Code client headers for key, value := range claude.DefaultHeaders { @@ -254,8 +253,10 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account // Set authentication header if useBearer { + req.Header.Set("anthropic-beta", claude.DefaultBetaHeader) req.Header.Set("Authorization", "Bearer "+authToken) } else { + req.Header.Set("anthropic-beta", claude.APIKeyBetaHeader) req.Header.Set("x-api-key", authToken) } From 69816f8691e9374adfafde596c5b5a34ec96ddaf Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 8 Feb 2026 13:30:39 +0800 Subject: [PATCH 15/41] fix: remove unused upstreamHopByHopHeaders variable to pass golangci-lint --- .../service/antigravity_gateway_service.go | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 4ea73e64..26b1c530 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -47,21 +47,6 @@ const ( googleRPCReasonRateLimitExceeded = "RATE_LIMIT_EXCEEDED" ) -// upstreamHopByHopHeaders 透传请求头时需要排除的 hop-by-hop 头 -var upstreamHopByHopHeaders = map[string]bool{ - "connection": true, - "keep-alive": true, - "proxy-authenticate": true, - "proxy-authorization": true, - "proxy-connection": true, - "te": true, - "trailer": true, - "transfer-encoding": true, - "upgrade": true, - "host": true, - "content-length": true, -} - // antigravityPassthroughErrorMessages 透传给客户端的错误消息白名单(小写) // 匹配时使用 strings.Contains,无需完全匹配 var antigravityPassthroughErrorMessages = []string{ From b1c30df8e300fb4258352e5384ee8fe3fe3cc240 Mon Sep 17 00:00:00 2001 From: shaw Date: Sun, 8 Feb 2026 14:00:02 +0800 Subject: [PATCH 16/41] fix(ui): unify admin table toolbar layout with search and buttons in single row Standardize filter bar layout across admin pages to place search/filters on left and action buttons on right within the same row, improving visual consistency and space utilization. --- .../src/views/admin/AnnouncementsView.vue | 50 ++++++------- frontend/src/views/admin/PromoCodesView.vue | 50 ++++++------- frontend/src/views/admin/ProxiesView.vue | 74 +++++++++---------- frontend/src/views/admin/RedeemView.vue | 56 +++++++------- frontend/src/views/admin/UsersView.vue | 6 +- 5 files changed, 114 insertions(+), 122 deletions(-) diff --git a/frontend/src/views/admin/AnnouncementsView.vue b/frontend/src/views/admin/AnnouncementsView.vue index 38574454..08d7b871 100644 --- a/frontend/src/views/admin/AnnouncementsView.vue +++ b/frontend/src/views/admin/AnnouncementsView.vue @@ -1,26 +1,10 @@