diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index 8596b8ba..ff6ab4e6 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -70,7 +70,6 @@ func provideCleanup( openaiOAuth *service.OpenAIOAuthService, geminiOAuth *service.GeminiOAuthService, antigravityOAuth *service.AntigravityOAuthService, - antigravityQuota *service.AntigravityQuotaRefresher, ) func() { return func() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -113,10 +112,6 @@ func provideCleanup( antigravityOAuth.Stop() return nil }}, - {"AntigravityQuotaRefresher", func() error { - antigravityQuota.Stop() - return nil - }}, {"Redis", func() error { return rdb.Close() }}, diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 9eab00e7..53fc1278 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -90,7 +90,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository) rateLimitService := service.NewRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService) claudeUsageFetcher := repository.NewClaudeUsageFetcher() - accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService) + antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository) + usageCache := service.NewUsageCache() + accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache) geminiTokenCache := repository.NewGeminiTokenCache(redisClient) geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService) gatewayCache := repository.NewGatewayCache(redisClient) @@ -145,8 +147,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService) httpServer := server.ProvideHTTPServer(configConfig, engine) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig) - antigravityQuotaRefresher := service.ProvideAntigravityQuotaRefresher(accountRepository, proxyRepository, antigravityOAuthService, configConfig) - v := provideCleanup(client, redisClient, tokenRefreshService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, antigravityQuotaRefresher) + v := provideCleanup(client, redisClient, tokenRefreshService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) application := &Application{ Server: httpServer, Cleanup: v, @@ -179,7 +180,6 @@ func provideCleanup( openaiOAuth *service.OpenAIOAuthService, geminiOAuth *service.GeminiOAuthService, antigravityOAuth *service.AntigravityOAuthService, - antigravityQuota *service.AntigravityQuotaRefresher, ) func() { return func() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -221,10 +221,6 @@ func provideCleanup( antigravityOAuth.Stop() return nil }}, - {"AntigravityQuotaRefresher", func() error { - antigravityQuota.Stop() - return nil - }}, {"Redis", func() error { return rdb.Close() }}, diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index e202eb7f..21a3af56 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -588,8 +588,20 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e // Stream already started, send error as SSE event then close flusher, ok := c.Writer.(http.Flusher) if ok { - // Send error event in SSE format - errorEvent := fmt.Sprintf(`data: {"type": "error", "error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message) + // Send error event in SSE format with proper JSON marshaling + errorData := map[string]any{ + "type": "error", + "error": map[string]string{ + "type": errType, + "message": message, + }, + } + jsonBytes, err := json.Marshal(errorData) + if err != nil { + _ = c.Error(err) + return + } + errorEvent := fmt.Sprintf("data: %s\n\n", string(jsonBytes)) if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil { _ = c.Error(err) } @@ -740,8 +752,27 @@ func sendMockWarmupStream(c *gin.Context, model string) { c.Header("Connection", "keep-alive") c.Header("X-Accel-Buffering", "no") + // Build message_start event with proper JSON marshaling + messageStart := map[string]any{ + "type": "message_start", + "message": map[string]any{ + "id": "msg_mock_warmup", + "type": "message", + "role": "assistant", + "model": model, + "content": []any{}, + "stop_reason": nil, + "stop_sequence": nil, + "usage": map[string]int{ + "input_tokens": 10, + "output_tokens": 0, + }, + }, + } + messageStartJSON, _ := json.Marshal(messageStart) + events := []string{ - `event: message_start` + "\n" + `data: {"message":{"content":[],"id":"msg_mock_warmup","model":"` + model + `","role":"assistant","stop_reason":null,"stop_sequence":null,"type":"message","usage":{"input_tokens":10,"output_tokens":0}},"type":"message_start"}`, + `event: message_start` + "\n" + `data: ` + string(messageStartJSON), `event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`, `event: content_block_delta` + "\n" + `data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`, `event: content_block_delta` + "\n" + `data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`, diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go index 4e049dbb..9d2e4a9d 100644 --- a/backend/internal/handler/gateway_helper.go +++ b/backend/internal/handler/gateway_helper.go @@ -144,6 +144,21 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) defer cancel() + // Try immediate acquire first (avoid unnecessary wait) + var result *service.AcquireResult + var err error + if slotType == "user" { + result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency) + } else { + result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency) + } + if err != nil { + return nil, err + } + if result.Acquired { + return result.ReleaseFunc, nil + } + // Determine if ping is needed (streaming + ping format defined) needPing := isStream && h.pingFormat != "" diff --git a/backend/internal/pkg/antigravity/gemini_types.go b/backend/internal/pkg/antigravity/gemini_types.go index 8e3e3885..67f6c3e7 100644 --- a/backend/internal/pkg/antigravity/gemini_types.go +++ b/backend/internal/pkg/antigravity/gemini_types.go @@ -143,9 +143,10 @@ type GeminiCandidate struct { // GeminiUsageMetadata Gemini 用量元数据 type GeminiUsageMetadata struct { - PromptTokenCount int `json:"promptTokenCount,omitempty"` - CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"` - TotalTokenCount int `json:"totalTokenCount,omitempty"` + PromptTokenCount int `json:"promptTokenCount,omitempty"` + CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"` + CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"` + TotalTokenCount int `json:"totalTokenCount,omitempty"` } // DefaultSafetySettings 默认安全设置(关闭所有过滤) diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index 83b87a32..d662be0e 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -14,16 +14,13 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st // 用于存储 tool_use id -> name 映射 toolIDToName := make(map[string]string) + // 检测是否启用 thinking + isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" + // 只有 Gemini 模型支持 dummy thought workaround // Claude 模型通过 Vertex/Google API 需要有效的 thought signatures allowDummyThought := strings.HasPrefix(mappedModel, "gemini-") - // 检测是否启用 thinking - requestedThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" - // 为避免 Claude 模型的 thought signature/消息块约束导致 400(上游要求 thinking 块开头等), - // 非 Gemini 模型默认不启用 thinking(除非未来支持完整签名链路)。 - isThinkingEnabled := requestedThinkingEnabled && allowDummyThought - // 1. 构建 contents contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought) if err != nil { @@ -34,15 +31,7 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model) // 3. 构建 generationConfig - reqForGen := claudeReq - if requestedThinkingEnabled && !allowDummyThought { - log.Printf("[Warning] Disabling thinking for non-Gemini model in antigravity transform: model=%s", mappedModel) - // shallow copy to avoid mutating caller's request - clone := *claudeReq - clone.Thinking = nil - reqForGen = &clone - } - generationConfig := buildGenerationConfig(reqForGen) + generationConfig := buildGenerationConfig(claudeReq) // 4. 构建 tools tools := buildTools(claudeReq.Tools) @@ -183,34 +172,6 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT // 参考: https://ai.google.dev/gemini-api/docs/thought-signatures const dummyThoughtSignature = "skip_thought_signature_validator" -// isValidThoughtSignature 验证 thought signature 是否有效 -// Claude API 要求 signature 必须是 base64 编码的字符串,长度至少 32 字节 -func isValidThoughtSignature(signature string) bool { - // 空字符串无效 - if signature == "" { - return false - } - - // signature 应该是 base64 编码,长度至少 40 个字符(约 30 字节) - // 参考 Claude API 文档和实际观察到的有效 signature - if len(signature) < 40 { - log.Printf("[Debug] Signature too short: len=%d", len(signature)) - return false - } - - // 检查是否是有效的 base64 字符 - // base64 字符集: A-Z, a-z, 0-9, +, /, = - for i, c := range signature { - if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') && - (c < '0' || c > '9') && c != '+' && c != '/' && c != '=' { - log.Printf("[Debug] Invalid base64 character at position %d: %c (code=%d)", i, c, c) - return false - } - } - - return true -} - // buildParts 构建消息的 parts // allowDummyThought: 只有 Gemini 模型支持 dummy thought signature func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) { @@ -239,30 +200,22 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu } case "thinking": - if allowDummyThought { - // Gemini 模型可以使用 dummy signature - parts = append(parts, GeminiPart{ - Text: block.Thinking, - Thought: true, - ThoughtSignature: dummyThoughtSignature, - }) + part := GeminiPart{ + Text: block.Thinking, + Thought: true, + } + // 保留原有 signature(Claude 模型需要有效的 signature) + if block.Signature != "" { + part.ThoughtSignature = block.Signature + } else if !allowDummyThought { + // Claude 模型需要有效 signature,跳过无 signature 的 thinking block + log.Printf("Warning: skipping thinking block without signature for Claude model") continue + } else { + // Gemini 模型使用 dummy signature + part.ThoughtSignature = dummyThoughtSignature } - - // Claude 模型:仅在提供有效 signature 时保留 thinking block;否则跳过以避免上游校验失败。 - signature := strings.TrimSpace(block.Signature) - if signature == "" || signature == dummyThoughtSignature { - log.Printf("[Warning] Skipping thinking block for Claude model (missing or dummy signature)") - continue - } - if !isValidThoughtSignature(signature) { - log.Printf("[Debug] Thinking signature may be invalid (passing through anyway): len=%d", len(signature)) - } - parts = append(parts, GeminiPart{ - Text: block.Thinking, - Thought: true, - ThoughtSignature: signature, - }) + parts = append(parts, part) case "image": if block.Source != nil && block.Source.Type == "base64" { @@ -433,7 +386,7 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { // 普通工具 var funcDecls []GeminiFunctionDecl - for i, tool := range tools { + for _, tool := range tools { // 跳过无效工具名称 if strings.TrimSpace(tool.Name) == "" { log.Printf("Warning: skipping tool with empty name") @@ -452,10 +405,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { description = tool.Custom.Description inputSchema = tool.Custom.InputSchema - // 调试日志:记录 custom 工具的 schema - if schemaJSON, err := json.Marshal(inputSchema); err == nil { - log.Printf("[Debug] Tool[%d] '%s' (custom) original schema: %s", i, tool.Name, string(schemaJSON)) - } } else { // 标准格式: 从顶层字段获取 description = tool.Description @@ -472,11 +421,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { } } - // 调试日志:记录清理后的 schema - if paramsJSON, err := json.Marshal(params); err == nil { - log.Printf("[Debug] Tool[%d] '%s' cleaned schema: %s", i, tool.Name, string(paramsJSON)) - } - funcDecls = append(funcDecls, GeminiFunctionDecl{ Name: tool.Name, Description: description, @@ -631,11 +575,9 @@ func cleanSchemaValue(value any) any { if k == "additionalProperties" { if boolVal, ok := val.(bool); ok { result[k] = boolVal - log.Printf("[Debug] additionalProperties is bool: %v", boolVal) } else { // 如果是 schema 对象,转换为 false(更安全的默认值) result[k] = false - log.Printf("[Debug] additionalProperties is not bool (type: %T), converting to false", val) } continue } diff --git a/backend/internal/pkg/antigravity/request_transformer_test.go b/backend/internal/pkg/antigravity/request_transformer_test.go index ba07893f..56eebad0 100644 --- a/backend/internal/pkg/antigravity/request_transformer_test.go +++ b/backend/internal/pkg/antigravity/request_transformer_test.go @@ -96,7 +96,7 @@ func TestBuildTools_CustomTypeTools(t *testing.T) { { Type: "custom", Name: "mcp_tool", - Custom: &CustomToolSpec{ + Custom: &ClaudeCustomToolSpec{ Description: "MCP tool description", InputSchema: map[string]any{ "type": "object", @@ -121,7 +121,7 @@ func TestBuildTools_CustomTypeTools(t *testing.T) { { Type: "custom", Name: "custom_tool", - Custom: &CustomToolSpec{ + Custom: &ClaudeCustomToolSpec{ Description: "Custom tool", InputSchema: map[string]any{"type": "object"}, }, @@ -148,7 +148,7 @@ func TestBuildTools_CustomTypeTools(t *testing.T) { { Type: "custom", Name: "invalid_custom", - Custom: &CustomToolSpec{ + Custom: &ClaudeCustomToolSpec{ Description: "Invalid", // InputSchema 为 nil }, diff --git a/backend/internal/pkg/antigravity/response_transformer.go b/backend/internal/pkg/antigravity/response_transformer.go index 799de694..cd7f5f80 100644 --- a/backend/internal/pkg/antigravity/response_transformer.go +++ b/backend/internal/pkg/antigravity/response_transformer.go @@ -232,10 +232,14 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon stopReason = "max_tokens" } + // 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount, + // 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去 usage := ClaudeUsage{} if geminiResp.UsageMetadata != nil { - usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount + cached := geminiResp.UsageMetadata.CachedContentTokenCount + usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + usage.CacheReadInputTokens = cached } // 生成响应 ID diff --git a/backend/internal/pkg/antigravity/stream_transformer.go b/backend/internal/pkg/antigravity/stream_transformer.go index c5d954f5..9fe68a11 100644 --- a/backend/internal/pkg/antigravity/stream_transformer.go +++ b/backend/internal/pkg/antigravity/stream_transformer.go @@ -29,8 +29,9 @@ type StreamingProcessor struct { originalModel string // 累计 usage - inputTokens int - outputTokens int + inputTokens int + outputTokens int + cacheReadTokens int } // NewStreamingProcessor 创建流式响应处理器 @@ -76,9 +77,13 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte { } // 更新 usage + // 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount, + // 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去 if geminiResp.UsageMetadata != nil { - p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount + cached := geminiResp.UsageMetadata.CachedContentTokenCount + p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + p.cacheReadTokens = cached } // 处理 parts @@ -108,8 +113,9 @@ func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) { } usage := &ClaudeUsage{ - InputTokens: p.inputTokens, - OutputTokens: p.outputTokens, + InputTokens: p.inputTokens, + OutputTokens: p.outputTokens, + CacheReadInputTokens: p.cacheReadTokens, } return result.Bytes(), usage @@ -123,8 +129,10 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte usage := ClaudeUsage{} if v1Resp.Response.UsageMetadata != nil { - usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount + cached := v1Resp.Response.UsageMetadata.CachedContentTokenCount + usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + usage.CacheReadInputTokens = cached } responseID := v1Resp.ResponseID @@ -418,8 +426,9 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte { } usage := ClaudeUsage{ - InputTokens: p.inputTokens, - OutputTokens: p.outputTokens, + InputTokens: p.inputTokens, + OutputTokens: p.outputTokens, + CacheReadInputTokens: p.cacheReadTokens, } deltaEvent := map[string]any{ diff --git a/backend/internal/repository/concurrency_cache.go b/backend/internal/repository/concurrency_cache.go index 35296497..95370f51 100644 --- a/backend/internal/repository/concurrency_cache.go +++ b/backend/internal/repository/concurrency_cache.go @@ -151,11 +151,17 @@ var ( return 1 `) - // getAccountsLoadBatchScript - batch load query (read-only) - // ARGV[1] = slot TTL (seconds, retained for compatibility) + // getAccountsLoadBatchScript - batch load query with expired slot cleanup + // ARGV[1] = slot TTL (seconds) // ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ... getAccountsLoadBatchScript = redis.NewScript(` local result = {} + local slotTTL = tonumber(ARGV[1]) + + -- Get current server time + local timeResult = redis.call('TIME') + local nowSeconds = tonumber(timeResult[1]) + local cutoffTime = nowSeconds - slotTTL local i = 2 while i <= #ARGV do @@ -163,6 +169,9 @@ var ( local maxConcurrency = tonumber(ARGV[i + 1]) local slotKey = 'concurrency:account:' .. accountID + + -- Clean up expired slots before counting + redis.call('ZREMRANGEBYSCORE', slotKey, '-inf', cutoffTime) local currentConcurrency = redis.call('ZCARD', slotKey) local waitKey = 'wait:account:' .. accountID diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index dfceac07..c4220c0c 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -69,13 +69,29 @@ type windowStatsCache struct { timestamp time.Time } -var ( - apiCacheMap = sync.Map{} // 缓存 API 响应 - windowStatsCacheMap = sync.Map{} // 缓存窗口统计 +// antigravityUsageCache 缓存 Antigravity 额度数据 +type antigravityUsageCache struct { + usageInfo *UsageInfo + timestamp time.Time +} + +const ( apiCacheTTL = 10 * time.Minute windowStatsCacheTTL = 1 * time.Minute ) +// UsageCache 封装账户使用量相关的缓存 +type UsageCache struct { + apiCache sync.Map // accountID -> *apiUsageCache + windowStatsCache sync.Map // accountID -> *windowStatsCache + antigravityCache sync.Map // accountID -> *antigravityUsageCache +} + +// NewUsageCache 创建 UsageCache 实例 +func NewUsageCache() *UsageCache { + return &UsageCache{} +} + // WindowStats 窗口期统计 type WindowStats struct { Requests int64 `json:"requests"` @@ -91,6 +107,12 @@ type UsageProgress struct { WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量) } +// AntigravityModelQuota Antigravity 单个模型的配额信息 +type AntigravityModelQuota struct { + Utilization int `json:"utilization"` // 使用率 0-100 + ResetTime string `json:"reset_time"` // 重置时间 ISO8601 +} + // UsageInfo 账号使用量信息 type UsageInfo struct { UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间 @@ -99,6 +121,9 @@ type UsageInfo struct { SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口 GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额 GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额 + + // Antigravity 多模型配额 + AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"` } // ClaudeUsageResponse Anthropic API返回的usage结构 @@ -124,19 +149,30 @@ type ClaudeUsageFetcher interface { // AccountUsageService 账号使用量查询服务 type AccountUsageService struct { - accountRepo AccountRepository - usageLogRepo UsageLogRepository - usageFetcher ClaudeUsageFetcher - geminiQuotaService *GeminiQuotaService + accountRepo AccountRepository + usageLogRepo UsageLogRepository + usageFetcher ClaudeUsageFetcher + geminiQuotaService *GeminiQuotaService + antigravityQuotaFetcher *AntigravityQuotaFetcher + cache *UsageCache } // NewAccountUsageService 创建AccountUsageService实例 -func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher, geminiQuotaService *GeminiQuotaService) *AccountUsageService { +func NewAccountUsageService( + accountRepo AccountRepository, + usageLogRepo UsageLogRepository, + usageFetcher ClaudeUsageFetcher, + geminiQuotaService *GeminiQuotaService, + antigravityQuotaFetcher *AntigravityQuotaFetcher, + cache *UsageCache, +) *AccountUsageService { return &AccountUsageService{ - accountRepo: accountRepo, - usageLogRepo: usageLogRepo, - usageFetcher: usageFetcher, - geminiQuotaService: geminiQuotaService, + accountRepo: accountRepo, + usageLogRepo: usageLogRepo, + usageFetcher: usageFetcher, + geminiQuotaService: geminiQuotaService, + antigravityQuotaFetcher: antigravityQuotaFetcher, + cache: cache, } } @@ -154,12 +190,17 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U return s.getGeminiUsage(ctx, account) } + // Antigravity 平台:使用 AntigravityQuotaFetcher 获取额度 + if account.Platform == PlatformAntigravity { + return s.getAntigravityUsage(ctx, account) + } + // 只有oauth类型账号可以通过API获取usage(有profile scope) if account.CanGetUsage() { var apiResp *ClaudeUsageResponse // 1. 检查 API 缓存(10 分钟) - if cached, ok := apiCacheMap.Load(accountID); ok { + if cached, ok := s.cache.apiCache.Load(accountID); ok { if cache, ok := cached.(*apiUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL { apiResp = cache.response } @@ -172,7 +213,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U return nil, err } // 缓存 API 响应 - apiCacheMap.Store(accountID, &apiUsageCache{ + s.cache.apiCache.Store(accountID, &apiUsageCache{ response: apiResp, timestamp: time.Now(), }) @@ -230,6 +271,43 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou return usage, nil } +// getAntigravityUsage 获取 Antigravity 账户额度 +func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *Account) (*UsageInfo, error) { + if s.antigravityQuotaFetcher == nil || !s.antigravityQuotaFetcher.CanFetch(account) { + now := time.Now() + return &UsageInfo{UpdatedAt: &now}, nil + } + + // 1. 检查缓存(10 分钟) + if cached, ok := s.cache.antigravityCache.Load(account.ID); ok { + if cache, ok := cached.(*antigravityUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL { + // 重新计算 RemainingSeconds + usage := cache.usageInfo + if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil { + usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds()) + } + return usage, nil + } + } + + // 2. 获取代理 URL + proxyURL := s.antigravityQuotaFetcher.GetProxyURL(ctx, account) + + // 3. 调用 API 获取额度 + result, err := s.antigravityQuotaFetcher.FetchQuota(ctx, account, proxyURL) + if err != nil { + return nil, fmt.Errorf("fetch antigravity quota failed: %w", err) + } + + // 4. 缓存结果 + s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{ + usageInfo: result.UsageInfo, + timestamp: time.Now(), + }) + + return result.UsageInfo, nil +} + // addWindowStats 为 usage 数据添加窗口期统计 // 使用独立缓存(1 分钟),与 API 缓存分离 func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) { @@ -241,7 +319,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou // 检查窗口统计缓存(1 分钟) var windowStats *WindowStats - if cached, ok := windowStatsCacheMap.Load(account.ID); ok { + if cached, ok := s.cache.windowStatsCache.Load(account.ID); ok { if cache, ok := cached.(*windowStatsCache); ok && time.Since(cache.timestamp) < windowStatsCacheTTL { windowStats = cache.stats } @@ -269,7 +347,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou } // 缓存窗口统计(1 分钟) - windowStatsCacheMap.Store(account.ID, &windowStatsCache{ + s.cache.windowStatsCache.Store(account.ID, &windowStatsCache{ stats: windowStats, timestamp: time.Now(), }) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 3093dd9a..e4843f1b 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -322,9 +322,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, originalModel := claudeReq.Model mappedModel := s.getMappedModel(account, claudeReq.Model) - if mappedModel != claudeReq.Model { - log.Printf("Antigravity model mapping: %s -> %s (account: %s)", claudeReq.Model, mappedModel, account.Name) - } // 获取 access_token if s.tokenProvider == nil { @@ -350,15 +347,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, return nil, fmt.Errorf("transform request: %w", err) } - // 调试:记录转换后的请求体(仅记录前 2000 字符) - if bodyJSON, err := json.Marshal(geminiBody); err == nil { - truncated := string(bodyJSON) - if len(truncated) > 2000 { - truncated = truncated[:2000] + "..." - } - log.Printf("[Debug] Transformed Gemini request: %s", truncated) - } - // 构建上游 action action := "generateContent" if claudeReq.Stream { diff --git a/backend/internal/service/antigravity_quota_fetcher.go b/backend/internal/service/antigravity_quota_fetcher.go new file mode 100644 index 00000000..c9024e33 --- /dev/null +++ b/backend/internal/service/antigravity_quota_fetcher.go @@ -0,0 +1,111 @@ +package service + +import ( + "context" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" +) + +// AntigravityQuotaFetcher 从 Antigravity API 获取额度 +type AntigravityQuotaFetcher struct { + proxyRepo ProxyRepository +} + +// NewAntigravityQuotaFetcher 创建 AntigravityQuotaFetcher +func NewAntigravityQuotaFetcher(proxyRepo ProxyRepository) *AntigravityQuotaFetcher { + return &AntigravityQuotaFetcher{proxyRepo: proxyRepo} +} + +// CanFetch 检查是否可以获取此账户的额度 +func (f *AntigravityQuotaFetcher) CanFetch(account *Account) bool { + if account.Platform != PlatformAntigravity { + return false + } + accessToken := account.GetCredential("access_token") + return accessToken != "" +} + +// FetchQuota 获取 Antigravity 账户额度信息 +func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Account, proxyURL string) (*QuotaResult, error) { + accessToken := account.GetCredential("access_token") + projectID := account.GetCredential("project_id") + + // 如果没有 project_id,生成一个随机的 + if projectID == "" { + projectID = antigravity.GenerateMockProjectID() + } + + client := antigravity.NewClient(proxyURL) + + // 调用 API 获取配额 + modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID) + if err != nil { + return nil, err + } + + // 转换为 UsageInfo + usageInfo := f.buildUsageInfo(modelsResp) + + return &QuotaResult{ + UsageInfo: usageInfo, + Raw: modelsRaw, + }, nil +} + +// buildUsageInfo 将 API 响应转换为 UsageInfo +func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse) *UsageInfo { + now := time.Now() + info := &UsageInfo{ + UpdatedAt: &now, + AntigravityQuota: make(map[string]*AntigravityModelQuota), + } + + // 遍历所有模型,填充 AntigravityQuota + for modelName, modelInfo := range modelsResp.Models { + if modelInfo.QuotaInfo == nil { + continue + } + + // remainingFraction 是剩余比例 (0.0-1.0),转换为使用率百分比 + utilization := int((1.0 - modelInfo.QuotaInfo.RemainingFraction) * 100) + + info.AntigravityQuota[modelName] = &AntigravityModelQuota{ + Utilization: utilization, + ResetTime: modelInfo.QuotaInfo.ResetTime, + } + } + + // 同时设置 FiveHour 用于兼容展示(取主要模型) + priorityModels := []string{"claude-sonnet-4-20250514", "claude-sonnet-4", "gemini-2.5-pro"} + for _, modelName := range priorityModels { + if modelInfo, ok := modelsResp.Models[modelName]; ok && modelInfo.QuotaInfo != nil { + utilization := (1.0 - modelInfo.QuotaInfo.RemainingFraction) * 100 + progress := &UsageProgress{ + Utilization: utilization, + } + if modelInfo.QuotaInfo.ResetTime != "" { + if resetTime, err := time.Parse(time.RFC3339, modelInfo.QuotaInfo.ResetTime); err == nil { + progress.ResetsAt = &resetTime + progress.RemainingSeconds = int(time.Until(resetTime).Seconds()) + } + } + info.FiveHour = progress + break + } + } + + return info +} + +// GetProxyURL 获取账户的代理 URL +func (f *AntigravityQuotaFetcher) GetProxyURL(ctx context.Context, account *Account) string { + if account.ProxyID == nil || f.proxyRepo == nil { + return "" + } + proxy, err := f.proxyRepo.GetByID(ctx, *account.ProxyID) + if err != nil || proxy == nil { + return "" + } + return proxy.URL() +} diff --git a/backend/internal/service/antigravity_quota_refresher.go b/backend/internal/service/antigravity_quota_refresher.go deleted file mode 100644 index c4b11d73..00000000 --- a/backend/internal/service/antigravity_quota_refresher.go +++ /dev/null @@ -1,222 +0,0 @@ -package service - -import ( - "context" - "log" - "sync" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" -) - -// AntigravityQuotaRefresher 定时刷新 Antigravity 账户的配额信息 -type AntigravityQuotaRefresher struct { - accountRepo AccountRepository - proxyRepo ProxyRepository - cfg *config.TokenRefreshConfig - - stopCh chan struct{} - wg sync.WaitGroup -} - -// NewAntigravityQuotaRefresher 创建配额刷新器 -func NewAntigravityQuotaRefresher( - accountRepo AccountRepository, - proxyRepo ProxyRepository, - _ *AntigravityOAuthService, - cfg *config.Config, -) *AntigravityQuotaRefresher { - return &AntigravityQuotaRefresher{ - accountRepo: accountRepo, - proxyRepo: proxyRepo, - cfg: &cfg.TokenRefresh, - stopCh: make(chan struct{}), - } -} - -// Start 启动后台配额刷新服务 -func (r *AntigravityQuotaRefresher) Start() { - if !r.cfg.Enabled { - log.Println("[AntigravityQuota] Service disabled by configuration") - return - } - - r.wg.Add(1) - go r.refreshLoop() - - log.Printf("[AntigravityQuota] Service started (check every %d minutes)", r.cfg.CheckIntervalMinutes) -} - -// Stop 停止服务 -func (r *AntigravityQuotaRefresher) Stop() { - close(r.stopCh) - r.wg.Wait() - log.Println("[AntigravityQuota] Service stopped") -} - -// refreshLoop 刷新循环 -func (r *AntigravityQuotaRefresher) refreshLoop() { - defer r.wg.Done() - - checkInterval := time.Duration(r.cfg.CheckIntervalMinutes) * time.Minute - if checkInterval < time.Minute { - checkInterval = 5 * time.Minute - } - - ticker := time.NewTicker(checkInterval) - defer ticker.Stop() - - // 启动时立即执行一次 - r.processRefresh() - - for { - select { - case <-ticker.C: - r.processRefresh() - case <-r.stopCh: - return - } - } -} - -// processRefresh 执行一次刷新 -func (r *AntigravityQuotaRefresher) processRefresh() { - ctx := context.Background() - - // 查询所有 active 的账户,然后过滤 antigravity 平台 - allAccounts, err := r.accountRepo.ListActive(ctx) - if err != nil { - log.Printf("[AntigravityQuota] Failed to list accounts: %v", err) - return - } - - // 过滤 antigravity 平台账户 - var accounts []Account - for _, acc := range allAccounts { - if acc.Platform == PlatformAntigravity { - accounts = append(accounts, acc) - } - } - - if len(accounts) == 0 { - return - } - - refreshed, failed := 0, 0 - - for i := range accounts { - account := &accounts[i] - - if err := r.refreshAccountQuota(ctx, account); err != nil { - log.Printf("[AntigravityQuota] Account %d (%s) failed: %v", account.ID, account.Name, err) - failed++ - } else { - refreshed++ - } - } - - log.Printf("[AntigravityQuota] Cycle complete: total=%d, refreshed=%d, failed=%d", - len(accounts), refreshed, failed) -} - -// refreshAccountQuota 刷新单个账户的配额 -func (r *AntigravityQuotaRefresher) refreshAccountQuota(ctx context.Context, account *Account) error { - accessToken := account.GetCredential("access_token") - projectID := account.GetCredential("project_id") - - if accessToken == "" { - return nil // 没有 access_token,跳过 - } - - // token 过期则跳过,由 TokenRefreshService 负责刷新 - if r.isTokenExpired(account) { - return nil - } - - // 获取代理 URL - var proxyURL string - if account.ProxyID != nil { - proxy, err := r.proxyRepo.GetByID(ctx, *account.ProxyID) - if err == nil && proxy != nil { - proxyURL = proxy.URL() - } - } - - client := antigravity.NewClient(proxyURL) - - if account.Extra == nil { - account.Extra = make(map[string]any) - } - - // 获取账户信息(tier、project_id 等) - loadResp, loadRaw, _ := client.LoadCodeAssist(ctx, accessToken) - if loadRaw != nil { - account.Extra["load_code_assist"] = loadRaw - } - if loadResp != nil { - // 尝试从 API 获取 project_id - if projectID == "" && loadResp.CloudAICompanionProject != "" { - projectID = loadResp.CloudAICompanionProject - account.Credentials["project_id"] = projectID - } - } - - // 如果仍然没有 project_id,随机生成一个并保存 - if projectID == "" { - projectID = antigravity.GenerateMockProjectID() - account.Credentials["project_id"] = projectID - log.Printf("[AntigravityQuotaRefresher] 为账户 %d 生成随机 project_id: %s", account.ID, projectID) - } - - // 调用 API 获取配额 - modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID) - if err != nil { - return r.accountRepo.Update(ctx, account) // 保存已有的 load_code_assist 信息 - } - - // 保存完整的配额响应 - if modelsRaw != nil { - account.Extra["available_models"] = modelsRaw - } - - // 解析配额数据为前端使用的格式 - r.updateAccountQuota(account, modelsResp) - - account.Extra["last_refresh"] = time.Now().Format(time.RFC3339) - - // 保存到数据库 - return r.accountRepo.Update(ctx, account) -} - -// isTokenExpired 检查 token 是否过期 -func (r *AntigravityQuotaRefresher) isTokenExpired(account *Account) bool { - expiresAt := account.GetCredentialAsTime("expires_at") - if expiresAt == nil { - return false - } - - // 提前 5 分钟认为过期 - return time.Now().Add(5 * time.Minute).After(*expiresAt) -} - -// updateAccountQuota 更新账户的配额信息(前端使用的格式) -func (r *AntigravityQuotaRefresher) updateAccountQuota(account *Account, modelsResp *antigravity.FetchAvailableModelsResponse) { - quota := make(map[string]any) - - for modelName, modelInfo := range modelsResp.Models { - if modelInfo.QuotaInfo == nil { - continue - } - - // 转换 remainingFraction (0.0-1.0) 为百分比 (0-100) - remaining := int(modelInfo.QuotaInfo.RemainingFraction * 100) - - quota[modelName] = map[string]any{ - "remaining": remaining, - "reset_time": modelInfo.QuotaInfo.ResetTime, - } - } - - account.Extra["quota"] = quota -} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 3e6f876c..8f1bf756 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -206,7 +206,7 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { // BindStickySession sets session -> account binding with standard TTL. func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error { - if sessionHash == "" || accountID <= 0 { + if sessionHash == "" || accountID <= 0 || s.cache == nil { return nil } return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL) @@ -431,7 +431,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } // ============ Layer 1: 粘性会话优先 ============ - if sessionHash != "" { + if sessionHash != "" && s.cache != nil { accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) if err == nil && accountID > 0 && !isExcluded(accountID) { account, err := s.accountRepo.GetByID(ctx, accountID) diff --git a/backend/internal/service/quota_fetcher.go b/backend/internal/service/quota_fetcher.go new file mode 100644 index 00000000..40d8572c --- /dev/null +++ b/backend/internal/service/quota_fetcher.go @@ -0,0 +1,19 @@ +package service + +import ( + "context" +) + +// QuotaFetcher 额度获取接口,各平台实现此接口 +type QuotaFetcher interface { + // CanFetch 检查是否可以获取此账户的额度 + CanFetch(account *Account) bool + // FetchQuota 获取账户额度信息 + FetchQuota(ctx context.Context, account *Account, proxyURL string) (*QuotaResult, error) +} + +// QuotaResult 额度获取结果 +type QuotaResult struct { + UsageInfo *UsageInfo // 转换后的使用信息 + Raw map[string]any // 原始响应,可存入 account.Extra +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 7971f041..f52c2a4a 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -54,18 +54,6 @@ func ProvideTimingWheelService() *TimingWheelService { return svc } -// ProvideAntigravityQuotaRefresher creates and starts AntigravityQuotaRefresher -func ProvideAntigravityQuotaRefresher( - accountRepo AccountRepository, - proxyRepo ProxyRepository, - oauthSvc *AntigravityOAuthService, - cfg *config.Config, -) *AntigravityQuotaRefresher { - svc := NewAntigravityQuotaRefresher(accountRepo, proxyRepo, oauthSvc, cfg) - svc.Start() - return svc -} - // ProvideDeferredService creates and starts DeferredService func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWheelService) *DeferredService { svc := NewDeferredService(accountRepo, timingWheel, 10*time.Second) @@ -124,6 +112,7 @@ var ProviderSet = wire.NewSet( ProvideTokenRefreshService, ProvideTimingWheelService, ProvideDeferredService, - ProvideAntigravityQuotaRefresher, + NewAntigravityQuotaFetcher, NewUserAttributeService, + NewUsageCache, ) diff --git a/frontend/src/components/account/AccountUsageCell.vue b/frontend/src/components/account/AccountUsageCell.vue index 8dfb9f38..b0bc6c32 100644 --- a/frontend/src/components/account/AccountUsageCell.vue +++ b/frontend/src/components/account/AccountUsageCell.vue @@ -93,7 +93,7 @@