diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index f37d696b..79827b26 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -39,11 +39,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { if err != nil { return nil, err } - sqlDB, err := infrastructure.ProvideSQLDB(client) + db, err := infrastructure.ProvideSQLDB(client) if err != nil { return nil, err } - userRepository := repository.NewUserRepository(client, sqlDB) + userRepository := repository.NewUserRepository(client, db) settingRepository := repository.NewSettingRepository(client) settingService := service.NewSettingService(settingRepository, configConfig) redisClient := infrastructure.ProvideRedis(configConfig) @@ -57,12 +57,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { authHandler := handler.NewAuthHandler(configConfig, authService, userService) userHandler := handler.NewUserHandler(userService) apiKeyRepository := repository.NewApiKeyRepository(client) - groupRepository := repository.NewGroupRepository(client, sqlDB) + groupRepository := repository.NewGroupRepository(client, db) userSubscriptionRepository := repository.NewUserSubscriptionRepository(client) apiKeyCache := repository.NewApiKeyCache(redisClient) apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) - usageLogRepository := repository.NewUsageLogRepository(client, sqlDB) + usageLogRepository := repository.NewUsageLogRepository(client, db) usageService := service.NewUsageService(usageLogRepository, userRepository) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) redeemCodeRepository := repository.NewRedeemCodeRepository(client) @@ -75,8 +75,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) dashboardService := service.NewDashboardService(usageLogRepository) dashboardHandler := admin.NewDashboardHandler(dashboardService) - accountRepository := repository.NewAccountRepository(client, sqlDB) - proxyRepository := repository.NewProxyRepository(client, sqlDB) + accountRepository := repository.NewAccountRepository(client, db) + proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber() adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber) adminUserHandler := admin.NewUserHandler(adminService) @@ -93,8 +93,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher) geminiTokenCache := repository.NewGeminiTokenCache(redisClient) geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService) + gatewayCache := repository.NewGatewayCache(redisClient) + antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository) + antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService) httpUpstream := repository.NewHTTPUpstream(configConfig) - accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, httpUpstream) + antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream) + accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream) concurrencyCache := repository.NewConcurrencyCache(redisClient) concurrencyService := service.NewConcurrencyService(concurrencyCache) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService) @@ -102,7 +106,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { oAuthHandler := admin.NewOAuthHandler(oAuthService) openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService) geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService) - antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository) antigravityOAuthHandler := admin.NewAntigravityOAuthHandler(antigravityOAuthService) proxyHandler := admin.NewProxyHandler(adminService) adminRedeemHandler := admin.NewRedeemHandler(adminService) @@ -115,7 +118,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService) adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler) - gatewayCache := repository.NewGatewayCache(redisClient) pricingRemoteClient := repository.NewPricingRemoteClient() pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient) if err != nil { @@ -127,8 +129,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { timingWheelService := service.ProvideTimingWheelService() deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService) - antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService) - antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService) diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 5fa2f4e1..ac938f8c 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -918,6 +918,37 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { return } + // Handle Antigravity accounts: return Claude + Gemini models + if account.Platform == service.PlatformAntigravity { + // Antigravity 支持 Claude 和部分 Gemini 模型 + type UnifiedModel struct { + ID string `json:"id"` + Type string `json:"type"` + DisplayName string `json:"display_name"` + } + + var models []UnifiedModel + + // 添加 Claude 模型 + for _, m := range claude.DefaultModels { + models = append(models, UnifiedModel{ + ID: m.ID, + Type: m.Type, + DisplayName: m.DisplayName, + }) + } + + // 添加 Gemini 3 系列模型用于测试 + geminiTestModels := []UnifiedModel{ + {ID: "gemini-3-flash", Type: "model", DisplayName: "Gemini 3 Flash"}, + {ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview"}, + } + models = append(models, geminiTestModels...) + + response.Success(c, models) + return + } + // Handle Claude/Anthropic accounts // For OAuth and Setup-Token accounts: return default models if account.IsOAuth() { diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 6296f2fe..3223eb18 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -44,11 +44,12 @@ type TestEvent struct { // AccountTestService handles account testing operations type AccountTestService struct { - accountRepo AccountRepository - oauthService *OAuthService - openaiOAuthService *OpenAIOAuthService - geminiTokenProvider *GeminiTokenProvider - httpUpstream HTTPUpstream + accountRepo AccountRepository + oauthService *OAuthService + openaiOAuthService *OpenAIOAuthService + geminiTokenProvider *GeminiTokenProvider + antigravityGatewayService *AntigravityGatewayService + httpUpstream HTTPUpstream } // NewAccountTestService creates a new AccountTestService @@ -57,14 +58,16 @@ func NewAccountTestService( oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, geminiTokenProvider *GeminiTokenProvider, + antigravityGatewayService *AntigravityGatewayService, httpUpstream HTTPUpstream, ) *AccountTestService { return &AccountTestService{ - accountRepo: accountRepo, - oauthService: oauthService, - openaiOAuthService: openaiOAuthService, - geminiTokenProvider: geminiTokenProvider, - httpUpstream: httpUpstream, + accountRepo: accountRepo, + oauthService: oauthService, + openaiOAuthService: openaiOAuthService, + geminiTokenProvider: geminiTokenProvider, + antigravityGatewayService: antigravityGatewayService, + httpUpstream: httpUpstream, } } @@ -141,6 +144,10 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int return s.testGeminiAccountConnection(c, account, modelID) } + if account.Platform == PlatformAntigravity { + return s.testAntigravityAccountConnection(c, account, modelID) + } + return s.testClaudeAccountConnection(c, account, modelID) } @@ -457,6 +464,46 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account return s.processGeminiStream(c, resp.Body) } +// testAntigravityAccountConnection tests an Antigravity account's connection +// 支持 Claude 和 Gemini 两种协议,使用非流式请求 +func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error { + ctx := c.Request.Context() + + // 默认模型:Claude 使用 claude-sonnet-4-5,Gemini 使用 gemini-3-pro-preview + testModelID := modelID + if testModelID == "" { + testModelID = "claude-sonnet-4-5" + } + + if s.antigravityGatewayService == nil { + return s.sendErrorAndEnd(c, "Antigravity gateway service not configured") + } + + // Set SSE headers + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + // Send test_start event + s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID}) + + // 调用 AntigravityGatewayService.TestConnection(复用协议转换逻辑) + result, err := s.antigravityGatewayService.TestConnection(ctx, account, testModelID) + if err != nil { + return s.sendErrorAndEnd(c, err.Error()) + } + + // 发送响应内容 + if result.Text != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: result.Text}) + } + + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil +} + // buildGeminiAPIKeyRequest builds request for Gemini API Key accounts func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, account *Account, modelID string, payload []byte) (*http.Request, error) { apiKey := account.GetCredential("api_key") @@ -514,7 +561,12 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun return req, nil } - // Wrap payload in Code Assist format + // Code Assist mode (with project_id) + return s.buildCodeAssistRequest(ctx, accessToken, projectID, modelID, payload) +} + +// buildCodeAssistRequest builds request for Google Code Assist API (used by Gemini CLI and Antigravity) +func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessToken, projectID, modelID string, payload []byte) (*http.Request, error) { var inner map[string]any if err := json.Unmarshal(payload, &inner); err != nil { return nil, err diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 18a67fdf..670f53ee 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -130,6 +130,165 @@ func (s *AntigravityGatewayService) IsModelSupported(requestedModel string) bool return false } +// TestConnectionResult 测试连接结果 +type TestConnectionResult struct { + Text string // 响应文本 + MappedModel string // 实际使用的模型 +} + +// TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费) +// 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择 +func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) { + // 获取 token + if s.tokenProvider == nil { + return nil, errors.New("antigravity token provider not configured") + } + accessToken, err := s.tokenProvider.GetAccessToken(ctx, account) + if err != nil { + return nil, fmt.Errorf("获取 access_token 失败: %w", err) + } + + // 获取 project_id + projectID := strings.TrimSpace(account.GetCredential("project_id")) + if projectID == "" { + return nil, errors.New("project_id not found in credentials") + } + + // 模型映射 + mappedModel := s.getMappedModel(account, modelID) + + // 构建请求体 + var requestBody []byte + if strings.HasPrefix(modelID, "gemini-") { + // Gemini 模型:直接使用 Gemini 格式 + requestBody, err = s.buildGeminiTestRequest(projectID, mappedModel) + } else { + // Claude 模型:使用协议转换 + requestBody, err = s.buildClaudeTestRequest(projectID, mappedModel) + } + if err != nil { + return nil, fmt.Errorf("构建请求失败: %w", err) + } + + // 构建 HTTP 请求(非流式) + fullURL := fmt.Sprintf("%s/v1internal:generateContent", antigravity.BaseURL) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(requestBody)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("User-Agent", antigravity.UserAgent) + + // 代理 URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 发送请求 + resp, err := s.httpUpstream.Do(req, proxyURL) + if err != nil { + return nil, fmt.Errorf("请求失败: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + // 读取响应 + respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + if err != nil { + return nil, fmt.Errorf("读取响应失败: %w", err) + } + + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody)) + } + + // 解包 v1internal 响应 + unwrapped, err := s.unwrapV1InternalResponse(respBody) + if err != nil { + return nil, fmt.Errorf("解包响应失败: %w", err) + } + + // 提取响应文本 + text := extractGeminiResponseText(unwrapped) + + return &TestConnectionResult{ + Text: text, + MappedModel: mappedModel, + }, nil +} + +// buildGeminiTestRequest 构建 Gemini 格式测试请求 +func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model string) ([]byte, error) { + payload := map[string]any{ + "contents": []map[string]any{ + { + "role": "user", + "parts": []map[string]any{ + {"text": "hi"}, + }, + }, + }, + } + payloadBytes, _ := json.Marshal(payload) + return s.wrapV1InternalRequest(projectID, model, payloadBytes) +} + +// buildClaudeTestRequest 构建 Claude 格式测试请求并转换为 Gemini 格式 +func (s *AntigravityGatewayService) buildClaudeTestRequest(projectID, mappedModel string) ([]byte, error) { + claudeReq := &antigravity.ClaudeRequest{ + Model: mappedModel, + Messages: []antigravity.ClaudeMessage{ + { + Role: "user", + Content: json.RawMessage(`"hi"`), + }, + }, + MaxTokens: 1024, + Stream: false, + } + return antigravity.TransformClaudeToGemini(claudeReq, projectID, mappedModel) +} + +// extractGeminiResponseText 从 Gemini 响应中提取文本 +func extractGeminiResponseText(respBody []byte) string { + var resp map[string]any + if err := json.Unmarshal(respBody, &resp); err != nil { + return "" + } + + candidates, ok := resp["candidates"].([]any) + if !ok || len(candidates) == 0 { + return "" + } + + candidate, ok := candidates[0].(map[string]any) + if !ok { + return "" + } + + content, ok := candidate["content"].(map[string]any) + if !ok { + return "" + } + + parts, ok := content["parts"].([]any) + if !ok { + return "" + } + + var texts []string + for _, part := range parts { + if partMap, ok := part.(map[string]any); ok { + if text, ok := partMap["text"].(string); ok && text != "" { + texts = append(texts, text) + } + } + } + + return strings.Join(texts, "") +} + // wrapV1InternalRequest 包装请求为 v1internal 格式 func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model string, originalBody []byte) ([]byte, error) { var request any