diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 614ded8d..bbc9c181 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -1,5 +1,3 @@ -// Package handler provides HTTP request handlers for the API gateway. -// It handles authentication, request routing, concurrency control, and billing validation. package handler import ( @@ -13,6 +11,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -29,7 +28,6 @@ type GatewayHandler struct { userService *service.UserService billingCacheService *service.BillingCacheService concurrencyHelper *ConcurrencyHelper - opsService *service.OpsService } // NewGatewayHandler creates a new GatewayHandler @@ -40,7 +38,6 @@ func NewGatewayHandler( userService *service.UserService, concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService, - opsService *service.OpsService, ) *GatewayHandler { return &GatewayHandler{ gatewayService: gatewayService, @@ -49,15 +46,14 @@ func NewGatewayHandler( userService: userService, billingCacheService: billingCacheService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude), - opsService: opsService, } } // Messages handles Claude API compatible messages endpoint // POST /v1/messages func (h *GatewayHandler) Messages(c *gin.Context) { - // 从context获取apiKey和user(APIKeyAuth中间件已设置) - apiKey, ok := middleware2.GetAPIKeyFromContext(c) + // 从context获取apiKey和user(ApiKeyAuth中间件已设置) + apiKey, ok := middleware2.GetApiKeyFromContext(c) if !ok { h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") return @@ -92,7 +88,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } reqModel := parsedReq.Model reqStream := parsedReq.Stream - setOpsRequestContext(c, reqModel, reqStream) // 验证 model 必填 if reqModel == "" { @@ -264,7 +259,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { defer cancel() if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, - APIKey: apiKey, + ApiKey: apiKey, User: apiKey.User, Account: usedAccount, Subscription: subscription, @@ -388,7 +383,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { defer cancel() if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, - APIKey: apiKey, + ApiKey: apiKey, User: apiKey.User, Account: usedAccount, Subscription: subscription, @@ -405,7 +400,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // Returns models based on account configurations (model_mapping whitelist) // Falls back to default models if no whitelist is configured func (h *GatewayHandler) Models(c *gin.Context) { - apiKey, _ := middleware2.GetAPIKeyFromContext(c) + apiKey, _ := middleware2.GetApiKeyFromContext(c) var groupID *int64 var platform string @@ -451,10 +446,19 @@ func (h *GatewayHandler) Models(c *gin.Context) { }) } +// AntigravityModels 返回 Antigravity 支持的全部模型 +// GET /antigravity/models +func (h *GatewayHandler) AntigravityModels(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "object": "list", + "data": antigravity.DefaultModels(), + }) +} + // Usage handles getting account balance for CC Switch integration // GET /v1/usage func (h *GatewayHandler) Usage(c *gin.Context) { - apiKey, ok := middleware2.GetAPIKeyFromContext(c) + apiKey, ok := middleware2.GetApiKeyFromContext(c) if !ok { h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") return @@ -579,7 +583,6 @@ func (h *GatewayHandler) mapUpstreamError(statusCode int) (int, string, string) // handleStreamingAwareError handles errors that may occur after streaming has started func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) { if streamStarted { - recordOpsError(c, h.opsService, status, errType, message, "") // Stream already started, send error as SSE event then close flusher, ok := c.Writer.(http.Flusher) if ok { @@ -611,7 +614,6 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e // errorResponse 返回Claude API格式的错误响应 func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) { - recordOpsError(c, h.opsService, status, errType, message, "") c.JSON(status, gin.H{ "type": "error", "error": gin.H{ @@ -625,8 +627,8 @@ func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, mess // POST /v1/messages/count_tokens // 特点:校验订阅/余额,但不计算并发、不记录使用量 func (h *GatewayHandler) CountTokens(c *gin.Context) { - // 从context获取apiKey和user(APIKeyAuth中间件已设置) - apiKey, ok := middleware2.GetAPIKeyFromContext(c) + // 从context获取apiKey和user(ApiKeyAuth中间件已设置) + apiKey, ok := middleware2.GetApiKeyFromContext(c) if !ok { h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") return diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 79ec9950..71678bed 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/gemini" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -20,7 +21,7 @@ import ( // GeminiV1BetaListModels proxies: // GET /v1beta/models func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) { - apiKey, ok := middleware.GetAPIKeyFromContext(c) + apiKey, ok := middleware.GetApiKeyFromContext(c) if !ok || apiKey == nil { googleError(c, http.StatusUnauthorized, "Invalid API key") return @@ -32,9 +33,9 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) { return } - // 强制 antigravity 模式:直接返回静态模型列表 + // 强制 antigravity 模式:返回 antigravity 支持的模型列表 if forcePlatform == service.PlatformAntigravity { - c.JSON(http.StatusOK, gemini.FallbackModelsList()) + c.JSON(http.StatusOK, antigravity.FallbackGeminiModelsList()) return } @@ -66,7 +67,7 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) { // GeminiV1BetaGetModel proxies: // GET /v1beta/models/{model} func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) { - apiKey, ok := middleware.GetAPIKeyFromContext(c) + apiKey, ok := middleware.GetApiKeyFromContext(c) if !ok || apiKey == nil { googleError(c, http.StatusUnauthorized, "Invalid API key") return @@ -84,9 +85,9 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) { return } - // 强制 antigravity 模式:直接返回静态模型信息 + // 强制 antigravity 模式:返回 antigravity 模型信息 if forcePlatform == service.PlatformAntigravity { - c.JSON(http.StatusOK, gemini.FallbackModel(modelName)) + c.JSON(http.StatusOK, antigravity.FallbackGeminiModel(modelName)) return } @@ -119,7 +120,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) { // POST /v1beta/models/{model}:generateContent // POST /v1beta/models/{model}:streamGenerateContent?alt=sse func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { - apiKey, ok := middleware.GetAPIKeyFromContext(c) + apiKey, ok := middleware.GetApiKeyFromContext(c) if !ok || apiKey == nil { googleError(c, http.StatusUnauthorized, "Invalid API key") return @@ -298,7 +299,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { defer cancel() if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, - APIKey: apiKey, + ApiKey: apiKey, User: apiKey.User, Account: usedAccount, Subscription: subscription, diff --git a/backend/internal/pkg/antigravity/claude_types.go b/backend/internal/pkg/antigravity/claude_types.go index 34e6b1f4..8a29cd10 100644 --- a/backend/internal/pkg/antigravity/claude_types.go +++ b/backend/internal/pkg/antigravity/claude_types.go @@ -138,3 +138,91 @@ type ErrorDetail struct { Type string `json:"type"` Message string `json:"message"` } + +// modelDef Antigravity 模型定义(内部使用) +type modelDef struct { + ID string + DisplayName string + CreatedAt string // 仅 Claude API 格式使用 +} + +// Antigravity 支持的 Claude 模型 +var claudeModels = []modelDef{ + {ID: "claude-opus-4-5-thinking", DisplayName: "Claude Opus 4.5 Thinking", CreatedAt: "2025-11-01T00:00:00Z"}, + {ID: "claude-sonnet-4-5", DisplayName: "Claude Sonnet 4.5", CreatedAt: "2025-09-29T00:00:00Z"}, + {ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"}, +} + +// Antigravity 支持的 Gemini 模型 +var geminiModels = []modelDef{ + {ID: "gemini-2.5-flash", DisplayName: "Gemini 2.5 Flash", CreatedAt: "2025-01-01T00:00:00Z"}, + {ID: "gemini-2.5-flash-lite", DisplayName: "Gemini 2.5 Flash Lite", CreatedAt: "2025-01-01T00:00:00Z"}, + {ID: "gemini-2.5-flash-thinking", DisplayName: "Gemini 2.5 Flash Thinking", CreatedAt: "2025-01-01T00:00:00Z"}, + {ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"}, + {ID: "gemini-3-pro-low", DisplayName: "Gemini 3 Pro Low", CreatedAt: "2025-06-01T00:00:00Z"}, + {ID: "gemini-3-pro-high", DisplayName: "Gemini 3 Pro High", CreatedAt: "2025-06-01T00:00:00Z"}, + {ID: "gemini-3-pro-preview", DisplayName: "Gemini 3 Pro Preview", CreatedAt: "2025-06-01T00:00:00Z"}, + {ID: "gemini-3-pro-image", DisplayName: "Gemini 3 Pro Image", CreatedAt: "2025-06-01T00:00:00Z"}, +} + +// ========== Claude API 格式 (/v1/models) ========== + +// ClaudeModel Claude API 模型格式 +type ClaudeModel struct { + ID string `json:"id"` + Type string `json:"type"` + DisplayName string `json:"display_name"` + CreatedAt string `json:"created_at"` +} + +// DefaultModels 返回 Claude API 格式的模型列表(Claude + Gemini) +func DefaultModels() []ClaudeModel { + all := append(claudeModels, geminiModels...) + result := make([]ClaudeModel, len(all)) + for i, m := range all { + result[i] = ClaudeModel{ID: m.ID, Type: "model", DisplayName: m.DisplayName, CreatedAt: m.CreatedAt} + } + return result +} + +// ========== Gemini v1beta 格式 (/v1beta/models) ========== + +// GeminiModel Gemini v1beta 模型格式 +type GeminiModel struct { + Name string `json:"name"` + DisplayName string `json:"displayName,omitempty"` + SupportedGenerationMethods []string `json:"supportedGenerationMethods,omitempty"` +} + +// GeminiModelsListResponse Gemini v1beta 模型列表响应 +type GeminiModelsListResponse struct { + Models []GeminiModel `json:"models"` +} + +var defaultGeminiMethods = []string{"generateContent", "streamGenerateContent"} + +// DefaultGeminiModels 返回 Gemini v1beta 格式的模型列表(仅 Gemini 模型) +func DefaultGeminiModels() []GeminiModel { + result := make([]GeminiModel, len(geminiModels)) + for i, m := range geminiModels { + result[i] = GeminiModel{Name: "models/" + m.ID, DisplayName: m.DisplayName, SupportedGenerationMethods: defaultGeminiMethods} + } + return result +} + +// FallbackGeminiModelsList 返回 Gemini v1beta 格式的模型列表响应 +func FallbackGeminiModelsList() GeminiModelsListResponse { + return GeminiModelsListResponse{Models: DefaultGeminiModels()} +} + +// FallbackGeminiModel 返回单个模型信息(v1beta 格式) +func FallbackGeminiModel(model string) GeminiModel { + if model == "" { + return GeminiModel{Name: "models/unknown", SupportedGenerationMethods: defaultGeminiMethods} + } + name := model + if len(model) < 7 || model[:7] != "models/" { + name = "models/" + model + } + return GeminiModel{Name: name, SupportedGenerationMethods: defaultGeminiMethods} +} diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index 90ff34e7..003398bd 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -1,5 +1,3 @@ -// Package antigravity provides a client for interacting with Google's Antigravity API, -// handling OAuth authentication, token management, and account tier information retrieval. package antigravity import ( @@ -59,6 +57,29 @@ type TierInfo struct { Description string `json:"description"` // 描述 } +// UnmarshalJSON supports both legacy string tiers and object tiers. +func (t *TierInfo) UnmarshalJSON(data []byte) error { + data = bytes.TrimSpace(data) + if len(data) == 0 || string(data) == "null" { + return nil + } + if data[0] == '"' { + var id string + if err := json.Unmarshal(data, &id); err != nil { + return err + } + t.ID = id + return nil + } + type alias TierInfo + var decoded alias + if err := json.Unmarshal(data, &decoded); err != nil { + return err + } + *t = TierInfo(decoded) + return nil +} + // IneligibleTier 不符合条件的层级信息 type IneligibleTier struct { Tier *TierInfo `json:"tier,omitempty"` diff --git a/backend/internal/pkg/antigravity/gemini_types.go b/backend/internal/pkg/antigravity/gemini_types.go index 8e3e3885..67f6c3e7 100644 --- a/backend/internal/pkg/antigravity/gemini_types.go +++ b/backend/internal/pkg/antigravity/gemini_types.go @@ -143,9 +143,10 @@ type GeminiCandidate struct { // GeminiUsageMetadata Gemini 用量元数据 type GeminiUsageMetadata struct { - PromptTokenCount int `json:"promptTokenCount,omitempty"` - CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"` - TotalTokenCount int `json:"totalTokenCount,omitempty"` + PromptTokenCount int `json:"promptTokenCount,omitempty"` + CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"` + CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"` + TotalTokenCount int `json:"totalTokenCount,omitempty"` } // DefaultSafetySettings 默认安全设置(关闭所有过滤) diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index 3af6579c..9a62ea03 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -150,13 +150,18 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT // 历史 assistant 消息不能添加没有 signature 的 dummy thinking block if allowDummyThought && role == "model" && isThinkingEnabled && i == len(messages)-1 { hasThoughtPart := false - for _, p := range parts { + firstPartIsThought := false + for idx, p := range parts { if p.Thought { hasThoughtPart = true + if idx == 0 { + firstPartIsThought = true + } break } } - if !hasThoughtPart && len(parts) > 0 { + // 如果没有thinking part,或者有thinking part但不在第一个位置,都需要在开头添加dummy thinking block + if len(parts) > 0 && (!hasThoughtPart || !firstPartIsThought) { // 在开头添加 dummy thinking block parts = append([]GeminiPart{{ Text: "Thinking...", @@ -236,6 +241,7 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, thought // Claude via Vertex: // - signature 是上游返回的完整性令牌;本地不需要/无法验证,只能透传 // - 缺失/无效 signature(例如来自 Gemini 的 dummy signature)会导致上游 400 + // - 为避免泄露 thinking 内容,缺失/无效 signature 的 thinking 直接丢弃 if signature == "" || signature == dummyThoughtSignature { continue } @@ -429,7 +435,7 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { // 普通工具 var funcDecls []GeminiFunctionDecl - for i, tool := range tools { + for _, tool := range tools { // 跳过无效工具名称 if strings.TrimSpace(tool.Name) == "" { log.Printf("Warning: skipping tool with empty name") @@ -448,10 +454,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { description = tool.Custom.Description inputSchema = tool.Custom.InputSchema - // 调试日志:记录 custom 工具的 schema - if schemaJSON, err := json.Marshal(inputSchema); err == nil { - log.Printf("[Debug] Tool[%d] '%s' (custom) original schema: %s", i, tool.Name, string(schemaJSON)) - } } else { // 标准格式: 从顶层字段获取 description = tool.Description @@ -468,11 +470,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { } } - // 调试日志:记录清理后的 schema - if paramsJSON, err := json.Marshal(params); err == nil { - log.Printf("[Debug] Tool[%d] '%s' cleaned schema: %s", i, tool.Name, string(paramsJSON)) - } - funcDecls = append(funcDecls, GeminiFunctionDecl{ Name: tool.Name, Description: description, @@ -627,20 +624,16 @@ func cleanSchemaValue(value any) any { if k == "additionalProperties" { if boolVal, ok := val.(bool); ok { result[k] = boolVal - log.Printf("[Debug] additionalProperties is bool: %v", boolVal) } else { // 如果是 schema 对象,转换为 false(更安全的默认值) result[k] = false - log.Printf("[Debug] additionalProperties is not bool (type: %T), converting to false", val) } continue } - // 递归清理所有值 result[k] = cleanSchemaValue(val) } return result - case []any: // 递归处理数组中的每个元素 cleaned := make([]any, 0, len(v)) diff --git a/backend/internal/pkg/antigravity/request_transformer_test.go b/backend/internal/pkg/antigravity/request_transformer_test.go index 845ae033..171ad078 100644 --- a/backend/internal/pkg/antigravity/request_transformer_test.go +++ b/backend/internal/pkg/antigravity/request_transformer_test.go @@ -15,15 +15,15 @@ func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) { description string }{ { - name: "Claude model - skip thinking block without signature", + name: "Claude model - drop thinking without signature", content: `[ {"type": "text", "text": "Hello"}, {"type": "thinking", "thinking": "Let me think...", "signature": ""}, {"type": "text", "text": "World"} ]`, thoughtMode: thoughtSignatureModePreserve, - expectedParts: 2, // 只有两个text block - description: "Claude模型应该跳过无signature的thinking block", + expectedParts: 2, // thinking 内容被丢弃 + description: "Claude模型应丢弃无signature的thinking block内容", }, { name: "Claude model - preserve thinking block with signature", diff --git a/backend/internal/pkg/antigravity/response_transformer.go b/backend/internal/pkg/antigravity/response_transformer.go index 799de694..9f63c958 100644 --- a/backend/internal/pkg/antigravity/response_transformer.go +++ b/backend/internal/pkg/antigravity/response_transformer.go @@ -232,10 +232,18 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon stopReason = "max_tokens" } + // 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount, + // 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去 usage := ClaudeUsage{} if geminiResp.UsageMetadata != nil { - usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount + cached := geminiResp.UsageMetadata.CachedContentTokenCount + prompt := geminiResp.UsageMetadata.PromptTokenCount + if cached > prompt { + cached = prompt + } + usage.InputTokens = prompt - cached usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + usage.CacheReadInputTokens = cached } // 生成响应 ID diff --git a/backend/internal/pkg/antigravity/stream_transformer.go b/backend/internal/pkg/antigravity/stream_transformer.go index c5d954f5..acb33354 100644 --- a/backend/internal/pkg/antigravity/stream_transformer.go +++ b/backend/internal/pkg/antigravity/stream_transformer.go @@ -29,8 +29,9 @@ type StreamingProcessor struct { originalModel string // 累计 usage - inputTokens int - outputTokens int + inputTokens int + outputTokens int + cacheReadTokens int } // NewStreamingProcessor 创建流式响应处理器 @@ -76,9 +77,17 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte { } // 更新 usage + // 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount, + // 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去 if geminiResp.UsageMetadata != nil { - p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount + cached := geminiResp.UsageMetadata.CachedContentTokenCount + prompt := geminiResp.UsageMetadata.PromptTokenCount + if cached > prompt { + cached = prompt + } + p.inputTokens = prompt - cached p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + p.cacheReadTokens = cached } // 处理 parts @@ -108,8 +117,9 @@ func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) { } usage := &ClaudeUsage{ - InputTokens: p.inputTokens, - OutputTokens: p.outputTokens, + InputTokens: p.inputTokens, + OutputTokens: p.outputTokens, + CacheReadInputTokens: p.cacheReadTokens, } return result.Bytes(), usage @@ -123,8 +133,14 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte usage := ClaudeUsage{} if v1Resp.Response.UsageMetadata != nil { - usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount + cached := v1Resp.Response.UsageMetadata.CachedContentTokenCount + prompt := v1Resp.Response.UsageMetadata.PromptTokenCount + if cached > prompt { + cached = prompt + } + usage.InputTokens = prompt - cached usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + usage.CacheReadInputTokens = cached } responseID := v1Resp.ResponseID @@ -418,8 +434,9 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte { } usage := ClaudeUsage{ - InputTokens: p.inputTokens, - OutputTokens: p.outputTokens, + InputTokens: p.inputTokens, + OutputTokens: p.outputTokens, + CacheReadInputTokens: p.cacheReadTokens, } deltaEvent := map[string]any{ diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index d9e0bb81..941f1ce9 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -13,8 +13,8 @@ import ( func RegisterGatewayRoutes( r *gin.Engine, h *handler.Handlers, - apiKeyAuth middleware.APIKeyAuthMiddleware, - apiKeyService *service.APIKeyService, + apiKeyAuth middleware.ApiKeyAuthMiddleware, + apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config, ) { @@ -36,7 +36,7 @@ func RegisterGatewayRoutes( // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) gemini := r.Group("/v1beta") gemini.Use(bodyLimit) - gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) + gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) { gemini.GET("/models", h.Gateway.GeminiV1BetaListModels) gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) @@ -47,6 +47,9 @@ func RegisterGatewayRoutes( // OpenAI Responses API(不带v1前缀的别名) r.POST("/responses", bodyLimit, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses) + // Antigravity 模型列表 + r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), h.Gateway.AntigravityModels) + // Antigravity 专用路由(仅使用 antigravity 账户,不混合调度) antigravityV1 := r.Group("/antigravity/v1") antigravityV1.Use(bodyLimit) @@ -55,14 +58,14 @@ func RegisterGatewayRoutes( { antigravityV1.POST("/messages", h.Gateway.Messages) antigravityV1.POST("/messages/count_tokens", h.Gateway.CountTokens) - antigravityV1.GET("/models", h.Gateway.Models) + antigravityV1.GET("/models", h.Gateway.AntigravityModels) antigravityV1.GET("/usage", h.Gateway.Usage) } antigravityV1Beta := r.Group("/antigravity/v1beta") antigravityV1Beta.Use(bodyLimit) antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity)) - antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) + antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) { antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels) antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index be908189..5f398740 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -49,11 +49,11 @@ var antigravityPrefixMapping = []struct { {"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等 {"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx {"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx - {"claude-haiku-4-5", "gemini-3-flash"}, // claude-haiku-4-5-xxx + {"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet {"claude-opus-4-5", "claude-opus-4-5-thinking"}, - {"claude-3-haiku", "gemini-3-flash"}, // 旧版 claude-3-haiku-xxx + {"claude-3-haiku", "claude-sonnet-4-5"}, // 旧版 claude-3-haiku-xxx → sonnet {"claude-sonnet-4", "claude-sonnet-4-5"}, - {"claude-haiku-4", "gemini-3-flash"}, + {"claude-haiku-4", "claude-sonnet-4-5"}, // → sonnet {"claude-opus-4", "claude-opus-4-5-thinking"}, {"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等 } @@ -64,6 +64,7 @@ type AntigravityGatewayService struct { tokenProvider *AntigravityTokenProvider rateLimitService *RateLimitService httpUpstream HTTPUpstream + settingService *SettingService } func NewAntigravityGatewayService( @@ -72,12 +73,14 @@ func NewAntigravityGatewayService( tokenProvider *AntigravityTokenProvider, rateLimitService *RateLimitService, httpUpstream HTTPUpstream, + settingService *SettingService, ) *AntigravityGatewayService { return &AntigravityGatewayService{ accountRepo: accountRepo, tokenProvider: tokenProvider, rateLimitService: rateLimitService, httpUpstream: httpUpstream, + settingService: settingService, } } @@ -308,6 +311,7 @@ func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byt } // isSignatureRelatedError 检测是否为 signature 相关的 400 错误 +// 注意:不包含 "thinking" 关键词,避免误判消息格式错误为 signature 错误 func isSignatureRelatedError(statusCode int, body []byte) bool { if statusCode != 400 { return false @@ -318,7 +322,6 @@ func isSignatureRelatedError(statusCode int, body []byte) bool { "signature", "thought_signature", "thoughtsignature", - "thinking", "invalid signature", "signature validation", } @@ -331,28 +334,60 @@ func isSignatureRelatedError(statusCode int, body []byte) bool { return false } -// stripThinkingFromClaudeRequest 从 Claude 请求中移除所有 thinking 相关内容 +// isModelNotFoundError 检测是否为模型不存在的 404 错误 +func isModelNotFoundError(statusCode int, body []byte) bool { + if statusCode != 404 { + return false + } + + bodyStr := strings.ToLower(string(body)) + keywords := []string{ + "model not found", + "model does not exist", + "unknown model", + "invalid model", + } + + for _, keyword := range keywords { + if strings.Contains(bodyStr, keyword) { + return true + } + } + return false +} + +// stripThinkingFromClaudeRequest 从 Claude 请求中移除有问题的 thinking 块 +// 策略:只移除历史消息中带 dummy signature 的 thinking 块,保留本次 thinking 配置 +// 这样可以让本次对话仍然使用 thinking 功能,只是清理历史中可能导致问题的内容 func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) *antigravity.ClaudeRequest { // 创建副本 stripped := *req - // 移除 thinking 配置 - stripped.Thinking = nil + // 保留 thinking 配置,让本次对话仍然可以使用 thinking + // stripped.Thinking = nil // 不再移除 - // 移除消息中的 thinking 块 + // 只移除消息中带 dummy signature 的 thinking 块 if len(stripped.Messages) > 0 { newMessages := make([]antigravity.ClaudeMessage, 0, len(stripped.Messages)) for _, msg := range stripped.Messages { newMsg := msg - // 如果 content 是数组,过滤 thinking 块 + // 如果 content 是数组,过滤有问题的 thinking 块 var blocks []map[string]any if err := json.Unmarshal(msg.Content, &blocks); err == nil { filtered := make([]map[string]any, 0, len(blocks)) for _, block := range blocks { - // 跳过有 type="thinking" 的块 + // 跳过带 dummy signature 的 thinking 块 if blockType, ok := block["type"].(string); ok && blockType == "thinking" { - continue + if sig, ok := block["signature"].(string); ok { + // 移除 dummy signature 的 thinking 块 + if sig == "skip_thought_signature_validator" || sig == "" { + continue + } + } else { + // 没有 signature 字段的 thinking 块也移除 + continue + } } // 跳过没有 type 但有 thinking 字段的块(untyped thinking blocks) if _, hasType := block["type"]; !hasType { @@ -390,9 +425,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, originalModel := claudeReq.Model mappedModel := s.getMappedModel(account, claudeReq.Model) - if mappedModel != claudeReq.Model { - log.Printf("Antigravity model mapping: %s -> %s (account: %s)", claudeReq.Model, mappedModel, account.Name) - } // 获取 access_token if s.tokenProvider == nil { @@ -418,15 +450,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, return nil, fmt.Errorf("transform request: %w", err) } - // 调试:记录转换后的请求体(仅记录前 2000 字符) - if bodyJSON, err := json.Marshal(geminiBody); err == nil { - truncated := string(bodyJSON) - if len(truncated) > 2000 { - truncated = truncated[:2000] + "..." - } - log.Printf("[Debug] Transformed Gemini request: %s", truncated) - } - // 构建上游 action action := "generateContent" if claudeReq.Stream { @@ -495,7 +518,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, if err != nil { log.Printf("[Antigravity] Failed to transform stripped request: %v", err) // 降级失败,返回原始错误 - if s.shouldFailoverUpstreamError(resp.StatusCode) { + if s.shouldFailoverWithTempUnsched(ctx, account, resp.StatusCode, respBody) { return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} } return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody) @@ -505,7 +528,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, retryReq, err := antigravity.NewAPIRequest(ctx, action, accessToken, strippedBody) if err != nil { log.Printf("[Antigravity] Failed to create retry request: %v", err) - if s.shouldFailoverUpstreamError(resp.StatusCode) { + if s.shouldFailoverWithTempUnsched(ctx, account, resp.StatusCode, respBody) { return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} } return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody) @@ -514,7 +537,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, retryResp, err := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency) if err != nil { log.Printf("[Antigravity] Retry request failed: %v", err) - if s.shouldFailoverUpstreamError(resp.StatusCode) { + if s.shouldFailoverWithTempUnsched(ctx, account, resp.StatusCode, respBody) { return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} } return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody) @@ -531,7 +554,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, log.Printf("[Antigravity] Retry also failed with status %d: %s", retryResp.StatusCode, string(retryRespBody)) s.handleUpstreamError(ctx, account, retryResp.StatusCode, retryResp.Header, retryRespBody) - if s.shouldFailoverUpstreamError(retryResp.StatusCode) { + if s.shouldFailoverWithTempUnsched(ctx, account, retryResp.StatusCode, retryRespBody) { return nil, &UpstreamFailoverError{StatusCode: retryResp.StatusCode} } return nil, s.writeMappedClaudeError(c, retryResp.StatusCode, retryRespBody) @@ -540,7 +563,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, // 不是 signature 错误,或者已经没有 thinking 块,直接返回错误 if resp.StatusCode >= 400 { - if s.shouldFailoverUpstreamError(resp.StatusCode) { + if s.shouldFailoverWithTempUnsched(ctx, account, resp.StatusCode, respBody) { return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} } @@ -594,8 +617,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co } switch action { - case "generateContent", "streamGenerateContent", "countTokens": + case "generateContent", "streamGenerateContent": // ok + case "countTokens": + return nil, s.writeGoogleError(c, http.StatusNotImplemented, "countTokens is not supported") default: return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action) } @@ -650,18 +675,6 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co sleepAntigravityBackoff(attempt) continue } - if action == "countTokens" { - estimated := estimateGeminiCountTokens(body) - c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) - return &ForwardResult{ - RequestID: "", - Usage: ClaudeUsage{}, - Model: originalModel, - Stream: false, - Duration: time.Since(startTime), - FirstTokenMs: nil, - }, nil - } return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") } @@ -678,18 +691,6 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co if resp.StatusCode == 429 { s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) } - if action == "countTokens" { - estimated := estimateGeminiCountTokens(body) - c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) - return &ForwardResult{ - RequestID: "", - Usage: ClaudeUsage{}, - Model: originalModel, - Stream: false, - Duration: time.Since(startTime), - FirstTokenMs: nil, - }, nil - } resp = &http.Response{ StatusCode: resp.StatusCode, Header: resp.Header.Clone(), @@ -712,20 +713,42 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) - if action == "countTokens" { - estimated := estimateGeminiCountTokens(body) - c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) - return &ForwardResult{ - RequestID: requestID, - Usage: ClaudeUsage{}, - Model: originalModel, - Stream: false, - Duration: time.Since(startTime), - FirstTokenMs: nil, - }, nil + // Check if model fallback is enabled and this is a model not found error + if s.settingService != nil && s.settingService.IsModelFallbackEnabled(ctx) && + isModelNotFoundError(resp.StatusCode, respBody) { + + fallbackModel := s.settingService.GetFallbackModel(ctx, PlatformAntigravity) + + // Only retry if fallback model is different from current model + if fallbackModel != "" && fallbackModel != mappedModel { + log.Printf("[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", + mappedModel, fallbackModel, account.Name) + + // Close original response + _ = resp.Body.Close() + + // Rebuild request with fallback model + fallbackBody, err := s.wrapV1InternalRequest(projectID, fallbackModel, body) + if err == nil { + fallbackReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, fallbackBody) + if err == nil { + fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency) + if err == nil && fallbackResp.StatusCode < 400 { + log.Printf("[Antigravity] Fallback succeeded with %s (account: %s)", fallbackModel, account.Name) + resp = fallbackResp + originalModel = fallbackModel // Update for billing + // Continue to normal response handling + goto handleSuccess + } else if fallbackResp != nil { + _ = fallbackResp.Body.Close() + } + } + } + log.Printf("[Antigravity] Fallback failed, returning original error") + } } - if s.shouldFailoverUpstreamError(resp.StatusCode) { + if s.shouldFailoverWithTempUnsched(ctx, account, resp.StatusCode, respBody) { return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} } @@ -739,6 +762,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode) } +handleSuccess: var usage *ClaudeUsage var firstTokenMs *int @@ -789,6 +813,15 @@ func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int) } } +func (s *AntigravityGatewayService) shouldFailoverWithTempUnsched(ctx context.Context, account *Account, statusCode int, body []byte) bool { + if s.rateLimitService != nil { + if s.rateLimitService.HandleTempUnschedulable(ctx, account, statusCode, body) { + return true + } + } + return s.shouldFailoverUpstreamError(statusCode) +} + func sleepAntigravityBackoff(attempt int) { sleepGeminiBackoff(attempt) // 复用 Gemini 的退避逻辑 } @@ -899,7 +932,10 @@ func (s *AntigravityGatewayService) handleGeminiNonStreamingResponse(c *gin.Cont } // 解包 v1internal 响应 - unwrapped, _ := s.unwrapV1InternalResponse(respBody) + unwrapped := respBody + if inner, unwrapErr := s.unwrapV1InternalResponse(respBody); unwrapErr == nil && inner != nil { + unwrapped = inner + } var parsed map[string]any if json.Unmarshal(unwrapped, &parsed) == nil { @@ -973,6 +1009,8 @@ func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, statusStr = "RESOURCE_EXHAUSTED" case 500: statusStr = "INTERNAL" + case 501: + statusStr = "UNIMPLEMENTED" case 502, 503: statusStr = "UNAVAILABLE" } diff --git a/backend/internal/service/antigravity_model_mapping_test.go b/backend/internal/service/antigravity_model_mapping_test.go index 1e37cdc2..39000e4f 100644 --- a/backend/internal/service/antigravity_model_mapping_test.go +++ b/backend/internal/service/antigravity_model_mapping_test.go @@ -104,28 +104,28 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) { expected: "claude-opus-4-5-thinking", }, { - name: "系统映射 - claude-haiku-4 → gemini-3-flash", + name: "系统映射 - claude-haiku-4 → claude-sonnet-4-5", requestedModel: "claude-haiku-4", accountMapping: nil, - expected: "gemini-3-flash", + expected: "claude-sonnet-4-5", }, { - name: "系统映射 - claude-haiku-4-5 → gemini-3-flash", + name: "系统映射 - claude-haiku-4-5 → claude-sonnet-4-5", requestedModel: "claude-haiku-4-5", accountMapping: nil, - expected: "gemini-3-flash", + expected: "claude-sonnet-4-5", }, { - name: "系统映射 - claude-3-haiku-20240307 → gemini-3-flash", + name: "系统映射 - claude-3-haiku-20240307 → claude-sonnet-4-5", requestedModel: "claude-3-haiku-20240307", accountMapping: nil, - expected: "gemini-3-flash", + expected: "claude-sonnet-4-5", }, { - name: "系统映射 - claude-haiku-4-5-20251001 → gemini-3-flash", + name: "系统映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-5", requestedModel: "claude-haiku-4-5-20251001", accountMapping: nil, - expected: "gemini-3-flash", + expected: "claude-sonnet-4-5", }, { name: "系统映射 - claude-sonnet-4-5-20250929", diff --git a/backend/internal/service/antigravity_quota_fetcher.go b/backend/internal/service/antigravity_quota_fetcher.go new file mode 100644 index 00000000..c0231e99 --- /dev/null +++ b/backend/internal/service/antigravity_quota_fetcher.go @@ -0,0 +1,134 @@ +package service + +import ( + "context" + "fmt" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" +) + +// AntigravityQuotaFetcher 从 Antigravity API 获取额度 +type AntigravityQuotaFetcher struct { + proxyRepo ProxyRepository +} + +// NewAntigravityQuotaFetcher 创建 AntigravityQuotaFetcher +func NewAntigravityQuotaFetcher(proxyRepo ProxyRepository) *AntigravityQuotaFetcher { + return &AntigravityQuotaFetcher{proxyRepo: proxyRepo} +} + +// CanFetch 检查是否可以获取此账户的额度 +func (f *AntigravityQuotaFetcher) CanFetch(account *Account) bool { + if f == nil || account == nil { + return false + } + if account.Platform != PlatformAntigravity { + return false + } + accessToken := account.GetCredential("access_token") + return accessToken != "" +} + +// FetchQuota 获取 Antigravity 账户额度信息 +func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Account, proxyURL string) (*QuotaResult, error) { + if f == nil { + return nil, fmt.Errorf("antigravity quota fetcher is nil") + } + if account == nil { + return nil, fmt.Errorf("account is nil") + } + accessToken := account.GetCredential("access_token") + projectID := account.GetCredential("project_id") + + // 如果没有 project_id,生成一个随机的 + if projectID == "" { + projectID = antigravity.GenerateMockProjectID() + } + + client := antigravity.NewClient(proxyURL) + + // 调用 API 获取配额 + modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID) + if err != nil { + return nil, err + } + + // 转换为 UsageInfo + usageInfo := f.buildUsageInfo(modelsResp) + + return &QuotaResult{ + UsageInfo: usageInfo, + Raw: modelsRaw, + }, nil +} + +// buildUsageInfo 将 API 响应转换为 UsageInfo +func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse) *UsageInfo { + now := time.Now() + info := &UsageInfo{ + UpdatedAt: &now, + AntigravityQuota: make(map[string]*AntigravityModelQuota), + } + + if modelsResp == nil { + return info + } + + // 遍历所有模型,填充 AntigravityQuota + for modelName, modelInfo := range modelsResp.Models { + if modelInfo.QuotaInfo == nil { + continue + } + + // remainingFraction 是剩余比例 (0.0-1.0),转换为使用率百分比 + utilization := clampInt(int((1.0-modelInfo.QuotaInfo.RemainingFraction)*100), 0, 100) + + info.AntigravityQuota[modelName] = &AntigravityModelQuota{ + Utilization: utilization, + ResetTime: modelInfo.QuotaInfo.ResetTime, + } + } + + // 同时设置 FiveHour 用于兼容展示(取主要模型) + priorityModels := []string{"claude-sonnet-4-20250514", "claude-sonnet-4", "gemini-2.5-pro"} + for _, modelName := range priorityModels { + if modelInfo, ok := modelsResp.Models[modelName]; ok && modelInfo.QuotaInfo != nil { + utilization := clampFloat64((1.0-modelInfo.QuotaInfo.RemainingFraction)*100, 0, 100) + progress := &UsageProgress{ + Utilization: utilization, + } + if modelInfo.QuotaInfo.ResetTime != "" { + if resetTime, err := time.Parse(time.RFC3339, modelInfo.QuotaInfo.ResetTime); err == nil { + progress.ResetsAt = &resetTime + progress.RemainingSeconds = remainingSecondsUntil(resetTime) + } + } + info.FiveHour = progress + break + } + } + + return info +} + +// GetProxyURL 获取账户的代理 URL +func (f *AntigravityQuotaFetcher) GetProxyURL(ctx context.Context, account *Account) (string, error) { + if f == nil { + return "", fmt.Errorf("antigravity quota fetcher is nil") + } + if account == nil { + return "", fmt.Errorf("account is nil") + } + if account.ProxyID == nil || f.proxyRepo == nil { + return "", nil + } + proxy, err := f.proxyRepo.GetByID(ctx, *account.ProxyID) + if err != nil { + return "", err + } + if proxy == nil { + return "", nil + } + return proxy.URL(), nil +} diff --git a/backend/internal/service/antigravity_quota_refresher.go b/backend/internal/service/antigravity_quota_refresher.go deleted file mode 100644 index c4b11d73..00000000 --- a/backend/internal/service/antigravity_quota_refresher.go +++ /dev/null @@ -1,222 +0,0 @@ -package service - -import ( - "context" - "log" - "sync" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" -) - -// AntigravityQuotaRefresher 定时刷新 Antigravity 账户的配额信息 -type AntigravityQuotaRefresher struct { - accountRepo AccountRepository - proxyRepo ProxyRepository - cfg *config.TokenRefreshConfig - - stopCh chan struct{} - wg sync.WaitGroup -} - -// NewAntigravityQuotaRefresher 创建配额刷新器 -func NewAntigravityQuotaRefresher( - accountRepo AccountRepository, - proxyRepo ProxyRepository, - _ *AntigravityOAuthService, - cfg *config.Config, -) *AntigravityQuotaRefresher { - return &AntigravityQuotaRefresher{ - accountRepo: accountRepo, - proxyRepo: proxyRepo, - cfg: &cfg.TokenRefresh, - stopCh: make(chan struct{}), - } -} - -// Start 启动后台配额刷新服务 -func (r *AntigravityQuotaRefresher) Start() { - if !r.cfg.Enabled { - log.Println("[AntigravityQuota] Service disabled by configuration") - return - } - - r.wg.Add(1) - go r.refreshLoop() - - log.Printf("[AntigravityQuota] Service started (check every %d minutes)", r.cfg.CheckIntervalMinutes) -} - -// Stop 停止服务 -func (r *AntigravityQuotaRefresher) Stop() { - close(r.stopCh) - r.wg.Wait() - log.Println("[AntigravityQuota] Service stopped") -} - -// refreshLoop 刷新循环 -func (r *AntigravityQuotaRefresher) refreshLoop() { - defer r.wg.Done() - - checkInterval := time.Duration(r.cfg.CheckIntervalMinutes) * time.Minute - if checkInterval < time.Minute { - checkInterval = 5 * time.Minute - } - - ticker := time.NewTicker(checkInterval) - defer ticker.Stop() - - // 启动时立即执行一次 - r.processRefresh() - - for { - select { - case <-ticker.C: - r.processRefresh() - case <-r.stopCh: - return - } - } -} - -// processRefresh 执行一次刷新 -func (r *AntigravityQuotaRefresher) processRefresh() { - ctx := context.Background() - - // 查询所有 active 的账户,然后过滤 antigravity 平台 - allAccounts, err := r.accountRepo.ListActive(ctx) - if err != nil { - log.Printf("[AntigravityQuota] Failed to list accounts: %v", err) - return - } - - // 过滤 antigravity 平台账户 - var accounts []Account - for _, acc := range allAccounts { - if acc.Platform == PlatformAntigravity { - accounts = append(accounts, acc) - } - } - - if len(accounts) == 0 { - return - } - - refreshed, failed := 0, 0 - - for i := range accounts { - account := &accounts[i] - - if err := r.refreshAccountQuota(ctx, account); err != nil { - log.Printf("[AntigravityQuota] Account %d (%s) failed: %v", account.ID, account.Name, err) - failed++ - } else { - refreshed++ - } - } - - log.Printf("[AntigravityQuota] Cycle complete: total=%d, refreshed=%d, failed=%d", - len(accounts), refreshed, failed) -} - -// refreshAccountQuota 刷新单个账户的配额 -func (r *AntigravityQuotaRefresher) refreshAccountQuota(ctx context.Context, account *Account) error { - accessToken := account.GetCredential("access_token") - projectID := account.GetCredential("project_id") - - if accessToken == "" { - return nil // 没有 access_token,跳过 - } - - // token 过期则跳过,由 TokenRefreshService 负责刷新 - if r.isTokenExpired(account) { - return nil - } - - // 获取代理 URL - var proxyURL string - if account.ProxyID != nil { - proxy, err := r.proxyRepo.GetByID(ctx, *account.ProxyID) - if err == nil && proxy != nil { - proxyURL = proxy.URL() - } - } - - client := antigravity.NewClient(proxyURL) - - if account.Extra == nil { - account.Extra = make(map[string]any) - } - - // 获取账户信息(tier、project_id 等) - loadResp, loadRaw, _ := client.LoadCodeAssist(ctx, accessToken) - if loadRaw != nil { - account.Extra["load_code_assist"] = loadRaw - } - if loadResp != nil { - // 尝试从 API 获取 project_id - if projectID == "" && loadResp.CloudAICompanionProject != "" { - projectID = loadResp.CloudAICompanionProject - account.Credentials["project_id"] = projectID - } - } - - // 如果仍然没有 project_id,随机生成一个并保存 - if projectID == "" { - projectID = antigravity.GenerateMockProjectID() - account.Credentials["project_id"] = projectID - log.Printf("[AntigravityQuotaRefresher] 为账户 %d 生成随机 project_id: %s", account.ID, projectID) - } - - // 调用 API 获取配额 - modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID) - if err != nil { - return r.accountRepo.Update(ctx, account) // 保存已有的 load_code_assist 信息 - } - - // 保存完整的配额响应 - if modelsRaw != nil { - account.Extra["available_models"] = modelsRaw - } - - // 解析配额数据为前端使用的格式 - r.updateAccountQuota(account, modelsResp) - - account.Extra["last_refresh"] = time.Now().Format(time.RFC3339) - - // 保存到数据库 - return r.accountRepo.Update(ctx, account) -} - -// isTokenExpired 检查 token 是否过期 -func (r *AntigravityQuotaRefresher) isTokenExpired(account *Account) bool { - expiresAt := account.GetCredentialAsTime("expires_at") - if expiresAt == nil { - return false - } - - // 提前 5 分钟认为过期 - return time.Now().Add(5 * time.Minute).After(*expiresAt) -} - -// updateAccountQuota 更新账户的配额信息(前端使用的格式) -func (r *AntigravityQuotaRefresher) updateAccountQuota(account *Account, modelsResp *antigravity.FetchAvailableModelsResponse) { - quota := make(map[string]any) - - for modelName, modelInfo := range modelsResp.Models { - if modelInfo.QuotaInfo == nil { - continue - } - - // 转换 remainingFraction (0.0-1.0) 为百分比 (0-100) - remaining := int(modelInfo.QuotaInfo.RemainingFraction * 100) - - quota[modelName] = map[string]any{ - "remaining": remaining, - "reset_time": modelInfo.QuotaInfo.ResetTime, - } - } - - account.Extra["quota"] = quota -} diff --git a/backend/internal/service/quota_fetcher.go b/backend/internal/service/quota_fetcher.go new file mode 100644 index 00000000..5c376d70 --- /dev/null +++ b/backend/internal/service/quota_fetcher.go @@ -0,0 +1,21 @@ +package service + +import ( + "context" +) + +// QuotaFetcher 额度获取接口,各平台实现此接口 +type QuotaFetcher interface { + // CanFetch 检查是否可以获取此账户的额度 + CanFetch(account *Account) bool + // GetProxyURL 获取账户的代理 URL(如果没有代理则返回空字符串) + GetProxyURL(ctx context.Context, account *Account) (string, error) + // FetchQuota 获取账户额度信息 + FetchQuota(ctx context.Context, account *Account, proxyURL string) (*QuotaResult, error) +} + +// QuotaResult 额度获取结果 +type QuotaResult struct { + UsageInfo *UsageInfo // 转换后的使用信息 + Raw map[string]any // 原始响应,可存入 account.Extra +}