diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index b27d0535..438864be 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -122,8 +122,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { timingWheelService := service.ProvideTimingWheelService() deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService) - geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream) - gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, userService, concurrencyService, billingCacheService) + antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService) + antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream) + geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService) + gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) @@ -133,7 +135,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService) engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService) httpServer := server.ProvideHTTPServer(configConfig, engine) - tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) + tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig) v := provideCleanup(db, client, tokenRefreshService, pricingService, emailQueueService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) application := &Application{ Server: httpServer, diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index a0a4f05e..9c77bafa 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -21,27 +21,30 @@ import ( // GatewayHandler handles API gateway requests type GatewayHandler struct { - gatewayService *service.GatewayService - geminiCompatService *service.GeminiMessagesCompatService - userService *service.UserService - billingCacheService *service.BillingCacheService - concurrencyHelper *ConcurrencyHelper + gatewayService *service.GatewayService + geminiCompatService *service.GeminiMessagesCompatService + antigravityGatewayService *service.AntigravityGatewayService + userService *service.UserService + billingCacheService *service.BillingCacheService + concurrencyHelper *ConcurrencyHelper } // NewGatewayHandler creates a new GatewayHandler func NewGatewayHandler( gatewayService *service.GatewayService, geminiCompatService *service.GeminiMessagesCompatService, + antigravityGatewayService *service.AntigravityGatewayService, userService *service.UserService, concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService, ) *GatewayHandler { return &GatewayHandler{ - gatewayService: gatewayService, - geminiCompatService: geminiCompatService, - userService: userService, - billingCacheService: billingCacheService, - concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude), + gatewayService: gatewayService, + geminiCompatService: geminiCompatService, + antigravityGatewayService: antigravityGatewayService, + userService: userService, + billingCacheService: billingCacheService, + concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude), } } @@ -163,8 +166,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } - // 转发请求 - result, err := h.geminiCompatService.Forward(c.Request.Context(), c, account, body) + // 转发请求 - 根据账号平台分流 + var result *service.ForwardResult + if account.Platform == service.PlatformAntigravity { + result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, req.Model, "generateContent", req.Stream, body) + } else { + result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body) + } if accountReleaseFunc != nil { accountReleaseFunc() } @@ -240,8 +248,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } - // 转发请求 - result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) + // 转发请求 - 根据账号平台分流 + var result *service.ForwardResult + if account.Platform == service.PlatformAntigravity { + result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body) + } else { + result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body) + } if accountReleaseFunc != nil { accountReleaseFunc() } diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 53625669..613d4c86 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -32,6 +32,13 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) { account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID) if err != nil { + // 没有 gemini 账户,检查是否有 antigravity 账户可用 + hasAntigravity, _ := h.geminiCompatService.HasAntigravityAccounts(c.Request.Context(), apiKey.GroupID) + if hasAntigravity { + // antigravity 账户使用静态模型列表 + c.JSON(http.StatusOK, gemini.FallbackModelsList()) + return + } googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) return } @@ -69,6 +76,13 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) { account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID) if err != nil { + // 没有 gemini 账户,检查是否有 antigravity 账户可用 + hasAntigravity, _ := h.geminiCompatService.HasAntigravityAccounts(c.Request.Context(), apiKey.GroupID) + if hasAntigravity { + // antigravity 账户使用静态模型信息 + c.JSON(http.StatusOK, gemini.FallbackModel(modelName)) + return + } googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) return } @@ -182,8 +196,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { return } - // 5) forward (writes response to client) - result, err := h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body) + // 5) forward (根据平台分流) + var result *service.ForwardResult + if account.Platform == service.PlatformAntigravity { + result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, modelName, action, stream, body) + } else { + result, err = h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body) + } if accountReleaseFunc != nil { accountReleaseFunc() } diff --git a/backend/internal/handler/gemini_v1beta_handler_test.go b/backend/internal/handler/gemini_v1beta_handler_test.go new file mode 100644 index 00000000..82b30ee4 --- /dev/null +++ b/backend/internal/handler/gemini_v1beta_handler_test.go @@ -0,0 +1,143 @@ +//go:build unit + +package handler + +import ( + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +// TestGeminiV1BetaHandler_PlatformRoutingInvariant 文档化并验证 Handler 层的平台路由逻辑不变量 +// 该测试确保 gemini 和 antigravity 平台的路由逻辑符合预期 +func TestGeminiV1BetaHandler_PlatformRoutingInvariant(t *testing.T) { + tests := []struct { + name string + platform string + expectedService string + description string + }{ + { + name: "Gemini平台使用ForwardNative", + platform: service.PlatformGemini, + expectedService: "GeminiMessagesCompatService.ForwardNative", + description: "Gemini OAuth 账户直接调用 Google API", + }, + { + name: "Antigravity平台使用ForwardGemini", + platform: service.PlatformAntigravity, + expectedService: "AntigravityGatewayService.ForwardGemini", + description: "Antigravity 账户通过 CRS 中转,支持 Gemini 协议", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 模拟 GeminiV1BetaModels 中的路由决策 (lines 199-205 in gemini_v1beta_handler.go) + var routedService string + if tt.platform == service.PlatformAntigravity { + routedService = "AntigravityGatewayService.ForwardGemini" + } else { + routedService = "GeminiMessagesCompatService.ForwardNative" + } + + require.Equal(t, tt.expectedService, routedService, + "平台 %s 应该路由到 %s: %s", + tt.platform, tt.expectedService, tt.description) + }) + } +} + +// TestGeminiV1BetaHandler_ListModelsAntigravityFallback 验证 ListModels 的 antigravity 降级逻辑 +// 当没有 gemini 账户但有 antigravity 账户时,应返回静态模型列表 +func TestGeminiV1BetaHandler_ListModelsAntigravityFallback(t *testing.T) { + tests := []struct { + name string + hasGeminiAccount bool + hasAntigravity bool + expectedBehavior string + }{ + { + name: "有Gemini账户-调用ForwardAIStudioGET", + hasGeminiAccount: true, + hasAntigravity: false, + expectedBehavior: "forward_to_upstream", + }, + { + name: "无Gemini有Antigravity-返回静态列表", + hasGeminiAccount: false, + hasAntigravity: true, + expectedBehavior: "static_fallback", + }, + { + name: "无任何账户-返回503", + hasGeminiAccount: false, + hasAntigravity: false, + expectedBehavior: "service_unavailable", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 模拟 GeminiV1BetaListModels 的逻辑 (lines 33-44 in gemini_v1beta_handler.go) + var behavior string + + if tt.hasGeminiAccount { + behavior = "forward_to_upstream" + } else if tt.hasAntigravity { + behavior = "static_fallback" + } else { + behavior = "service_unavailable" + } + + require.Equal(t, tt.expectedBehavior, behavior) + }) + } +} + +// TestGeminiV1BetaHandler_GetModelAntigravityFallback 验证 GetModel 的 antigravity 降级逻辑 +func TestGeminiV1BetaHandler_GetModelAntigravityFallback(t *testing.T) { + tests := []struct { + name string + hasGeminiAccount bool + hasAntigravity bool + expectedBehavior string + }{ + { + name: "有Gemini账户-调用ForwardAIStudioGET", + hasGeminiAccount: true, + hasAntigravity: false, + expectedBehavior: "forward_to_upstream", + }, + { + name: "无Gemini有Antigravity-返回静态模型信息", + hasGeminiAccount: false, + hasAntigravity: true, + expectedBehavior: "static_model_info", + }, + { + name: "无任何账户-返回503", + hasGeminiAccount: false, + hasAntigravity: false, + expectedBehavior: "service_unavailable", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 模拟 GeminiV1BetaGetModel 的逻辑 (lines 77-87 in gemini_v1beta_handler.go) + var behavior string + + if tt.hasGeminiAccount { + behavior = "forward_to_upstream" + } else if tt.hasAntigravity { + behavior = "static_model_info" + } else { + behavior = "service_unavailable" + } + + require.Equal(t, tt.expectedBehavior, behavior) + }) + } +} diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index fe6053ee..326aa45d 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -337,6 +337,56 @@ func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Cont return outAccounts, nil } +func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) { + if len(platforms) == 0 { + return nil, nil + } + var accounts []accountModel + now := time.Now() + err := r.db.WithContext(ctx). + Where("platform IN ?", platforms). + Where("status = ? AND schedulable = ?", service.StatusActive, true). + Where("(overload_until IS NULL OR overload_until <= ?)", now). + Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now). + Preload("Proxy"). + Order("priority ASC"). + Find(&accounts).Error + if err != nil { + return nil, err + } + outAccounts := make([]service.Account, 0, len(accounts)) + for i := range accounts { + outAccounts = append(outAccounts, *accountModelToService(&accounts[i])) + } + return outAccounts, nil +} + +func (r *accountRepository) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) { + if len(platforms) == 0 { + return nil, nil + } + var accounts []accountModel + now := time.Now() + err := r.db.WithContext(ctx). + Joins("JOIN account_groups ON account_groups.account_id = accounts.id"). + Where("account_groups.group_id = ?", groupID). + Where("accounts.platform IN ?", platforms). + Where("accounts.status = ? AND accounts.schedulable = ?", service.StatusActive, true). + Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now). + Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now). + Preload("Proxy"). + Order("account_groups.priority ASC, accounts.priority ASC"). + Find(&accounts).Error + if err != nil { + return nil, err + } + outAccounts := make([]service.Account, 0, len(accounts)) + for i := range accounts { + outAccounts = append(outAccounts, *accountModelToService(&accounts[i])) + } + return outAccounts, nil +} + func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { now := time.Now() return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id). diff --git a/backend/internal/repository/gateway_routing_integration_test.go b/backend/internal/repository/gateway_routing_integration_test.go new file mode 100644 index 00000000..46a22f9c --- /dev/null +++ b/backend/internal/repository/gateway_routing_integration_test.go @@ -0,0 +1,250 @@ +//go:build integration + +package repository + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/suite" + "gorm.io/datatypes" + "gorm.io/gorm" +) + +// GatewayRoutingSuite 测试网关路由相关的数据库查询 +// 验证账户选择和分流逻辑在真实数据库环境下的行为 +type GatewayRoutingSuite struct { + suite.Suite + ctx context.Context + db *gorm.DB + accountRepo *accountRepository +} + +func (s *GatewayRoutingSuite) SetupTest() { + s.ctx = context.Background() + s.db = testTx(s.T()) + s.accountRepo = NewAccountRepository(s.db).(*accountRepository) +} + +func TestGatewayRoutingSuite(t *testing.T) { + suite.Run(t, new(GatewayRoutingSuite)) +} + +// TestListSchedulableByPlatforms_GeminiAndAntigravity 验证多平台账户查询 +func (s *GatewayRoutingSuite) TestListSchedulableByPlatforms_GeminiAndAntigravity() { + // 创建各平台账户 + geminiAcc := mustCreateAccount(s.T(), s.db, &accountModel{ + Name: "gemini-oauth", + Platform: service.PlatformGemini, + Type: service.AccountTypeOAuth, + Status: service.StatusActive, + Schedulable: true, + Priority: 1, + }) + + antigravityAcc := mustCreateAccount(s.T(), s.db, &accountModel{ + Name: "antigravity-oauth", + Platform: service.PlatformAntigravity, + Type: service.AccountTypeOAuth, + Status: service.StatusActive, + Schedulable: true, + Priority: 2, + Credentials: datatypes.JSONMap{ + "access_token": "test-token", + "refresh_token": "test-refresh", + "project_id": "test-project", + }, + }) + + // 创建不应被选中的 anthropic 账户 + mustCreateAccount(s.T(), s.db, &accountModel{ + Name: "anthropic-oauth", + Platform: service.PlatformAnthropic, + Type: service.AccountTypeOAuth, + Status: service.StatusActive, + Schedulable: true, + Priority: 0, + }) + + // 查询 gemini + antigravity 平台 + accounts, err := s.accountRepo.ListSchedulableByPlatforms(s.ctx, []string{ + service.PlatformGemini, + service.PlatformAntigravity, + }) + + s.Require().NoError(err) + s.Require().Len(accounts, 2, "应返回 gemini 和 antigravity 两个账户") + + // 验证返回的账户平台 + platforms := make(map[string]bool) + for _, acc := range accounts { + platforms[acc.Platform] = true + } + s.Require().True(platforms[service.PlatformGemini], "应包含 gemini 账户") + s.Require().True(platforms[service.PlatformAntigravity], "应包含 antigravity 账户") + s.Require().False(platforms[service.PlatformAnthropic], "不应包含 anthropic 账户") + + // 验证账户 ID 匹配 + ids := make(map[int64]bool) + for _, acc := range accounts { + ids[acc.ID] = true + } + s.Require().True(ids[geminiAcc.ID]) + s.Require().True(ids[antigravityAcc.ID]) +} + +// TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding 验证按分组过滤 +func (s *GatewayRoutingSuite) TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding() { + // 创建 gemini 分组 + group := mustCreateGroup(s.T(), s.db, &groupModel{ + Name: "gemini-group", + Platform: service.PlatformGemini, + Status: service.StatusActive, + }) + + // 创建账户 + boundAcc := mustCreateAccount(s.T(), s.db, &accountModel{ + Name: "bound-antigravity", + Platform: service.PlatformAntigravity, + Status: service.StatusActive, + Schedulable: true, + }) + unboundAcc := mustCreateAccount(s.T(), s.db, &accountModel{ + Name: "unbound-antigravity", + Platform: service.PlatformAntigravity, + Status: service.StatusActive, + Schedulable: true, + }) + + // 只绑定一个账户到分组 + mustBindAccountToGroup(s.T(), s.db, boundAcc.ID, group.ID, 1) + + // 查询分组内的账户 + accounts, err := s.accountRepo.ListSchedulableByGroupIDAndPlatforms(s.ctx, group.ID, []string{ + service.PlatformGemini, + service.PlatformAntigravity, + }) + + s.Require().NoError(err) + s.Require().Len(accounts, 1, "应只返回绑定到分组的账户") + s.Require().Equal(boundAcc.ID, accounts[0].ID) + + // 确认未绑定的账户不在结果中 + for _, acc := range accounts { + s.Require().NotEqual(unboundAcc.ID, acc.ID, "不应包含未绑定的账户") + } +} + +// TestListSchedulableByPlatform_Antigravity 验证单平台查询 +func (s *GatewayRoutingSuite) TestListSchedulableByPlatform_Antigravity() { + // 创建多种平台账户 + mustCreateAccount(s.T(), s.db, &accountModel{ + Name: "gemini-1", + Platform: service.PlatformGemini, + Status: service.StatusActive, + Schedulable: true, + }) + + antigravity := mustCreateAccount(s.T(), s.db, &accountModel{ + Name: "antigravity-1", + Platform: service.PlatformAntigravity, + Status: service.StatusActive, + Schedulable: true, + }) + + // 只查询 antigravity 平台 + accounts, err := s.accountRepo.ListSchedulableByPlatform(s.ctx, service.PlatformAntigravity) + + s.Require().NoError(err) + s.Require().Len(accounts, 1) + s.Require().Equal(antigravity.ID, accounts[0].ID) + s.Require().Equal(service.PlatformAntigravity, accounts[0].Platform) +} + +// TestSchedulableFilter_ExcludesInactive 验证不可调度账户被过滤 +func (s *GatewayRoutingSuite) TestSchedulableFilter_ExcludesInactive() { + // 创建可调度账户 + activeAcc := mustCreateAccount(s.T(), s.db, &accountModel{ + Name: "active-antigravity", + Platform: service.PlatformAntigravity, + Status: service.StatusActive, + Schedulable: true, + }) + + // 创建不可调度账户(需要先创建再更新,因为 fixture 默认设置 Schedulable=true) + inactiveAcc := mustCreateAccount(s.T(), s.db, &accountModel{ + Name: "inactive-antigravity", + Platform: service.PlatformAntigravity, + Status: service.StatusActive, + }) + s.Require().NoError(s.db.Model(&accountModel{}).Where("id = ?", inactiveAcc.ID).Update("schedulable", false).Error) + + // 创建错误状态账户 + mustCreateAccount(s.T(), s.db, &accountModel{ + Name: "error-antigravity", + Platform: service.PlatformAntigravity, + Status: service.StatusError, + Schedulable: true, + }) + + accounts, err := s.accountRepo.ListSchedulableByPlatform(s.ctx, service.PlatformAntigravity) + + s.Require().NoError(err) + s.Require().Len(accounts, 1, "应只返回可调度的 active 账户") + s.Require().Equal(activeAcc.ID, accounts[0].ID) +} + +// TestPlatformRoutingDecision 验证平台路由决策 +// 这个测试模拟 Handler 层在选择账户后的路由决策逻辑 +func (s *GatewayRoutingSuite) TestPlatformRoutingDecision() { + // 创建两种平台的账户 + geminiAcc := mustCreateAccount(s.T(), s.db, &accountModel{ + Name: "gemini-route-test", + Platform: service.PlatformGemini, + Status: service.StatusActive, + Schedulable: true, + }) + + antigravityAcc := mustCreateAccount(s.T(), s.db, &accountModel{ + Name: "antigravity-route-test", + Platform: service.PlatformAntigravity, + Status: service.StatusActive, + Schedulable: true, + }) + + tests := []struct { + name string + accountID int64 + expectedService string + }{ + { + name: "Gemini账户路由到ForwardNative", + accountID: geminiAcc.ID, + expectedService: "GeminiMessagesCompatService.ForwardNative", + }, + { + name: "Antigravity账户路由到ForwardGemini", + accountID: antigravityAcc.ID, + expectedService: "AntigravityGatewayService.ForwardGemini", + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + // 从数据库获取账户 + account, err := s.accountRepo.GetByID(s.ctx, tt.accountID) + s.Require().NoError(err) + + // 模拟 Handler 层的路由决策 + var routedService string + if account.Platform == service.PlatformAntigravity { + routedService = "AntigravityGatewayService.ForwardGemini" + } else { + routedService = "GeminiMessagesCompatService.ForwardNative" + } + + s.Require().Equal(tt.expectedService, routedService) + }) + } +} diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index be70987c..5eb81faf 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -38,6 +38,8 @@ type AccountRepository interface { ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) + ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) + ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error SetOverloaded(ctx context.Context, id int64, until time.Time) error diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go new file mode 100644 index 00000000..f41301c5 --- /dev/null +++ b/backend/internal/service/antigravity_gateway_service.go @@ -0,0 +1,845 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/gin-gonic/gin" + "github.com/google/uuid" +) + +const ( + antigravityStickySessionTTL = time.Hour + antigravityMaxRetries = 5 + antigravityRetryBaseDelay = 1 * time.Second + antigravityRetryMaxDelay = 16 * time.Second +) + +// Antigravity 直接支持的模型 +var antigravitySupportedModels = map[string]bool{ + "claude-opus-4-5-thinking": true, + "claude-sonnet-4-5": true, + "claude-sonnet-4-5-thinking": true, + "gemini-2.5-flash": true, + "gemini-2.5-flash-lite": true, + "gemini-2.5-flash-thinking": true, + "gemini-3-flash": true, + "gemini-3-pro-low": true, + "gemini-3-pro-high": true, + "gemini-3-pro-preview": true, + "gemini-3-pro-image": true, +} + +// Antigravity 系统默认模型映射表(不支持 → 支持) +var antigravityModelMapping = map[string]string{ + "claude-3-5-sonnet-20241022": "claude-sonnet-4-5", + "claude-3-5-sonnet-20240620": "claude-sonnet-4-5", + "claude-sonnet-4-5-20250929": "claude-sonnet-4-5-thinking", + "claude-opus-4": "claude-opus-4-5-thinking", + "claude-opus-4-5-20251101": "claude-opus-4-5-thinking", + "claude-haiku-4": "claude-sonnet-4-5", + "claude-3-haiku-20240307": "claude-sonnet-4-5", + "claude-haiku-4-5-20251001": "claude-sonnet-4-5", +} + +// AntigravityGatewayService 处理 Antigravity 平台的 API 转发 +type AntigravityGatewayService struct { + accountRepo AccountRepository + cache GatewayCache + tokenProvider *AntigravityTokenProvider + rateLimitService *RateLimitService + httpUpstream HTTPUpstream +} + +func NewAntigravityGatewayService( + accountRepo AccountRepository, + cache GatewayCache, + tokenProvider *AntigravityTokenProvider, + rateLimitService *RateLimitService, + httpUpstream HTTPUpstream, +) *AntigravityGatewayService { + return &AntigravityGatewayService{ + accountRepo: accountRepo, + cache: cache, + tokenProvider: tokenProvider, + rateLimitService: rateLimitService, + httpUpstream: httpUpstream, + } +} + +// GetTokenProvider 返回 token provider +func (s *AntigravityGatewayService) GetTokenProvider() *AntigravityTokenProvider { + return s.tokenProvider +} + +// getMappedModel 获取映射后的模型名 +func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedModel string) string { + // 1. 优先使用账户级映射(复用现有方法) + if mapped := account.GetMappedModel(requestedModel); mapped != requestedModel { + return mapped + } + + // 2. 系统默认映射 + if mapped, ok := antigravityModelMapping[requestedModel]; ok { + return mapped + } + + // 3. Gemini 模型透传 + if strings.HasPrefix(requestedModel, "gemini-") { + return requestedModel + } + + // 4. Claude 前缀透传直接支持的模型 + if antigravitySupportedModels[requestedModel] { + return requestedModel + } + + // 5. 默认值 + return "claude-sonnet-4-5" +} + +// IsModelSupported 检查模型是否被支持 +func (s *AntigravityGatewayService) IsModelSupported(requestedModel string) bool { + // 直接支持的模型 + if antigravitySupportedModels[requestedModel] { + return true + } + // 可映射的模型 + if _, ok := antigravityModelMapping[requestedModel]; ok { + return true + } + // Gemini 前缀透传 + if strings.HasPrefix(requestedModel, "gemini-") { + return true + } + // Claude 模型支持(通过默认映射) + if strings.HasPrefix(requestedModel, "claude-") { + return true + } + return false +} + +// wrapV1InternalRequest 包装请求为 v1internal 格式 +func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model string, originalBody []byte) ([]byte, error) { + var request any + if err := json.Unmarshal(originalBody, &request); err != nil { + return nil, fmt.Errorf("解析请求体失败: %w", err) + } + + wrapped := map[string]any{ + "project": projectID, + "requestId": "agent-" + uuid.New().String(), + "userAgent": "sub2api", + "requestType": "agent", + "model": model, + "request": request, + } + + return json.Marshal(wrapped) +} + +// unwrapV1InternalResponse 解包 v1internal 响应 +func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byte, error) { + var outer map[string]any + if err := json.Unmarshal(body, &outer); err != nil { + return nil, err + } + + if resp, ok := outer["response"]; ok { + return json.Marshal(resp) + } + + return body, nil +} + +// unwrapSSELine 解包 SSE 行中的 v1internal 响应 +func (s *AntigravityGatewayService) unwrapSSELine(line string) string { + if !strings.HasPrefix(line, "data: ") { + return line + } + + data := strings.TrimPrefix(line, "data: ") + if data == "" || data == "[DONE]" { + return line + } + + var outer map[string]any + if err := json.Unmarshal([]byte(data), &outer); err != nil { + return line + } + + if resp, ok := outer["response"]; ok { + unwrapped, err := json.Marshal(resp) + if err != nil { + return line + } + return "data: " + string(unwrapped) + } + + return line +} + +// Forward 转发 Claude 协议请求 +func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { + startTime := time.Now() + + // 解析请求获取 model 和 stream + var req struct { + Model string `json:"model"` + Stream bool `json:"stream"` + } + if err := json.Unmarshal(body, &req); err != nil { + return nil, fmt.Errorf("parse request: %w", err) + } + if strings.TrimSpace(req.Model) == "" { + return nil, fmt.Errorf("missing model") + } + + originalModel := req.Model + mappedModel := s.getMappedModel(account, req.Model) + if mappedModel != req.Model { + log.Printf("Antigravity model mapping: %s -> %s (account: %s)", req.Model, mappedModel, account.Name) + } + + // 获取 access_token + if s.tokenProvider == nil { + return nil, errors.New("antigravity token provider not configured") + } + accessToken, err := s.tokenProvider.GetAccessToken(ctx, account) + if err != nil { + return nil, fmt.Errorf("获取 access_token 失败: %w", err) + } + + // 获取 project_id + projectID := strings.TrimSpace(account.GetCredential("project_id")) + if projectID == "" { + return nil, errors.New("project_id not found in credentials") + } + + // 代理 URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 包装请求 + wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, body) + if err != nil { + return nil, err + } + + // 构建上游 URL + action := "generateContent" + if req.Stream { + action = "streamGenerateContent" + } + fullURL := fmt.Sprintf("%s/v1internal:%s", antigravity.BaseURL, action) + if req.Stream { + fullURL += "?alt=sse" + } + + // 重试循环 + var resp *http.Response + for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBody)) + if err != nil { + return nil, err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) + upstreamReq.Header.Set("User-Agent", antigravity.UserAgent) + + resp, err = s.httpUpstream.Do(upstreamReq, proxyURL) + if err != nil { + if attempt < antigravityMaxRetries { + log.Printf("Antigravity account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, antigravityMaxRetries, err) + sleepAntigravityBackoff(attempt) + continue + } + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") + } + + if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + + if resp.StatusCode == 429 { + s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + if attempt < antigravityMaxRetries { + log.Printf("Antigravity account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, antigravityMaxRetries) + sleepAntigravityBackoff(attempt) + continue + } + // 最后一次尝试也失败 + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break + } + + break + } + defer func() { _ = resp.Body.Close() }() + + // 处理错误响应 + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + + if s.shouldFailoverUpstreamError(resp.StatusCode) { + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + } + + return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody) + } + + requestID := resp.Header.Get("x-request-id") + if requestID != "" { + c.Header("x-request-id", requestID) + } + + var usage *ClaudeUsage + var firstTokenMs *int + if req.Stream { + streamRes, err := s.handleStreamingResponse(c, resp, startTime, originalModel) + if err != nil { + return nil, err + } + usage = streamRes.usage + firstTokenMs = streamRes.firstTokenMs + } else { + usage, err = s.handleNonStreamingResponse(c, resp, originalModel) + if err != nil { + return nil, err + } + } + + return &ForwardResult{ + RequestID: requestID, + Usage: *usage, + Model: originalModel, // 使用原始模型用于计费和日志 + Stream: req.Stream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + }, nil +} + +// ForwardGemini 转发 Gemini 协议请求 +func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) { + startTime := time.Now() + + if strings.TrimSpace(originalModel) == "" { + return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing model in URL") + } + if strings.TrimSpace(action) == "" { + return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing action in URL") + } + if len(body) == 0 { + return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty") + } + + switch action { + case "generateContent", "streamGenerateContent", "countTokens": + // ok + default: + return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action) + } + + mappedModel := s.getMappedModel(account, originalModel) + + // 获取 access_token + if s.tokenProvider == nil { + return nil, errors.New("antigravity token provider not configured") + } + accessToken, err := s.tokenProvider.GetAccessToken(ctx, account) + if err != nil { + return nil, fmt.Errorf("获取 access_token 失败: %w", err) + } + + // 获取 project_id + projectID := strings.TrimSpace(account.GetCredential("project_id")) + if projectID == "" { + return nil, errors.New("project_id not found in credentials") + } + + // 代理 URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 包装请求 + wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, body) + if err != nil { + return nil, err + } + + // 构建上游 URL + upstreamAction := action + if action == "generateContent" && stream { + upstreamAction = "streamGenerateContent" + } + fullURL := fmt.Sprintf("%s/v1internal:%s", antigravity.BaseURL, upstreamAction) + if stream || upstreamAction == "streamGenerateContent" { + fullURL += "?alt=sse" + } + + // 重试循环 + var resp *http.Response + for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBody)) + if err != nil { + return nil, err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) + upstreamReq.Header.Set("User-Agent", antigravity.UserAgent) + + resp, err = s.httpUpstream.Do(upstreamReq, proxyURL) + if err != nil { + if attempt < antigravityMaxRetries { + log.Printf("Antigravity account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, antigravityMaxRetries, err) + sleepAntigravityBackoff(attempt) + continue + } + if action == "countTokens" { + estimated := estimateGeminiCountTokens(body) + c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) + return &ForwardResult{ + RequestID: "", + Usage: ClaudeUsage{}, + Model: originalModel, + Stream: false, + Duration: time.Since(startTime), + FirstTokenMs: nil, + }, nil + } + return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") + } + + if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + + if resp.StatusCode == 429 { + s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + if attempt < antigravityMaxRetries { + log.Printf("Antigravity account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, antigravityMaxRetries) + sleepAntigravityBackoff(attempt) + continue + } + if action == "countTokens" { + estimated := estimateGeminiCountTokens(body) + c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) + return &ForwardResult{ + RequestID: "", + Usage: ClaudeUsage{}, + Model: originalModel, + Stream: false, + Duration: time.Since(startTime), + FirstTokenMs: nil, + }, nil + } + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break + } + + break + } + defer func() { _ = resp.Body.Close() }() + + requestID := resp.Header.Get("x-request-id") + if requestID != "" { + c.Header("x-request-id", requestID) + } + + // 处理错误响应 + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + + if action == "countTokens" { + estimated := estimateGeminiCountTokens(body) + c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) + return &ForwardResult{ + RequestID: requestID, + Usage: ClaudeUsage{}, + Model: originalModel, + Stream: false, + Duration: time.Since(startTime), + FirstTokenMs: nil, + }, nil + } + + if s.shouldFailoverUpstreamError(resp.StatusCode) { + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + } + + // 解包并返回错误 + unwrapped, _ := s.unwrapV1InternalResponse(respBody) + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "application/json" + } + c.Data(resp.StatusCode, contentType, unwrapped) + return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode) + } + + var usage *ClaudeUsage + var firstTokenMs *int + + if stream || upstreamAction == "streamGenerateContent" { + streamRes, err := s.handleGeminiStreamingResponse(c, resp, startTime) + if err != nil { + return nil, err + } + usage = streamRes.usage + firstTokenMs = streamRes.firstTokenMs + } else { + usageResp, err := s.handleGeminiNonStreamingResponse(c, resp) + if err != nil { + return nil, err + } + usage = usageResp + } + + if usage == nil { + usage = &ClaudeUsage{} + } + + return &ForwardResult{ + RequestID: requestID, + Usage: *usage, + Model: originalModel, + Stream: stream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + }, nil +} + +func (s *AntigravityGatewayService) shouldRetryUpstreamError(statusCode int) bool { + switch statusCode { + case 429, 500, 502, 503, 504, 529: + return true + default: + return false + } +} + +func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int) bool { + switch statusCode { + case 401, 403, 429, 529: + return true + default: + return statusCode >= 500 + } +} + +func sleepAntigravityBackoff(attempt int) { + sleepGeminiBackoff(attempt) // 复用 Gemini 的退避逻辑 +} + +func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) { + if s.rateLimitService == nil { + return + } + s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body) +} + +type antigravityStreamResult struct { + usage *ClaudeUsage + firstTokenMs *int +} + +func (s *AntigravityGatewayService) handleStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + c.Status(http.StatusOK) + + flusher, ok := c.Writer.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") + } + + usage := &ClaudeUsage{} + var firstTokenMs *int + reader := bufio.NewReader(resp.Body) + + for { + line, err := reader.ReadString('\n') + if err != nil && !errors.Is(err, io.EOF) { + return nil, fmt.Errorf("stream read error: %w", err) + } + + if len(line) > 0 { + // 解包 v1internal 响应 + unwrapped := s.unwrapSSELine(strings.TrimRight(line, "\r\n")) + + // 解析 usage + if strings.HasPrefix(unwrapped, "data: ") { + data := strings.TrimPrefix(unwrapped, "data: ") + if data != "" && data != "[DONE]" { + if firstTokenMs == nil { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + s.parseClaudeSSEUsage(data, usage) + } + } + + // 写入响应 + if _, writeErr := fmt.Fprintf(c.Writer, "%s\n", unwrapped); writeErr != nil { + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, writeErr + } + flusher.Flush() + } + + if errors.Is(err, io.EOF) { + break + } + } + + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil +} + +func (s *AntigravityGatewayService) handleNonStreamingResponse(c *gin.Context, resp *http.Response, originalModel string) (*ClaudeUsage, error) { + body, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20)) + if err != nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response") + } + + // 解包 v1internal 响应 + unwrapped, err := s.unwrapV1InternalResponse(body) + if err != nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") + } + + // 解析 usage + var respObj struct { + Usage ClaudeUsage `json:"usage"` + } + _ = json.Unmarshal(unwrapped, &respObj) + + c.Data(http.StatusOK, "application/json", unwrapped) + return &respObj.Usage, nil +} + +func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) { + c.Status(resp.StatusCode) + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "text/event-stream; charset=utf-8" + } + c.Header("Content-Type", contentType) + + flusher, ok := c.Writer.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") + } + + reader := bufio.NewReader(resp.Body) + usage := &ClaudeUsage{} + var firstTokenMs *int + + for { + line, err := reader.ReadString('\n') + if len(line) > 0 { + trimmed := strings.TrimRight(line, "\r\n") + if strings.HasPrefix(trimmed, "data:") { + payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:")) + if payload == "" || payload == "[DONE]" { + _, _ = io.WriteString(c.Writer, line) + flusher.Flush() + } else { + // 解包 v1internal 响应 + inner, parseErr := s.unwrapV1InternalResponse([]byte(payload)) + if parseErr == nil && inner != nil { + payload = string(inner) + } + + // 解析 usage + var parsed map[string]any + if json.Unmarshal(inner, &parsed) == nil { + if u := extractGeminiUsage(parsed); u != nil { + usage = u + } + } + + if firstTokenMs == nil { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", payload) + flusher.Flush() + } + } else { + _, _ = io.WriteString(c.Writer, line) + flusher.Flush() + } + } + + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return nil, err + } + } + + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil +} + +func (s *AntigravityGatewayService) handleGeminiNonStreamingResponse(c *gin.Context, resp *http.Response) (*ClaudeUsage, error) { + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + // 解包 v1internal 响应 + unwrapped, _ := s.unwrapV1InternalResponse(respBody) + + var parsed map[string]any + if json.Unmarshal(unwrapped, &parsed) == nil { + if u := extractGeminiUsage(parsed); u != nil { + c.Data(resp.StatusCode, "application/json", unwrapped) + return u, nil + } + } + + c.Data(resp.StatusCode, "application/json", unwrapped) + return &ClaudeUsage{}, nil +} + +func (s *AntigravityGatewayService) parseClaudeSSEUsage(data string, usage *ClaudeUsage) { + // 解析 message_start 获取 input tokens + var msgStart struct { + Type string `json:"type"` + Message struct { + Usage ClaudeUsage `json:"usage"` + } `json:"message"` + } + if json.Unmarshal([]byte(data), &msgStart) == nil && msgStart.Type == "message_start" { + usage.InputTokens = msgStart.Message.Usage.InputTokens + usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens + usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens + } + + // 解析 message_delta 获取 output tokens + var msgDelta struct { + Type string `json:"type"` + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens"` + CacheReadInputTokens int `json:"cache_read_input_tokens"` + } `json:"usage"` + } + if json.Unmarshal([]byte(data), &msgDelta) == nil && msgDelta.Type == "message_delta" { + usage.OutputTokens = msgDelta.Usage.OutputTokens + if usage.InputTokens == 0 { + usage.InputTokens = msgDelta.Usage.InputTokens + } + if usage.CacheCreationInputTokens == 0 { + usage.CacheCreationInputTokens = msgDelta.Usage.CacheCreationInputTokens + } + if usage.CacheReadInputTokens == 0 { + usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens + } + } +} + +func (s *AntigravityGatewayService) writeClaudeError(c *gin.Context, status int, errType, message string) error { + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{"type": errType, "message": message}, + }) + return fmt.Errorf("%s", message) +} + +func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, upstreamStatus int, body []byte) error { + var statusCode int + var errType, errMsg string + + switch upstreamStatus { + case 400: + statusCode = http.StatusBadRequest + errType = "invalid_request_error" + errMsg = "Invalid request" + case 401: + statusCode = http.StatusBadGateway + errType = "authentication_error" + errMsg = "Upstream authentication failed" + case 403: + statusCode = http.StatusBadGateway + errType = "permission_error" + errMsg = "Upstream access forbidden" + case 429: + statusCode = http.StatusTooManyRequests + errType = "rate_limit_error" + errMsg = "Upstream rate limit exceeded" + case 529: + statusCode = http.StatusServiceUnavailable + errType = "overloaded_error" + errMsg = "Upstream service overloaded" + default: + statusCode = http.StatusBadGateway + errType = "upstream_error" + errMsg = "Upstream request failed" + } + + c.JSON(statusCode, gin.H{ + "type": "error", + "error": gin.H{"type": errType, "message": errMsg}, + }) + return fmt.Errorf("upstream error: %d", upstreamStatus) +} + +func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error { + statusStr := "UNKNOWN" + switch status { + case 400: + statusStr = "INVALID_ARGUMENT" + case 404: + statusStr = "NOT_FOUND" + case 429: + statusStr = "RESOURCE_EXHAUSTED" + case 500: + statusStr = "INTERNAL" + case 502, 503: + statusStr = "UNAVAILABLE" + } + + c.JSON(status, gin.H{ + "error": gin.H{ + "code": status, + "message": message, + "status": statusStr, + }, + }) + return fmt.Errorf("%s", message) +} diff --git a/backend/internal/service/antigravity_model_mapping_test.go b/backend/internal/service/antigravity_model_mapping_test.go new file mode 100644 index 00000000..a6dd701b --- /dev/null +++ b/backend/internal/service/antigravity_model_mapping_test.go @@ -0,0 +1,257 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsAntigravityModelSupported(t *testing.T) { + tests := []struct { + name string + model string + expected bool + }{ + // 直接支持的模型 + {"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true}, + {"直接支持 - claude-opus-4-5-thinking", "claude-opus-4-5-thinking", true}, + {"直接支持 - claude-sonnet-4-5-thinking", "claude-sonnet-4-5-thinking", true}, + {"直接支持 - gemini-2.5-flash", "gemini-2.5-flash", true}, + {"直接支持 - gemini-2.5-flash-lite", "gemini-2.5-flash-lite", true}, + {"直接支持 - gemini-3-pro-high", "gemini-3-pro-high", true}, + + // 可映射的模型 + {"可映射 - claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20241022", true}, + {"可映射 - claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20240620", true}, + {"可映射 - claude-opus-4", "claude-opus-4", true}, + {"可映射 - claude-haiku-4", "claude-haiku-4", true}, + {"可映射 - claude-3-haiku-20240307", "claude-3-haiku-20240307", true}, + + // Gemini 前缀透传 + {"Gemini前缀 - gemini-1.5-pro", "gemini-1.5-pro", true}, + {"Gemini前缀 - gemini-unknown-model", "gemini-unknown-model", true}, + {"Gemini前缀 - gemini-future-version", "gemini-future-version", true}, + + // Claude 前缀兜底 + {"Claude前缀 - claude-unknown-model", "claude-unknown-model", true}, + {"Claude前缀 - claude-3-opus-20240229", "claude-3-opus-20240229", true}, + {"Claude前缀 - claude-future-version", "claude-future-version", true}, + + // 不支持的模型 + {"不支持 - gpt-4", "gpt-4", false}, + {"不支持 - gpt-4o", "gpt-4o", false}, + {"不支持 - llama-3", "llama-3", false}, + {"不支持 - mistral-7b", "mistral-7b", false}, + {"不支持 - 空字符串", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsAntigravityModelSupported(tt.model) + require.Equal(t, tt.expected, got, "model: %s", tt.model) + }) + } +} + +func TestAntigravityGatewayService_GetMappedModel(t *testing.T) { + svc := &AntigravityGatewayService{} + + tests := []struct { + name string + requestedModel string + accountMapping map[string]string + expected string + }{ + // 1. 账户级映射优先(注意:model_mapping 在 credentials 中存储为 map[string]any) + { + name: "账户映射优先", + requestedModel: "claude-3-5-sonnet-20241022", + accountMapping: map[string]string{"claude-3-5-sonnet-20241022": "custom-model"}, + expected: "custom-model", + }, + { + name: "账户映射覆盖系统映射", + requestedModel: "claude-opus-4", + accountMapping: map[string]string{"claude-opus-4": "my-opus"}, + expected: "my-opus", + }, + + // 2. 系统默认映射 + { + name: "系统映射 - claude-3-5-sonnet-20241022", + requestedModel: "claude-3-5-sonnet-20241022", + accountMapping: nil, + expected: "claude-sonnet-4-5", + }, + { + name: "系统映射 - claude-3-5-sonnet-20240620", + requestedModel: "claude-3-5-sonnet-20240620", + accountMapping: nil, + expected: "claude-sonnet-4-5", + }, + { + name: "系统映射 - claude-opus-4", + requestedModel: "claude-opus-4", + accountMapping: nil, + expected: "claude-opus-4-5-thinking", + }, + { + name: "系统映射 - claude-opus-4-5-20251101", + requestedModel: "claude-opus-4-5-20251101", + accountMapping: nil, + expected: "claude-opus-4-5-thinking", + }, + { + name: "系统映射 - claude-haiku-4", + requestedModel: "claude-haiku-4", + accountMapping: nil, + expected: "claude-sonnet-4-5", + }, + { + name: "系统映射 - claude-3-haiku-20240307", + requestedModel: "claude-3-haiku-20240307", + accountMapping: nil, + expected: "claude-sonnet-4-5", + }, + { + name: "系统映射 - claude-sonnet-4-5-20250929", + requestedModel: "claude-sonnet-4-5-20250929", + accountMapping: nil, + expected: "claude-sonnet-4-5-thinking", + }, + + // 3. Gemini 透传 + { + name: "Gemini透传 - gemini-2.5-flash", + requestedModel: "gemini-2.5-flash", + accountMapping: nil, + expected: "gemini-2.5-flash", + }, + { + name: "Gemini透传 - gemini-1.5-pro", + requestedModel: "gemini-1.5-pro", + accountMapping: nil, + expected: "gemini-1.5-pro", + }, + { + name: "Gemini透传 - gemini-future-model", + requestedModel: "gemini-future-model", + accountMapping: nil, + expected: "gemini-future-model", + }, + + // 4. 直接支持的模型 + { + name: "直接支持 - claude-sonnet-4-5", + requestedModel: "claude-sonnet-4-5", + accountMapping: nil, + expected: "claude-sonnet-4-5", + }, + { + name: "直接支持 - claude-opus-4-5-thinking", + requestedModel: "claude-opus-4-5-thinking", + accountMapping: nil, + expected: "claude-opus-4-5-thinking", + }, + { + name: "直接支持 - claude-sonnet-4-5-thinking", + requestedModel: "claude-sonnet-4-5-thinking", + accountMapping: nil, + expected: "claude-sonnet-4-5-thinking", + }, + + // 5. 默认值 fallback(未知 claude 模型) + { + name: "默认值 - claude-unknown", + requestedModel: "claude-unknown", + accountMapping: nil, + expected: "claude-sonnet-4-5", + }, + { + name: "默认值 - claude-3-opus-20240229", + requestedModel: "claude-3-opus-20240229", + accountMapping: nil, + expected: "claude-sonnet-4-5", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account := &Account{ + Platform: PlatformAntigravity, + } + if tt.accountMapping != nil { + // GetModelMapping 期望 model_mapping 是 map[string]any 格式 + mappingAny := make(map[string]any) + for k, v := range tt.accountMapping { + mappingAny[k] = v + } + account.Credentials = map[string]any{ + "model_mapping": mappingAny, + } + } + + got := svc.getMappedModel(account, tt.requestedModel) + require.Equal(t, tt.expected, got, "model: %s", tt.requestedModel) + }) + } +} + +func TestAntigravityGatewayService_GetMappedModel_EdgeCases(t *testing.T) { + svc := &AntigravityGatewayService{} + + tests := []struct { + name string + requestedModel string + expected string + }{ + // 空字符串回退到默认值 + {"空字符串", "", "claude-sonnet-4-5"}, + + // 非 claude/gemini 前缀回退到默认值 + {"非claude/gemini前缀 - gpt", "gpt-4", "claude-sonnet-4-5"}, + {"非claude/gemini前缀 - llama", "llama-3", "claude-sonnet-4-5"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account := &Account{Platform: PlatformAntigravity} + got := svc.getMappedModel(account, tt.requestedModel) + require.Equal(t, tt.expected, got) + }) + } +} + +func TestAntigravityGatewayService_IsModelSupported(t *testing.T) { + svc := &AntigravityGatewayService{} + + tests := []struct { + name string + model string + expected bool + }{ + // 直接支持 + {"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true}, + {"直接支持 - gemini-3-flash", "gemini-3-flash", true}, + + // 可映射 + {"可映射 - claude-opus-4", "claude-opus-4", true}, + + // 前缀透传 + {"Gemini前缀", "gemini-unknown", true}, + {"Claude前缀", "claude-unknown", true}, + + // 不支持 + {"不支持 - gpt-4", "gpt-4", false}, + {"不支持 - 空字符串", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := svc.IsModelSupported(tt.model) + require.Equal(t, tt.expected, got) + }) + } +} diff --git a/backend/internal/service/antigravity_token_provider.go b/backend/internal/service/antigravity_token_provider.go new file mode 100644 index 00000000..724b940d --- /dev/null +++ b/backend/internal/service/antigravity_token_provider.go @@ -0,0 +1,145 @@ +package service + +import ( + "context" + "errors" + "log" + "strconv" + "strings" + "time" +) + +const ( + antigravityTokenRefreshSkew = 3 * time.Minute + antigravityTokenCacheSkew = 5 * time.Minute +) + +// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义) +type AntigravityTokenCache = GeminiTokenCache + +// AntigravityTokenProvider 管理 Antigravity 账户的 access_token +type AntigravityTokenProvider struct { + accountRepo AccountRepository + tokenCache AntigravityTokenCache + antigravityOAuthService *AntigravityOAuthService +} + +func NewAntigravityTokenProvider( + accountRepo AccountRepository, + tokenCache AntigravityTokenCache, + antigravityOAuthService *AntigravityOAuthService, +) *AntigravityTokenProvider { + return &AntigravityTokenProvider{ + accountRepo: accountRepo, + tokenCache: tokenCache, + antigravityOAuthService: antigravityOAuthService, + } +} + +// GetAccessToken 获取有效的 access_token +func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) { + if account == nil { + return "", errors.New("account is nil") + } + if account.Platform != PlatformAntigravity || account.Type != AccountTypeOAuth { + return "", errors.New("not an antigravity oauth account") + } + + cacheKey := antigravityTokenCacheKey(account) + + // 1. 先尝试缓存 + if p.tokenCache != nil { + if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + return token, nil + } + } + + // 2. 如果即将过期则刷新 + expiresAt := parseAntigravityExpiresAt(account) + needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew + if needsRefresh && p.tokenCache != nil { + locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) + if err == nil && locked { + defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() + + // 拿到锁后再次检查缓存(另一个 worker 可能已刷新) + if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + return token, nil + } + + // 从数据库获取最新账户信息 + fresh, err := p.accountRepo.GetByID(ctx, account.ID) + if err == nil && fresh != nil { + account = fresh + } + expiresAt = parseAntigravityExpiresAt(account) + if expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew { + if p.antigravityOAuthService == nil { + return "", errors.New("antigravity oauth service not configured") + } + tokenInfo, err := p.antigravityOAuthService.RefreshAccountToken(ctx, account) + if err != nil { + return "", err + } + newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo) + for k, v := range account.Credentials { + if _, exists := newCredentials[k]; !exists { + newCredentials[k] = v + } + } + account.Credentials = newCredentials + if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { + log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr) + } + expiresAt = parseAntigravityExpiresAt(account) + } + } + } + + accessToken := account.GetCredential("access_token") + if strings.TrimSpace(accessToken) == "" { + return "", errors.New("access_token not found in credentials") + } + + // 3. 存入缓存 + if p.tokenCache != nil { + ttl := 30 * time.Minute + if expiresAt != nil { + until := time.Until(*expiresAt) + switch { + case until > antigravityTokenCacheSkew: + ttl = until - antigravityTokenCacheSkew + case until > 0: + ttl = until + default: + ttl = time.Minute + } + } + _ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl) + } + + return accessToken, nil +} + +func antigravityTokenCacheKey(account *Account) string { + projectID := strings.TrimSpace(account.GetCredential("project_id")) + if projectID != "" { + return "ag:" + projectID + } + return "ag:account:" + strconv.FormatInt(account.ID, 10) +} + +func parseAntigravityExpiresAt(account *Account) *time.Time { + raw := strings.TrimSpace(account.GetCredential("expires_at")) + if raw == "" { + return nil + } + if unixSec, err := strconv.ParseInt(raw, 10, 64); err == nil && unixSec > 0 { + t := time.Unix(unixSec, 0) + return &t + } + if t, err := time.Parse(time.RFC3339, raw); err == nil { + return &t + } + return nil +} diff --git a/backend/internal/service/antigravity_token_refresher.go b/backend/internal/service/antigravity_token_refresher.go new file mode 100644 index 00000000..1d2b8f15 --- /dev/null +++ b/backend/internal/service/antigravity_token_refresher.go @@ -0,0 +1,57 @@ +package service + +import ( + "context" + "strconv" + "time" +) + +// AntigravityTokenRefresher 实现 TokenRefresher 接口 +type AntigravityTokenRefresher struct { + antigravityOAuthService *AntigravityOAuthService +} + +func NewAntigravityTokenRefresher(antigravityOAuthService *AntigravityOAuthService) *AntigravityTokenRefresher { + return &AntigravityTokenRefresher{ + antigravityOAuthService: antigravityOAuthService, + } +} + +// CanRefresh 检查是否可以刷新此账户 +func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool { + return account.Platform == PlatformAntigravity && account.Type == AccountTypeOAuth +} + +// NeedsRefresh 检查账户是否需要刷新 +func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool { + if !r.CanRefresh(account) { + return false + } + expiresAtStr := account.GetCredential("expires_at") + if expiresAtStr == "" { + return false + } + expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64) + if err != nil { + return false + } + expiryTime := time.Unix(expiresAt, 0) + return time.Until(expiryTime) < refreshWindow +} + +// Refresh 执行 token 刷新 +func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) { + tokenInfo, err := r.antigravityOAuthService.RefreshAccountToken(ctx, account) + if err != nil { + return nil, err + } + + newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo) + for k, v := range account.Credentials { + if _, exists := newCredentials[k]; !exists { + newCredentials[k] = v + } + } + + return newCredentials, nil +} diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go new file mode 100644 index 00000000..df424f25 --- /dev/null +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -0,0 +1,565 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +// mockAccountRepoForMultiplatform 多平台测试用的 mock +type mockAccountRepoForMultiplatform struct { + accounts []Account + accountsByID map[int64]*Account + listPlatformsFunc func(ctx context.Context, platforms []string) ([]Account, error) +} + +func (m *mockAccountRepoForMultiplatform) GetByID(ctx context.Context, id int64) (*Account, error) { + if acc, ok := m.accountsByID[id]; ok { + return acc, nil + } + return nil, errors.New("account not found") +} + +func (m *mockAccountRepoForMultiplatform) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) { + if m.listPlatformsFunc != nil { + return m.listPlatformsFunc(ctx, platforms) + } + // 过滤符合平台的账户 + var result []Account + platformSet := make(map[string]bool) + for _, p := range platforms { + platformSet[p] = true + } + for _, acc := range m.accounts { + if platformSet[acc.Platform] && acc.IsSchedulable() { + result = append(result, acc) + } + } + return result, nil +} + +func (m *mockAccountRepoForMultiplatform) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) { + return m.ListSchedulableByPlatforms(ctx, platforms) +} + +// Stub methods to implement AccountRepository interface +func (m *mockAccountRepoForMultiplatform) Create(ctx context.Context, account *Account) error { + return nil +} +func (m *mockAccountRepoForMultiplatform) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) { + return nil, nil +} +func (m *mockAccountRepoForMultiplatform) Update(ctx context.Context, account *Account) error { + return nil +} +func (m *mockAccountRepoForMultiplatform) Delete(ctx context.Context, id int64) error { return nil } +func (m *mockAccountRepoForMultiplatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (m *mockAccountRepoForMultiplatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (m *mockAccountRepoForMultiplatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { + return nil, nil +} +func (m *mockAccountRepoForMultiplatform) ListActive(ctx context.Context) ([]Account, error) { + return nil, nil +} +func (m *mockAccountRepoForMultiplatform) ListByPlatform(ctx context.Context, platform string) ([]Account, error) { + return nil, nil +} +func (m *mockAccountRepoForMultiplatform) UpdateLastUsed(ctx context.Context, id int64) error { + return nil +} +func (m *mockAccountRepoForMultiplatform) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error { + return nil +} +func (m *mockAccountRepoForMultiplatform) SetError(ctx context.Context, id int64, errorMsg string) error { + return nil +} +func (m *mockAccountRepoForMultiplatform) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { + return nil +} +func (m *mockAccountRepoForMultiplatform) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error { + return nil +} +func (m *mockAccountRepoForMultiplatform) ListSchedulable(ctx context.Context) ([]Account, error) { + return nil, nil +} +func (m *mockAccountRepoForMultiplatform) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) { + return nil, nil +} +func (m *mockAccountRepoForMultiplatform) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) { + return nil, nil +} +func (m *mockAccountRepoForMultiplatform) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) { + return nil, nil +} +func (m *mockAccountRepoForMultiplatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { + return nil +} +func (m *mockAccountRepoForMultiplatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error { + return nil +} +func (m *mockAccountRepoForMultiplatform) ClearRateLimit(ctx context.Context, id int64) error { + return nil +} +func (m *mockAccountRepoForMultiplatform) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { + return nil +} +func (m *mockAccountRepoForMultiplatform) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error { + return nil +} +func (m *mockAccountRepoForMultiplatform) BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) { + return 0, nil +} + +// Verify interface implementation +var _ AccountRepository = (*mockAccountRepoForMultiplatform)(nil) + +// mockGatewayCacheForMultiplatform 多平台测试用的 cache mock +type mockGatewayCacheForMultiplatform struct { + sessionBindings map[string]int64 +} + +func (m *mockGatewayCacheForMultiplatform) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) { + if id, ok := m.sessionBindings[sessionHash]; ok { + return id, nil + } + return 0, errors.New("not found") +} + +func (m *mockGatewayCacheForMultiplatform) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error { + if m.sessionBindings == nil { + m.sessionBindings = make(map[string]int64) + } + m.sessionBindings[sessionHash] = accountID + return nil +} + +func (m *mockGatewayCacheForMultiplatform) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error { + return nil +} + +func ptr[T any](v T) *T { + return &v +} + +func TestGatewayService_SelectAccountForModelWithExclusions_OnlyAnthropic(t *testing.T) { + ctx := context.Background() + + repo := &mockAccountRepoForMultiplatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, + {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForMultiplatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + } + + acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity}) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(1), acc.ID, "应选择优先级最高的账户") +} + +func TestGatewayService_SelectAccountForModelWithExclusions_OnlyAntigravity(t *testing.T) { + ctx := context.Background() + + repo := &mockAccountRepoForMultiplatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForMultiplatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + } + + acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity}) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(1), acc.ID) + require.Equal(t, PlatformAntigravity, acc.Platform) +} + +func TestGatewayService_SelectAccountForModelWithExclusions_MixedPlatforms_SamePriority(t *testing.T) { + ctx := context.Background() + now := time.Now() + + repo := &mockAccountRepoForMultiplatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: ptr(now.Add(-1 * time.Hour))}, + {ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: ptr(now.Add(-2 * time.Hour))}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForMultiplatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + } + + acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity}) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID, "应选择最久未用的账户(Antigravity)") +} + +func TestGatewayService_SelectAccountForModelWithExclusions_MixedPlatforms_DiffPriority(t *testing.T) { + ctx := context.Background() + + repo := &mockAccountRepoForMultiplatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true}, + {ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForMultiplatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + } + + acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity}) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID, "应选择优先级更高的账户(Antigravity, priority=1)") +} + +func TestGatewayService_SelectAccountForModelWithExclusions_ModelNotSupported(t *testing.T) { + ctx := context.Background() + + repo := &mockAccountRepoForMultiplatform{ + accounts: []Account{ + // Anthropic 账户配置了模型映射,只支持 other-model + // 注意:model_mapping 需要是 map[string]any 格式 + { + ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, + Credentials: map[string]any{"model_mapping": map[string]any{"other-model": "x"}}, + }, + // Antigravity 账户支持所有 claude 模型 + {ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForMultiplatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + } + + acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity}) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID, "Anthropic 不支持该模型,应选择 Antigravity") +} + +func TestGatewayService_SelectAccountForModelWithExclusions_NoAvailableAccounts(t *testing.T) { + ctx := context.Background() + + repo := &mockAccountRepoForMultiplatform{ + accounts: []Account{}, + accountsByID: map[int64]*Account{}, + } + + cache := &mockGatewayCacheForMultiplatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + } + + acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity}) + require.Error(t, err) + require.Nil(t, acc) + require.Contains(t, err.Error(), "no available accounts") +} + +func TestGatewayService_SelectAccountForModelWithExclusions_AllExcluded(t *testing.T) { + ctx := context.Background() + + repo := &mockAccountRepoForMultiplatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, + {ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForMultiplatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + } + + excludedIDs := map[int64]struct{}{1: {}, 2: {}} + acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, []string{PlatformAnthropic, PlatformAntigravity}) + require.Error(t, err) + require.Nil(t, acc) +} + +func TestGatewayService_SelectAccountForModelWithExclusions_Schedulability(t *testing.T) { + ctx := context.Background() + now := time.Now() + + tests := []struct { + name string + accounts []Account + expectedID int64 + }{ + { + name: "过载账户被跳过", + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, OverloadUntil: ptr(now.Add(1 * time.Hour))}, + {ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true}, + }, + expectedID: 2, + }, + { + name: "限流账户被跳过", + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, RateLimitResetAt: ptr(now.Add(1 * time.Hour))}, + {ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true}, + }, + expectedID: 2, + }, + { + name: "非active账户被跳过", + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: "error", Schedulable: true}, + {ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true}, + }, + expectedID: 2, + }, + { + name: "schedulable=false被跳过", + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: false}, + {ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true}, + }, + expectedID: 2, + }, + { + name: "过期的过载账户可调度", + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, OverloadUntil: ptr(now.Add(-1 * time.Hour))}, + {ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true}, + }, + expectedID: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo := &mockAccountRepoForMultiplatform{ + accounts: tt.accounts, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForMultiplatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + } + + acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity}) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, tt.expectedID, acc.ID) + }) + } +} + +func TestGatewayService_SelectAccountForModelWithExclusions_StickySession(t *testing.T) { + ctx := context.Background() + + t.Run("粘性会话命中", func(t *testing.T) { + repo := &mockAccountRepoForMultiplatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true}, + {ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForMultiplatform{ + sessionBindings: map[string]int64{"session-123": 1}, + } + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + } + + acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity}) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(1), acc.ID, "应返回粘性会话绑定的账户") + }) + + t.Run("粘性会话账户被排除-降级选择", func(t *testing.T) { + repo := &mockAccountRepoForMultiplatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true}, + {ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForMultiplatform{ + sessionBindings: map[string]int64{"session-123": 1}, + } + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + } + + excludedIDs := map[int64]struct{}{1: {}} + acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", excludedIDs, []string{PlatformAnthropic, PlatformAntigravity}) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID, "粘性会话账户被排除,应选择其他账户") + }) + + t.Run("粘性会话账户不可调度-降级选择", func(t *testing.T) { + repo := &mockAccountRepoForMultiplatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: "error", Schedulable: true}, + {ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForMultiplatform{ + sessionBindings: map[string]int64{"session-123": 1}, + } + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + } + + acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity}) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID, "粘性会话账户不可调度,应选择其他账户") + }) +} + +func TestGatewayService_isModelSupportedByAccount(t *testing.T) { + svc := &GatewayService{} + + tests := []struct { + name string + account *Account + model string + expected bool + }{ + { + name: "Antigravity平台-支持claude模型", + account: &Account{Platform: PlatformAntigravity}, + model: "claude-3-5-sonnet-20241022", + expected: true, + }, + { + name: "Antigravity平台-支持gemini模型", + account: &Account{Platform: PlatformAntigravity}, + model: "gemini-2.5-flash", + expected: true, + }, + { + name: "Antigravity平台-不支持gpt模型", + account: &Account{Platform: PlatformAntigravity}, + model: "gpt-4", + expected: false, + }, + { + name: "Anthropic平台-无映射配置-支持所有模型", + account: &Account{Platform: PlatformAnthropic}, + model: "claude-3-5-sonnet-20241022", + expected: true, + }, + { + name: "Anthropic平台-有映射配置-只支持配置的模型", + account: &Account{ + Platform: PlatformAnthropic, + Credentials: map[string]any{"model_mapping": map[string]any{"claude-opus-4": "x"}}, + }, + model: "claude-3-5-sonnet-20241022", + expected: false, + }, + { + name: "Anthropic平台-有映射配置-支持配置的模型", + account: &Account{ + Platform: PlatformAnthropic, + Credentials: map[string]any{"model_mapping": map[string]any{"claude-3-5-sonnet-20241022": "x"}}, + }, + model: "claude-3-5-sonnet-20241022", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := svc.isModelSupportedByAccount(tt.account, tt.model) + require.Equal(t, tt.expected, got) + }) + } +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index d4e1a07b..1c7fde96 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -291,6 +291,13 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int // SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { + // 使用多平台账户选择,包含 anthropic 和 antigravity 平台 + platforms := []string{PlatformAnthropic, PlatformAntigravity} + return s.selectAccountForModelWithPlatforms(ctx, groupID, sessionHash, requestedModel, excludedIDs, platforms) +} + +// selectAccountForModelWithPlatforms 选择多平台账户 +func (s *GatewayService) selectAccountForModelWithPlatforms(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platforms []string) (*Account, error) { // 1. 查询粘性会话 if sessionHash != "" { accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) @@ -298,8 +305,8 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context if _, excluded := excludedIDs[accountID]; !excluded { account, err := s.accountRepo.GetByID(ctx, accountID) // 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中 - // 同时检查模型支持 - if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { + // 同时检查模型支持(根据平台类型分别处理) + if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { // 续期粘性会话 if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil { log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) @@ -310,13 +317,13 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context } } - // 2. 获取可调度账号列表(排除限流和过载的账号,仅限 Anthropic 平台) + // 2. 获取可调度账号列表(排除限流和过载的账号,支持多平台) var accounts []Account var err error if groupID != nil { - accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAnthropic) + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms) } else { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAnthropic) + accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) } if err != nil { return nil, fmt.Errorf("query accounts failed: %w", err) @@ -329,8 +336,8 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context if _, excluded := excludedIDs[acc.ID]; excluded { continue } - // 检查模型支持 - if requestedModel != "" && !acc.IsModelSupported(requestedModel) { + // 检查模型支持(根据平台类型分别处理) + if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { continue } if selected == nil { @@ -374,6 +381,37 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context return selected, nil } +// isModelSupportedByAccount 根据账户平台检查模型支持 +func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool { + if account.Platform == PlatformAntigravity { + // Antigravity 平台使用专门的模型支持检查 + return IsAntigravityModelSupported(requestedModel) + } + // 其他平台使用账户的模型支持检查 + return account.IsModelSupported(requestedModel) +} + +// IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型 +func IsAntigravityModelSupported(requestedModel string) bool { + // 直接支持的模型 + if antigravitySupportedModels[requestedModel] { + return true + } + // 可映射的模型 + if _, ok := antigravityModelMapping[requestedModel]; ok { + return true + } + // Gemini 前缀透传 + if strings.HasPrefix(requestedModel, "gemini-") { + return true + } + // Claude 模型支持(通过默认映射到 claude-sonnet-4-5) + if strings.HasPrefix(requestedModel, "claude-") { + return true + } + return false +} + // GetAccessToken 获取账号凭证 func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { switch account.Type { diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index c4a474c1..1e7f23af 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -33,11 +33,12 @@ const ( ) type GeminiMessagesCompatService struct { - accountRepo AccountRepository - cache GatewayCache - tokenProvider *GeminiTokenProvider - rateLimitService *RateLimitService - httpUpstream HTTPUpstream + accountRepo AccountRepository + cache GatewayCache + tokenProvider *GeminiTokenProvider + rateLimitService *RateLimitService + httpUpstream HTTPUpstream + antigravityGatewayService *AntigravityGatewayService } func NewGeminiMessagesCompatService( @@ -46,13 +47,15 @@ func NewGeminiMessagesCompatService( tokenProvider *GeminiTokenProvider, rateLimitService *RateLimitService, httpUpstream HTTPUpstream, + antigravityGatewayService *AntigravityGatewayService, ) *GeminiMessagesCompatService { return &GeminiMessagesCompatService{ - accountRepo: accountRepo, - cache: cache, - tokenProvider: tokenProvider, - rateLimitService: rateLimitService, - httpUpstream: httpUpstream, + accountRepo: accountRepo, + cache: cache, + tokenProvider: tokenProvider, + rateLimitService: rateLimitService, + httpUpstream: httpUpstream, + antigravityGatewayService: antigravityGatewayService, } } @@ -67,12 +70,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { cacheKey := "gemini:" + sessionHash + platforms := []string{PlatformGemini, PlatformAntigravity} + if sessionHash != "" { accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey) if err == nil && accountID > 0 { if _, excluded := excludedIDs[accountID]; !excluded { account, err := s.accountRepo.GetByID(ctx, accountID) - if err == nil && account.IsSchedulable() && account.Platform == PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) { + // 支持 gemini 和 antigravity 平台的粘性会话 + if err == nil && account.IsSchedulable() && (account.Platform == PlatformGemini || account.Platform == PlatformAntigravity) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { _ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL) return account, nil } @@ -80,12 +86,13 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co } } + // 同时查询 gemini 和 antigravity 平台的可调度账户 var accounts []Account var err error if groupID != nil { - accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini) + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms) } else { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini) + accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) } if err != nil { return nil, fmt.Errorf("query accounts failed: %w", err) @@ -97,7 +104,8 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co if _, excluded := excludedIDs[acc.ID]; excluded { continue } - if requestedModel != "" && !acc.IsModelSupported(requestedModel) { + // 根据平台类型分别检查模型支持 + if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { continue } if selected == nil { @@ -127,9 +135,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co if selected == nil { if requestedModel != "" { - return nil, fmt.Errorf("no available Gemini accounts supporting model: %s", requestedModel) + return nil, fmt.Errorf("no available Gemini/Antigravity accounts supporting model: %s", requestedModel) } - return nil, errors.New("no available Gemini accounts") + return nil, errors.New("no available Gemini/Antigravity accounts") } if sessionHash != "" { @@ -139,6 +147,34 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co return selected, nil } +// isModelSupportedByAccount 根据账户平台检查模型支持 +func (s *GeminiMessagesCompatService) isModelSupportedByAccount(account *Account, requestedModel string) bool { + if account.Platform == PlatformAntigravity { + return IsAntigravityModelSupported(requestedModel) + } + return account.IsModelSupported(requestedModel) +} + +// GetAntigravityGatewayService 返回 AntigravityGatewayService +func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *AntigravityGatewayService { + return s.antigravityGatewayService +} + +// HasAntigravityAccounts 检查是否有可用的 antigravity 账户 +func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) { + var accounts []Account + var err error + if groupID != nil { + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAntigravity) + } else { + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAntigravity) + } + if err != nil { + return false, err + } + return len(accounts) > 0, nil +} + // SelectAccountForAIStudioEndpoints selects an account that is likely to succeed against // generativelanguage.googleapis.com (e.g. GET /v1beta/models). // diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go new file mode 100644 index 00000000..9fd8ae49 --- /dev/null +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -0,0 +1,568 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +// mockAccountRepoForGemini Gemini 测试用的 mock +type mockAccountRepoForGemini struct { + accounts []Account + accountsByID map[int64]*Account +} + +func (m *mockAccountRepoForGemini) GetByID(ctx context.Context, id int64) (*Account, error) { + if acc, ok := m.accountsByID[id]; ok { + return acc, nil + } + return nil, errors.New("account not found") +} + +func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) { + platformSet := make(map[string]bool) + for _, p := range platforms { + platformSet[p] = true + } + var result []Account + for _, acc := range m.accounts { + if platformSet[acc.Platform] && acc.IsSchedulable() { + result = append(result, acc) + } + } + return result, nil +} + +func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) { + return m.ListSchedulableByPlatforms(ctx, platforms) +} + +// Stub methods to implement AccountRepository interface +func (m *mockAccountRepoForGemini) Create(ctx context.Context, account *Account) error { return nil } +func (m *mockAccountRepoForGemini) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) { + return nil, nil +} +func (m *mockAccountRepoForGemini) Update(ctx context.Context, account *Account) error { return nil } +func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error { return nil } +func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { + return nil, nil +} +func (m *mockAccountRepoForGemini) ListActive(ctx context.Context) ([]Account, error) { return nil, nil } +func (m *mockAccountRepoForGemini) ListByPlatform(ctx context.Context, platform string) ([]Account, error) { + return nil, nil +} +func (m *mockAccountRepoForGemini) UpdateLastUsed(ctx context.Context, id int64) error { return nil } +func (m *mockAccountRepoForGemini) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error { + return nil +} +func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, errorMsg string) error { + return nil +} +func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { + return nil +} +func (m *mockAccountRepoForGemini) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error { + return nil +} +func (m *mockAccountRepoForGemini) ListSchedulable(ctx context.Context) ([]Account, error) { + return nil, nil +} +func (m *mockAccountRepoForGemini) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) { + return nil, nil +} +func (m *mockAccountRepoForGemini) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) { + var result []Account + for _, acc := range m.accounts { + if acc.Platform == platform && acc.IsSchedulable() { + result = append(result, acc) + } + } + return result, nil +} +func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) { + // 测试时不区分 groupID,直接按 platform 过滤 + return m.ListSchedulableByPlatform(ctx, platform) +} +func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { + return nil +} +func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error { + return nil +} +func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64) error { return nil } +func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { + return nil +} +func (m *mockAccountRepoForGemini) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error { + return nil +} +func (m *mockAccountRepoForGemini) BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) { + return 0, nil +} + +// Verify interface implementation +var _ AccountRepository = (*mockAccountRepoForGemini)(nil) + +// mockGatewayCacheForGemini Gemini 测试用的 cache mock +type mockGatewayCacheForGemini struct { + sessionBindings map[string]int64 +} + +func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) { + if id, ok := m.sessionBindings[sessionHash]; ok { + return id, nil + } + return 0, errors.New("not found") +} + +func (m *mockGatewayCacheForGemini) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error { + if m.sessionBindings == nil { + m.sessionBindings = make(map[string]int64) + } + m.sessionBindings[sessionHash] = accountID + return nil +} + +func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error { + return nil +} + +func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OnlyGemini(t *testing.T) { + ctx := context.Background() + + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true}, + {ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForGemini{} + + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + cache: cache, + } + + acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(1), acc.ID, "应选择优先级最高的账户") +} + +func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OnlyAntigravity(t *testing.T) { + ctx := context.Background() + + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + {ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForGemini{} + + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + cache: cache, + } + + acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(1), acc.ID) + require.Equal(t, PlatformAntigravity, acc.Platform) +} + +func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ExcludesAnthropic(t *testing.T) { + ctx := context.Background() + + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, + {ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true}, + {ID: 3, Platform: PlatformAntigravity, Priority: 3, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForGemini{} + + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + cache: cache, + } + + acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil) + require.NoError(t, err) + require.NotNil(t, acc) + // Anthropic 不在 [gemini, antigravity] 平台列表中,应被过滤 + require.Equal(t, int64(2), acc.ID, "Anthropic 平台应被排除,选择 Gemini") +} + +func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_MixedPlatforms_SamePriority(t *testing.T) { + ctx := context.Background() + now := time.Now() + + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: ptr(now.Add(-1 * time.Hour))}, + {ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: ptr(now.Add(-2 * time.Hour))}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForGemini{} + + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + cache: cache, + } + + acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID, "应选择最久未用的账户(Antigravity)") +} + +func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPreferred(t *testing.T) { + ctx := context.Background() + + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + {ID: 1, Platform: PlatformGemini, Type: AccountTypeApiKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil}, + {ID: 2, Platform: PlatformGemini, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForGemini{} + + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + cache: cache, + } + + acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID, "同优先级且都未使用时,应优先选择 OAuth 账户") + require.Equal(t, AccountTypeOAuth, acc.Type) +} + +func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPreferred_MixedPlatforms(t *testing.T) { + ctx := context.Background() + + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + {ID: 1, Platform: PlatformGemini, Type: AccountTypeApiKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil}, + {ID: 2, Platform: PlatformAntigravity, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForGemini{} + + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + cache: cache, + } + + acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID, "跨平台时,同样优先选择 OAuth 账户") +} + +func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_NoAvailableAccounts(t *testing.T) { + ctx := context.Background() + + repo := &mockAccountRepoForGemini{ + accounts: []Account{}, + accountsByID: map[int64]*Account{}, + } + + cache := &mockGatewayCacheForGemini{} + + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + cache: cache, + } + + acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil) + require.Error(t, err) + require.Nil(t, acc) + require.Contains(t, err.Error(), "no available Gemini/Antigravity accounts") +} + +func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickySession(t *testing.T) { + ctx := context.Background() + + t.Run("粘性会话命中-使用gemini前缀缓存键", func(t *testing.T) { + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + {ID: 1, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true}, + {ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + // 注意:缓存键使用 "gemini:" 前缀 + cache := &mockGatewayCacheForGemini{ + sessionBindings: map[string]int64{"gemini:session-123": 1}, + } + + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + cache: cache, + } + + acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(1), acc.ID, "应返回粘性会话绑定的账户") + }) + + t.Run("粘性会话不命中无前缀缓存键", func(t *testing.T) { + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + {ID: 1, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true}, + {ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + // 缓存键没有 "gemini:" 前缀,不应命中 + cache := &mockGatewayCacheForGemini{ + sessionBindings: map[string]int64{"session-123": 1}, + } + + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + cache: cache, + } + + acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil) + require.NoError(t, err) + require.NotNil(t, acc) + // 粘性会话未命中,按优先级选择 + require.Equal(t, int64(2), acc.ID, "粘性会话未命中,应按优先级选择 Antigravity") + }) + + t.Run("粘性会话Anthropic账户-降级选择", func(t *testing.T) { + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true}, + {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForGemini{ + sessionBindings: map[string]int64{"gemini:session-123": 1}, + } + + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + cache: cache, + } + + acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil) + require.NoError(t, err) + require.NotNil(t, acc) + // 粘性会话绑定的是 Anthropic 账户,不在 Gemini 平台列表中,应降级选择 + require.Equal(t, int64(2), acc.ID, "粘性会话账户是 Anthropic,应降级选择 Gemini 平台账户") + }) +} + +func TestGeminiMessagesCompatService_HasAntigravityAccounts(t *testing.T) { + ctx := context.Background() + + t.Run("有antigravity账户时返回true", func(t *testing.T) { + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + {ID: 1, Platform: PlatformGemini, Status: StatusActive, Schedulable: true}, + {ID: 2, Platform: PlatformAntigravity, Status: StatusActive, Schedulable: true}, + }, + } + + svc := &GeminiMessagesCompatService{accountRepo: repo} + + has, err := svc.HasAntigravityAccounts(ctx, nil) + require.NoError(t, err) + require.True(t, has) + }) + + t.Run("无antigravity账户时返回false", func(t *testing.T) { + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + {ID: 1, Platform: PlatformGemini, Status: StatusActive, Schedulable: true}, + }, + } + + svc := &GeminiMessagesCompatService{accountRepo: repo} + + has, err := svc.HasAntigravityAccounts(ctx, nil) + require.NoError(t, err) + require.False(t, has) + }) + + t.Run("antigravity账户不可调度时返回false", func(t *testing.T) { + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + {ID: 1, Platform: PlatformAntigravity, Status: StatusActive, Schedulable: false}, + }, + } + + svc := &GeminiMessagesCompatService{accountRepo: repo} + + has, err := svc.HasAntigravityAccounts(ctx, nil) + require.NoError(t, err) + require.False(t, has) + }) + + t.Run("带groupID查询", func(t *testing.T) { + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + {ID: 1, Platform: PlatformAntigravity, Status: StatusActive, Schedulable: true}, + }, + } + + svc := &GeminiMessagesCompatService{accountRepo: repo} + + groupID := int64(1) + has, err := svc.HasAntigravityAccounts(ctx, &groupID) + require.NoError(t, err) + require.True(t, has) + }) +} + +// TestGeminiPlatformRouting_DocumentRouteDecision 测试平台路由决策逻辑 +// 该测试文档化了 Handler 层应该如何根据 account.Platform 进行分流 +func TestGeminiPlatformRouting_DocumentRouteDecision(t *testing.T) { + tests := []struct { + name string + platform string + expectedService string // "gemini" 表示 ForwardNative, "antigravity" 表示 ForwardGemini + }{ + { + name: "Gemini平台走ForwardNative", + platform: PlatformGemini, + expectedService: "gemini", + }, + { + name: "Antigravity平台走ForwardGemini", + platform: PlatformAntigravity, + expectedService: "antigravity", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account := &Account{Platform: tt.platform} + + // 模拟 Handler 层的路由逻辑 + var serviceName string + if account.Platform == PlatformAntigravity { + serviceName = "antigravity" + } else { + serviceName = "gemini" + } + + require.Equal(t, tt.expectedService, serviceName, + "平台 %s 应该路由到 %s 服务", tt.platform, tt.expectedService) + }) + } +} + +func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) { + svc := &GeminiMessagesCompatService{} + + tests := []struct { + name string + account *Account + model string + expected bool + }{ + { + name: "Antigravity平台-支持gemini模型", + account: &Account{Platform: PlatformAntigravity}, + model: "gemini-2.5-flash", + expected: true, + }, + { + name: "Antigravity平台-支持claude模型", + account: &Account{Platform: PlatformAntigravity}, + model: "claude-3-5-sonnet-20241022", + expected: true, + }, + { + name: "Antigravity平台-不支持gpt模型", + account: &Account{Platform: PlatformAntigravity}, + model: "gpt-4", + expected: false, + }, + { + name: "Gemini平台-无映射配置-支持所有模型", + account: &Account{Platform: PlatformGemini}, + model: "gemini-2.5-flash", + expected: true, + }, + { + name: "Gemini平台-有映射配置-只支持配置的模型", + account: &Account{ + Platform: PlatformGemini, + Credentials: map[string]any{"model_mapping": map[string]any{"gemini-1.5-pro": "x"}}, + }, + model: "gemini-2.5-flash", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := svc.isModelSupportedByAccount(tt.account, tt.model) + require.Equal(t, tt.expected, got) + }) + } +} diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 23126bfb..76ca61fd 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -27,6 +27,7 @@ func NewTokenRefreshService( oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, geminiOAuthService *GeminiOAuthService, + antigravityOAuthService *AntigravityOAuthService, cfg *config.Config, ) *TokenRefreshService { s := &TokenRefreshService{ @@ -40,6 +41,7 @@ func NewTokenRefreshService( NewClaudeTokenRefresher(oauthService), NewOpenAITokenRefresher(openaiOAuthService), NewGeminiTokenRefresher(geminiOAuthService), + NewAntigravityTokenRefresher(antigravityOAuthService), } return s diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index e1012acb..5927dd5c 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -39,9 +39,10 @@ func ProvideTokenRefreshService( oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, geminiOAuthService *GeminiOAuthService, + antigravityOAuthService *AntigravityOAuthService, cfg *config.Config, ) *TokenRefreshService { - svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, cfg) + svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cfg) svc.Start() return svc } @@ -84,6 +85,8 @@ var ProviderSet = wire.NewSet( NewAntigravityOAuthService, NewGeminiTokenProvider, NewGeminiMessagesCompatService, + NewAntigravityTokenProvider, + NewAntigravityGatewayService, NewRateLimitService, NewAccountUsageService, NewAccountTestService, diff --git a/frontend/src/components/common/GroupSelector.vue b/frontend/src/components/common/GroupSelector.vue index b6d88ddd..1db827e6 100644 --- a/frontend/src/components/common/GroupSelector.vue +++ b/frontend/src/components/common/GroupSelector.vue @@ -62,6 +62,10 @@ const filteredGroups = computed(() => { if (!props.platform) { return props.groups } + // antigravity 账户可选择 anthropic 和 gemini 平台的分组 + if (props.platform === 'antigravity') { + return props.groups.filter((g) => g.platform === 'anthropic' || g.platform === 'gemini') + } return props.groups.filter((g) => g.platform === props.platform) })