feat: 完善 Antigravity 多平台网关支持,修复 Gemini handler 分流逻辑

This commit is contained in:
song
2025-12-28 17:48:52 +08:00
parent 6648e6506c
commit 1d085d982b
18 changed files with 3042 additions and 43 deletions

View File

@@ -122,8 +122,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
timingWheelService := service.ProvideTimingWheelService()
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, userService, concurrencyService, billingCacheService)
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
@@ -133,7 +135,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService)
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService)
httpServer := server.ProvideHTTPServer(configConfig, engine)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
v := provideCleanup(db, client, tokenRefreshService, pricingService, emailQueueService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{
Server: httpServer,

View File

@@ -21,27 +21,30 @@ import (
// GatewayHandler handles API gateway requests
type GatewayHandler struct {
gatewayService *service.GatewayService
geminiCompatService *service.GeminiMessagesCompatService
userService *service.UserService
billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper
gatewayService *service.GatewayService
geminiCompatService *service.GeminiMessagesCompatService
antigravityGatewayService *service.AntigravityGatewayService
userService *service.UserService
billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper
}
// NewGatewayHandler creates a new GatewayHandler
func NewGatewayHandler(
gatewayService *service.GatewayService,
geminiCompatService *service.GeminiMessagesCompatService,
antigravityGatewayService *service.AntigravityGatewayService,
userService *service.UserService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
) *GatewayHandler {
return &GatewayHandler{
gatewayService: gatewayService,
geminiCompatService: geminiCompatService,
userService: userService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude),
gatewayService: gatewayService,
geminiCompatService: geminiCompatService,
antigravityGatewayService: antigravityGatewayService,
userService: userService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude),
}
}
@@ -163,8 +166,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 转发请求
result, err := h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, req.Model, "generateContent", req.Stream, body)
} else {
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
}
if accountReleaseFunc != nil {
accountReleaseFunc()
}
@@ -240,8 +248,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 转发请求
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body)
} else {
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body)
}
if accountReleaseFunc != nil {
accountReleaseFunc()
}

View File

@@ -32,6 +32,13 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
if err != nil {
// 没有 gemini 账户,检查是否有 antigravity 账户可用
hasAntigravity, _ := h.geminiCompatService.HasAntigravityAccounts(c.Request.Context(), apiKey.GroupID)
if hasAntigravity {
// antigravity 账户使用静态模型列表
c.JSON(http.StatusOK, gemini.FallbackModelsList())
return
}
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
@@ -69,6 +76,13 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
if err != nil {
// 没有 gemini 账户,检查是否有 antigravity 账户可用
hasAntigravity, _ := h.geminiCompatService.HasAntigravityAccounts(c.Request.Context(), apiKey.GroupID)
if hasAntigravity {
// antigravity 账户使用静态模型信息
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
return
}
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
@@ -182,8 +196,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
return
}
// 5) forward (writes response to client)
result, err := h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body)
// 5) forward (根据平台分流)
var result *service.ForwardResult
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, modelName, action, stream, body)
} else {
result, err = h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body)
}
if accountReleaseFunc != nil {
accountReleaseFunc()
}

View File

@@ -0,0 +1,143 @@
//go:build unit
package handler
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
// TestGeminiV1BetaHandler_PlatformRoutingInvariant 文档化并验证 Handler 层的平台路由逻辑不变量
// 该测试确保 gemini 和 antigravity 平台的路由逻辑符合预期
func TestGeminiV1BetaHandler_PlatformRoutingInvariant(t *testing.T) {
tests := []struct {
name string
platform string
expectedService string
description string
}{
{
name: "Gemini平台使用ForwardNative",
platform: service.PlatformGemini,
expectedService: "GeminiMessagesCompatService.ForwardNative",
description: "Gemini OAuth 账户直接调用 Google API",
},
{
name: "Antigravity平台使用ForwardGemini",
platform: service.PlatformAntigravity,
expectedService: "AntigravityGatewayService.ForwardGemini",
description: "Antigravity 账户通过 CRS 中转,支持 Gemini 协议",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模拟 GeminiV1BetaModels 中的路由决策 (lines 199-205 in gemini_v1beta_handler.go)
var routedService string
if tt.platform == service.PlatformAntigravity {
routedService = "AntigravityGatewayService.ForwardGemini"
} else {
routedService = "GeminiMessagesCompatService.ForwardNative"
}
require.Equal(t, tt.expectedService, routedService,
"平台 %s 应该路由到 %s: %s",
tt.platform, tt.expectedService, tt.description)
})
}
}
// TestGeminiV1BetaHandler_ListModelsAntigravityFallback 验证 ListModels 的 antigravity 降级逻辑
// 当没有 gemini 账户但有 antigravity 账户时,应返回静态模型列表
func TestGeminiV1BetaHandler_ListModelsAntigravityFallback(t *testing.T) {
tests := []struct {
name string
hasGeminiAccount bool
hasAntigravity bool
expectedBehavior string
}{
{
name: "有Gemini账户-调用ForwardAIStudioGET",
hasGeminiAccount: true,
hasAntigravity: false,
expectedBehavior: "forward_to_upstream",
},
{
name: "无Gemini有Antigravity-返回静态列表",
hasGeminiAccount: false,
hasAntigravity: true,
expectedBehavior: "static_fallback",
},
{
name: "无任何账户-返回503",
hasGeminiAccount: false,
hasAntigravity: false,
expectedBehavior: "service_unavailable",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模拟 GeminiV1BetaListModels 的逻辑 (lines 33-44 in gemini_v1beta_handler.go)
var behavior string
if tt.hasGeminiAccount {
behavior = "forward_to_upstream"
} else if tt.hasAntigravity {
behavior = "static_fallback"
} else {
behavior = "service_unavailable"
}
require.Equal(t, tt.expectedBehavior, behavior)
})
}
}
// TestGeminiV1BetaHandler_GetModelAntigravityFallback 验证 GetModel 的 antigravity 降级逻辑
func TestGeminiV1BetaHandler_GetModelAntigravityFallback(t *testing.T) {
tests := []struct {
name string
hasGeminiAccount bool
hasAntigravity bool
expectedBehavior string
}{
{
name: "有Gemini账户-调用ForwardAIStudioGET",
hasGeminiAccount: true,
hasAntigravity: false,
expectedBehavior: "forward_to_upstream",
},
{
name: "无Gemini有Antigravity-返回静态模型信息",
hasGeminiAccount: false,
hasAntigravity: true,
expectedBehavior: "static_model_info",
},
{
name: "无任何账户-返回503",
hasGeminiAccount: false,
hasAntigravity: false,
expectedBehavior: "service_unavailable",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模拟 GeminiV1BetaGetModel 的逻辑 (lines 77-87 in gemini_v1beta_handler.go)
var behavior string
if tt.hasGeminiAccount {
behavior = "forward_to_upstream"
} else if tt.hasAntigravity {
behavior = "static_model_info"
} else {
behavior = "service_unavailable"
}
require.Equal(t, tt.expectedBehavior, behavior)
})
}
}

View File

@@ -337,6 +337,56 @@ func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Cont
return outAccounts, nil
}
func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
if len(platforms) == 0 {
return nil, nil
}
var accounts []accountModel
now := time.Now()
err := r.db.WithContext(ctx).
Where("platform IN ?", platforms).
Where("status = ? AND schedulable = ?", service.StatusActive, true).
Where("(overload_until IS NULL OR overload_until <= ?)", now).
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
func (r *accountRepository) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
if len(platforms) == 0 {
return nil, nil
}
var accounts []accountModel
now := time.Now()
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ?", groupID).
Where("accounts.platform IN ?", platforms).
Where("accounts.status = ? AND accounts.schedulable = ?", service.StatusActive, true).
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).

View File

@@ -0,0 +1,250 @@
//go:build integration
package repository
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/datatypes"
"gorm.io/gorm"
)
// GatewayRoutingSuite 测试网关路由相关的数据库查询
// 验证账户选择和分流逻辑在真实数据库环境下的行为
type GatewayRoutingSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
accountRepo *accountRepository
}
func (s *GatewayRoutingSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.accountRepo = NewAccountRepository(s.db).(*accountRepository)
}
func TestGatewayRoutingSuite(t *testing.T) {
suite.Run(t, new(GatewayRoutingSuite))
}
// TestListSchedulableByPlatforms_GeminiAndAntigravity 验证多平台账户查询
func (s *GatewayRoutingSuite) TestListSchedulableByPlatforms_GeminiAndAntigravity() {
// 创建各平台账户
geminiAcc := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "gemini-oauth",
Platform: service.PlatformGemini,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Schedulable: true,
Priority: 1,
})
antigravityAcc := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "antigravity-oauth",
Platform: service.PlatformAntigravity,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Schedulable: true,
Priority: 2,
Credentials: datatypes.JSONMap{
"access_token": "test-token",
"refresh_token": "test-refresh",
"project_id": "test-project",
},
})
// 创建不应被选中的 anthropic 账户
mustCreateAccount(s.T(), s.db, &accountModel{
Name: "anthropic-oauth",
Platform: service.PlatformAnthropic,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Schedulable: true,
Priority: 0,
})
// 查询 gemini + antigravity 平台
accounts, err := s.accountRepo.ListSchedulableByPlatforms(s.ctx, []string{
service.PlatformGemini,
service.PlatformAntigravity,
})
s.Require().NoError(err)
s.Require().Len(accounts, 2, "应返回 gemini 和 antigravity 两个账户")
// 验证返回的账户平台
platforms := make(map[string]bool)
for _, acc := range accounts {
platforms[acc.Platform] = true
}
s.Require().True(platforms[service.PlatformGemini], "应包含 gemini 账户")
s.Require().True(platforms[service.PlatformAntigravity], "应包含 antigravity 账户")
s.Require().False(platforms[service.PlatformAnthropic], "不应包含 anthropic 账户")
// 验证账户 ID 匹配
ids := make(map[int64]bool)
for _, acc := range accounts {
ids[acc.ID] = true
}
s.Require().True(ids[geminiAcc.ID])
s.Require().True(ids[antigravityAcc.ID])
}
// TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding 验证按分组过滤
func (s *GatewayRoutingSuite) TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding() {
// 创建 gemini 分组
group := mustCreateGroup(s.T(), s.db, &groupModel{
Name: "gemini-group",
Platform: service.PlatformGemini,
Status: service.StatusActive,
})
// 创建账户
boundAcc := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "bound-antigravity",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
Schedulable: true,
})
unboundAcc := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "unbound-antigravity",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
Schedulable: true,
})
// 只绑定一个账户到分组
mustBindAccountToGroup(s.T(), s.db, boundAcc.ID, group.ID, 1)
// 查询分组内的账户
accounts, err := s.accountRepo.ListSchedulableByGroupIDAndPlatforms(s.ctx, group.ID, []string{
service.PlatformGemini,
service.PlatformAntigravity,
})
s.Require().NoError(err)
s.Require().Len(accounts, 1, "应只返回绑定到分组的账户")
s.Require().Equal(boundAcc.ID, accounts[0].ID)
// 确认未绑定的账户不在结果中
for _, acc := range accounts {
s.Require().NotEqual(unboundAcc.ID, acc.ID, "不应包含未绑定的账户")
}
}
// TestListSchedulableByPlatform_Antigravity 验证单平台查询
func (s *GatewayRoutingSuite) TestListSchedulableByPlatform_Antigravity() {
// 创建多种平台账户
mustCreateAccount(s.T(), s.db, &accountModel{
Name: "gemini-1",
Platform: service.PlatformGemini,
Status: service.StatusActive,
Schedulable: true,
})
antigravity := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "antigravity-1",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
Schedulable: true,
})
// 只查询 antigravity 平台
accounts, err := s.accountRepo.ListSchedulableByPlatform(s.ctx, service.PlatformAntigravity)
s.Require().NoError(err)
s.Require().Len(accounts, 1)
s.Require().Equal(antigravity.ID, accounts[0].ID)
s.Require().Equal(service.PlatformAntigravity, accounts[0].Platform)
}
// TestSchedulableFilter_ExcludesInactive 验证不可调度账户被过滤
func (s *GatewayRoutingSuite) TestSchedulableFilter_ExcludesInactive() {
// 创建可调度账户
activeAcc := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "active-antigravity",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
Schedulable: true,
})
// 创建不可调度账户(需要先创建再更新,因为 fixture 默认设置 Schedulable=true
inactiveAcc := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "inactive-antigravity",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
})
s.Require().NoError(s.db.Model(&accountModel{}).Where("id = ?", inactiveAcc.ID).Update("schedulable", false).Error)
// 创建错误状态账户
mustCreateAccount(s.T(), s.db, &accountModel{
Name: "error-antigravity",
Platform: service.PlatformAntigravity,
Status: service.StatusError,
Schedulable: true,
})
accounts, err := s.accountRepo.ListSchedulableByPlatform(s.ctx, service.PlatformAntigravity)
s.Require().NoError(err)
s.Require().Len(accounts, 1, "应只返回可调度的 active 账户")
s.Require().Equal(activeAcc.ID, accounts[0].ID)
}
// TestPlatformRoutingDecision 验证平台路由决策
// 这个测试模拟 Handler 层在选择账户后的路由决策逻辑
func (s *GatewayRoutingSuite) TestPlatformRoutingDecision() {
// 创建两种平台的账户
geminiAcc := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "gemini-route-test",
Platform: service.PlatformGemini,
Status: service.StatusActive,
Schedulable: true,
})
antigravityAcc := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "antigravity-route-test",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
Schedulable: true,
})
tests := []struct {
name string
accountID int64
expectedService string
}{
{
name: "Gemini账户路由到ForwardNative",
accountID: geminiAcc.ID,
expectedService: "GeminiMessagesCompatService.ForwardNative",
},
{
name: "Antigravity账户路由到ForwardGemini",
accountID: antigravityAcc.ID,
expectedService: "AntigravityGatewayService.ForwardGemini",
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
// 从数据库获取账户
account, err := s.accountRepo.GetByID(s.ctx, tt.accountID)
s.Require().NoError(err)
// 模拟 Handler 层的路由决策
var routedService string
if account.Platform == service.PlatformAntigravity {
routedService = "AntigravityGatewayService.ForwardGemini"
} else {
routedService = "GeminiMessagesCompatService.ForwardNative"
}
s.Require().Equal(tt.expectedService, routedService)
})
}
}

View File

@@ -38,6 +38,8 @@ type AccountRepository interface {
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error)
ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error)
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error)
ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error

View File

@@ -0,0 +1,845 @@
package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
const (
antigravityStickySessionTTL = time.Hour
antigravityMaxRetries = 5
antigravityRetryBaseDelay = 1 * time.Second
antigravityRetryMaxDelay = 16 * time.Second
)
// Antigravity 直接支持的模型
var antigravitySupportedModels = map[string]bool{
"claude-opus-4-5-thinking": true,
"claude-sonnet-4-5": true,
"claude-sonnet-4-5-thinking": true,
"gemini-2.5-flash": true,
"gemini-2.5-flash-lite": true,
"gemini-2.5-flash-thinking": true,
"gemini-3-flash": true,
"gemini-3-pro-low": true,
"gemini-3-pro-high": true,
"gemini-3-pro-preview": true,
"gemini-3-pro-image": true,
}
// Antigravity 系统默认模型映射表(不支持 → 支持)
var antigravityModelMapping = map[string]string{
"claude-3-5-sonnet-20241022": "claude-sonnet-4-5",
"claude-3-5-sonnet-20240620": "claude-sonnet-4-5",
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5-thinking",
"claude-opus-4": "claude-opus-4-5-thinking",
"claude-opus-4-5-20251101": "claude-opus-4-5-thinking",
"claude-haiku-4": "claude-sonnet-4-5",
"claude-3-haiku-20240307": "claude-sonnet-4-5",
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
}
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
type AntigravityGatewayService struct {
accountRepo AccountRepository
cache GatewayCache
tokenProvider *AntigravityTokenProvider
rateLimitService *RateLimitService
httpUpstream HTTPUpstream
}
func NewAntigravityGatewayService(
accountRepo AccountRepository,
cache GatewayCache,
tokenProvider *AntigravityTokenProvider,
rateLimitService *RateLimitService,
httpUpstream HTTPUpstream,
) *AntigravityGatewayService {
return &AntigravityGatewayService{
accountRepo: accountRepo,
cache: cache,
tokenProvider: tokenProvider,
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
}
}
// GetTokenProvider 返回 token provider
func (s *AntigravityGatewayService) GetTokenProvider() *AntigravityTokenProvider {
return s.tokenProvider
}
// getMappedModel 获取映射后的模型名
func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedModel string) string {
// 1. 优先使用账户级映射(复用现有方法)
if mapped := account.GetMappedModel(requestedModel); mapped != requestedModel {
return mapped
}
// 2. 系统默认映射
if mapped, ok := antigravityModelMapping[requestedModel]; ok {
return mapped
}
// 3. Gemini 模型透传
if strings.HasPrefix(requestedModel, "gemini-") {
return requestedModel
}
// 4. Claude 前缀透传直接支持的模型
if antigravitySupportedModels[requestedModel] {
return requestedModel
}
// 5. 默认值
return "claude-sonnet-4-5"
}
// IsModelSupported 检查模型是否被支持
func (s *AntigravityGatewayService) IsModelSupported(requestedModel string) bool {
// 直接支持的模型
if antigravitySupportedModels[requestedModel] {
return true
}
// 可映射的模型
if _, ok := antigravityModelMapping[requestedModel]; ok {
return true
}
// Gemini 前缀透传
if strings.HasPrefix(requestedModel, "gemini-") {
return true
}
// Claude 模型支持(通过默认映射)
if strings.HasPrefix(requestedModel, "claude-") {
return true
}
return false
}
// wrapV1InternalRequest 包装请求为 v1internal 格式
func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model string, originalBody []byte) ([]byte, error) {
var request any
if err := json.Unmarshal(originalBody, &request); err != nil {
return nil, fmt.Errorf("解析请求体失败: %w", err)
}
wrapped := map[string]any{
"project": projectID,
"requestId": "agent-" + uuid.New().String(),
"userAgent": "sub2api",
"requestType": "agent",
"model": model,
"request": request,
}
return json.Marshal(wrapped)
}
// unwrapV1InternalResponse 解包 v1internal 响应
func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byte, error) {
var outer map[string]any
if err := json.Unmarshal(body, &outer); err != nil {
return nil, err
}
if resp, ok := outer["response"]; ok {
return json.Marshal(resp)
}
return body, nil
}
// unwrapSSELine 解包 SSE 行中的 v1internal 响应
func (s *AntigravityGatewayService) unwrapSSELine(line string) string {
if !strings.HasPrefix(line, "data: ") {
return line
}
data := strings.TrimPrefix(line, "data: ")
if data == "" || data == "[DONE]" {
return line
}
var outer map[string]any
if err := json.Unmarshal([]byte(data), &outer); err != nil {
return line
}
if resp, ok := outer["response"]; ok {
unwrapped, err := json.Marshal(resp)
if err != nil {
return line
}
return "data: " + string(unwrapped)
}
return line
}
// Forward 转发 Claude 协议请求
func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
startTime := time.Now()
// 解析请求获取 model 和 stream
var req struct {
Model string `json:"model"`
Stream bool `json:"stream"`
}
if err := json.Unmarshal(body, &req); err != nil {
return nil, fmt.Errorf("parse request: %w", err)
}
if strings.TrimSpace(req.Model) == "" {
return nil, fmt.Errorf("missing model")
}
originalModel := req.Model
mappedModel := s.getMappedModel(account, req.Model)
if mappedModel != req.Model {
log.Printf("Antigravity model mapping: %s -> %s (account: %s)", req.Model, mappedModel, account.Name)
}
// 获取 access_token
if s.tokenProvider == nil {
return nil, errors.New("antigravity token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("获取 access_token 失败: %w", err)
}
// 获取 project_id
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID == "" {
return nil, errors.New("project_id not found in credentials")
}
// 代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 包装请求
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, body)
if err != nil {
return nil, err
}
// 构建上游 URL
action := "generateContent"
if req.Stream {
action = "streamGenerateContent"
}
fullURL := fmt.Sprintf("%s/v1internal:%s", antigravity.BaseURL, action)
if req.Stream {
fullURL += "?alt=sse"
}
// 重试循环
var resp *http.Response
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBody))
if err != nil {
return nil, err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
upstreamReq.Header.Set("User-Agent", antigravity.UserAgent)
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
if err != nil {
if attempt < antigravityMaxRetries {
log.Printf("Antigravity account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, antigravityMaxRetries, err)
sleepAntigravityBackoff(attempt)
continue
}
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
}
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if resp.StatusCode == 429 {
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
if attempt < antigravityMaxRetries {
log.Printf("Antigravity account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, antigravityMaxRetries)
sleepAntigravityBackoff(attempt)
continue
}
// 最后一次尝试也失败
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break
}
break
}
defer func() { _ = resp.Body.Close() }()
// 处理错误响应
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody)
}
requestID := resp.Header.Get("x-request-id")
if requestID != "" {
c.Header("x-request-id", requestID)
}
var usage *ClaudeUsage
var firstTokenMs *int
if req.Stream {
streamRes, err := s.handleStreamingResponse(c, resp, startTime, originalModel)
if err != nil {
return nil, err
}
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
} else {
usage, err = s.handleNonStreamingResponse(c, resp, originalModel)
if err != nil {
return nil, err
}
}
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
Model: originalModel, // 使用原始模型用于计费和日志
Stream: req.Stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
// ForwardGemini 转发 Gemini 协议请求
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
startTime := time.Now()
if strings.TrimSpace(originalModel) == "" {
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing model in URL")
}
if strings.TrimSpace(action) == "" {
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing action in URL")
}
if len(body) == 0 {
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
}
switch action {
case "generateContent", "streamGenerateContent", "countTokens":
// ok
default:
return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action)
}
mappedModel := s.getMappedModel(account, originalModel)
// 获取 access_token
if s.tokenProvider == nil {
return nil, errors.New("antigravity token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("获取 access_token 失败: %w", err)
}
// 获取 project_id
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID == "" {
return nil, errors.New("project_id not found in credentials")
}
// 代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 包装请求
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, body)
if err != nil {
return nil, err
}
// 构建上游 URL
upstreamAction := action
if action == "generateContent" && stream {
upstreamAction = "streamGenerateContent"
}
fullURL := fmt.Sprintf("%s/v1internal:%s", antigravity.BaseURL, upstreamAction)
if stream || upstreamAction == "streamGenerateContent" {
fullURL += "?alt=sse"
}
// 重试循环
var resp *http.Response
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBody))
if err != nil {
return nil, err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
upstreamReq.Header.Set("User-Agent", antigravity.UserAgent)
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
if err != nil {
if attempt < antigravityMaxRetries {
log.Printf("Antigravity account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, antigravityMaxRetries, err)
sleepAntigravityBackoff(attempt)
continue
}
if action == "countTokens" {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
}
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
}
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if resp.StatusCode == 429 {
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
if attempt < antigravityMaxRetries {
log.Printf("Antigravity account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, antigravityMaxRetries)
sleepAntigravityBackoff(attempt)
continue
}
if action == "countTokens" {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
}
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break
}
break
}
defer func() { _ = resp.Body.Close() }()
requestID := resp.Header.Get("x-request-id")
if requestID != "" {
c.Header("x-request-id", requestID)
}
// 处理错误响应
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if action == "countTokens" {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{
RequestID: requestID,
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
}
if s.shouldFailoverUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
// 解包并返回错误
unwrapped, _ := s.unwrapV1InternalResponse(respBody)
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}
c.Data(resp.StatusCode, contentType, unwrapped)
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
}
var usage *ClaudeUsage
var firstTokenMs *int
if stream || upstreamAction == "streamGenerateContent" {
streamRes, err := s.handleGeminiStreamingResponse(c, resp, startTime)
if err != nil {
return nil, err
}
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
} else {
usageResp, err := s.handleGeminiNonStreamingResponse(c, resp)
if err != nil {
return nil, err
}
usage = usageResp
}
if usage == nil {
usage = &ClaudeUsage{}
}
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
Model: originalModel,
Stream: stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
func (s *AntigravityGatewayService) shouldRetryUpstreamError(statusCode int) bool {
switch statusCode {
case 429, 500, 502, 503, 504, 529:
return true
default:
return false
}
}
func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
switch statusCode {
case 401, 403, 429, 529:
return true
default:
return statusCode >= 500
}
}
func sleepAntigravityBackoff(attempt int) {
sleepGeminiBackoff(attempt) // 复用 Gemini 的退避逻辑
}
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) {
if s.rateLimitService == nil {
return
}
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
}
type antigravityStreamResult struct {
usage *ClaudeUsage
firstTokenMs *int
}
func (s *AntigravityGatewayService) handleStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
c.Status(http.StatusOK)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
usage := &ClaudeUsage{}
var firstTokenMs *int
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadString('\n')
if err != nil && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("stream read error: %w", err)
}
if len(line) > 0 {
// 解包 v1internal 响应
unwrapped := s.unwrapSSELine(strings.TrimRight(line, "\r\n"))
// 解析 usage
if strings.HasPrefix(unwrapped, "data: ") {
data := strings.TrimPrefix(unwrapped, "data: ")
if data != "" && data != "[DONE]" {
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
s.parseClaudeSSEUsage(data, usage)
}
}
// 写入响应
if _, writeErr := fmt.Fprintf(c.Writer, "%s\n", unwrapped); writeErr != nil {
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, writeErr
}
flusher.Flush()
}
if errors.Is(err, io.EOF) {
break
}
}
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
func (s *AntigravityGatewayService) handleNonStreamingResponse(c *gin.Context, resp *http.Response, originalModel string) (*ClaudeUsage, error) {
body, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response")
}
// 解包 v1internal 响应
unwrapped, err := s.unwrapV1InternalResponse(body)
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
}
// 解析 usage
var respObj struct {
Usage ClaudeUsage `json:"usage"`
}
_ = json.Unmarshal(unwrapped, &respObj)
c.Data(http.StatusOK, "application/json", unwrapped)
return &respObj.Usage, nil
}
func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) {
c.Status(resp.StatusCode)
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "text/event-stream; charset=utf-8"
}
c.Header("Content-Type", contentType)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
reader := bufio.NewReader(resp.Body)
usage := &ClaudeUsage{}
var firstTokenMs *int
for {
line, err := reader.ReadString('\n')
if len(line) > 0 {
trimmed := strings.TrimRight(line, "\r\n")
if strings.HasPrefix(trimmed, "data:") {
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
if payload == "" || payload == "[DONE]" {
_, _ = io.WriteString(c.Writer, line)
flusher.Flush()
} else {
// 解包 v1internal 响应
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
if parseErr == nil && inner != nil {
payload = string(inner)
}
// 解析 usage
var parsed map[string]any
if json.Unmarshal(inner, &parsed) == nil {
if u := extractGeminiUsage(parsed); u != nil {
usage = u
}
}
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", payload)
flusher.Flush()
}
} else {
_, _ = io.WriteString(c.Writer, line)
flusher.Flush()
}
}
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, err
}
}
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
func (s *AntigravityGatewayService) handleGeminiNonStreamingResponse(c *gin.Context, resp *http.Response) (*ClaudeUsage, error) {
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// 解包 v1internal 响应
unwrapped, _ := s.unwrapV1InternalResponse(respBody)
var parsed map[string]any
if json.Unmarshal(unwrapped, &parsed) == nil {
if u := extractGeminiUsage(parsed); u != nil {
c.Data(resp.StatusCode, "application/json", unwrapped)
return u, nil
}
}
c.Data(resp.StatusCode, "application/json", unwrapped)
return &ClaudeUsage{}, nil
}
func (s *AntigravityGatewayService) parseClaudeSSEUsage(data string, usage *ClaudeUsage) {
// 解析 message_start 获取 input tokens
var msgStart struct {
Type string `json:"type"`
Message struct {
Usage ClaudeUsage `json:"usage"`
} `json:"message"`
}
if json.Unmarshal([]byte(data), &msgStart) == nil && msgStart.Type == "message_start" {
usage.InputTokens = msgStart.Message.Usage.InputTokens
usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens
usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens
}
// 解析 message_delta 获取 output tokens
var msgDelta struct {
Type string `json:"type"`
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"`
} `json:"usage"`
}
if json.Unmarshal([]byte(data), &msgDelta) == nil && msgDelta.Type == "message_delta" {
usage.OutputTokens = msgDelta.Usage.OutputTokens
if usage.InputTokens == 0 {
usage.InputTokens = msgDelta.Usage.InputTokens
}
if usage.CacheCreationInputTokens == 0 {
usage.CacheCreationInputTokens = msgDelta.Usage.CacheCreationInputTokens
}
if usage.CacheReadInputTokens == 0 {
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
}
}
}
func (s *AntigravityGatewayService) writeClaudeError(c *gin.Context, status int, errType, message string) error {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{"type": errType, "message": message},
})
return fmt.Errorf("%s", message)
}
func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, upstreamStatus int, body []byte) error {
var statusCode int
var errType, errMsg string
switch upstreamStatus {
case 400:
statusCode = http.StatusBadRequest
errType = "invalid_request_error"
errMsg = "Invalid request"
case 401:
statusCode = http.StatusBadGateway
errType = "authentication_error"
errMsg = "Upstream authentication failed"
case 403:
statusCode = http.StatusBadGateway
errType = "permission_error"
errMsg = "Upstream access forbidden"
case 429:
statusCode = http.StatusTooManyRequests
errType = "rate_limit_error"
errMsg = "Upstream rate limit exceeded"
case 529:
statusCode = http.StatusServiceUnavailable
errType = "overloaded_error"
errMsg = "Upstream service overloaded"
default:
statusCode = http.StatusBadGateway
errType = "upstream_error"
errMsg = "Upstream request failed"
}
c.JSON(statusCode, gin.H{
"type": "error",
"error": gin.H{"type": errType, "message": errMsg},
})
return fmt.Errorf("upstream error: %d", upstreamStatus)
}
func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error {
statusStr := "UNKNOWN"
switch status {
case 400:
statusStr = "INVALID_ARGUMENT"
case 404:
statusStr = "NOT_FOUND"
case 429:
statusStr = "RESOURCE_EXHAUSTED"
case 500:
statusStr = "INTERNAL"
case 502, 503:
statusStr = "UNAVAILABLE"
}
c.JSON(status, gin.H{
"error": gin.H{
"code": status,
"message": message,
"status": statusStr,
},
})
return fmt.Errorf("%s", message)
}

View File

@@ -0,0 +1,257 @@
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestIsAntigravityModelSupported(t *testing.T) {
tests := []struct {
name string
model string
expected bool
}{
// 直接支持的模型
{"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true},
{"直接支持 - claude-opus-4-5-thinking", "claude-opus-4-5-thinking", true},
{"直接支持 - claude-sonnet-4-5-thinking", "claude-sonnet-4-5-thinking", true},
{"直接支持 - gemini-2.5-flash", "gemini-2.5-flash", true},
{"直接支持 - gemini-2.5-flash-lite", "gemini-2.5-flash-lite", true},
{"直接支持 - gemini-3-pro-high", "gemini-3-pro-high", true},
// 可映射的模型
{"可映射 - claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20241022", true},
{"可映射 - claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20240620", true},
{"可映射 - claude-opus-4", "claude-opus-4", true},
{"可映射 - claude-haiku-4", "claude-haiku-4", true},
{"可映射 - claude-3-haiku-20240307", "claude-3-haiku-20240307", true},
// Gemini 前缀透传
{"Gemini前缀 - gemini-1.5-pro", "gemini-1.5-pro", true},
{"Gemini前缀 - gemini-unknown-model", "gemini-unknown-model", true},
{"Gemini前缀 - gemini-future-version", "gemini-future-version", true},
// Claude 前缀兜底
{"Claude前缀 - claude-unknown-model", "claude-unknown-model", true},
{"Claude前缀 - claude-3-opus-20240229", "claude-3-opus-20240229", true},
{"Claude前缀 - claude-future-version", "claude-future-version", true},
// 不支持的模型
{"不支持 - gpt-4", "gpt-4", false},
{"不支持 - gpt-4o", "gpt-4o", false},
{"不支持 - llama-3", "llama-3", false},
{"不支持 - mistral-7b", "mistral-7b", false},
{"不支持 - 空字符串", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsAntigravityModelSupported(tt.model)
require.Equal(t, tt.expected, got, "model: %s", tt.model)
})
}
}
func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
svc := &AntigravityGatewayService{}
tests := []struct {
name string
requestedModel string
accountMapping map[string]string
expected string
}{
// 1. 账户级映射优先注意model_mapping 在 credentials 中存储为 map[string]any
{
name: "账户映射优先",
requestedModel: "claude-3-5-sonnet-20241022",
accountMapping: map[string]string{"claude-3-5-sonnet-20241022": "custom-model"},
expected: "custom-model",
},
{
name: "账户映射覆盖系统映射",
requestedModel: "claude-opus-4",
accountMapping: map[string]string{"claude-opus-4": "my-opus"},
expected: "my-opus",
},
// 2. 系统默认映射
{
name: "系统映射 - claude-3-5-sonnet-20241022",
requestedModel: "claude-3-5-sonnet-20241022",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-3-5-sonnet-20240620",
requestedModel: "claude-3-5-sonnet-20240620",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-opus-4",
requestedModel: "claude-opus-4",
accountMapping: nil,
expected: "claude-opus-4-5-thinking",
},
{
name: "系统映射 - claude-opus-4-5-20251101",
requestedModel: "claude-opus-4-5-20251101",
accountMapping: nil,
expected: "claude-opus-4-5-thinking",
},
{
name: "系统映射 - claude-haiku-4",
requestedModel: "claude-haiku-4",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-3-haiku-20240307",
requestedModel: "claude-3-haiku-20240307",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-sonnet-4-5-20250929",
requestedModel: "claude-sonnet-4-5-20250929",
accountMapping: nil,
expected: "claude-sonnet-4-5-thinking",
},
// 3. Gemini 透传
{
name: "Gemini透传 - gemini-2.5-flash",
requestedModel: "gemini-2.5-flash",
accountMapping: nil,
expected: "gemini-2.5-flash",
},
{
name: "Gemini透传 - gemini-1.5-pro",
requestedModel: "gemini-1.5-pro",
accountMapping: nil,
expected: "gemini-1.5-pro",
},
{
name: "Gemini透传 - gemini-future-model",
requestedModel: "gemini-future-model",
accountMapping: nil,
expected: "gemini-future-model",
},
// 4. 直接支持的模型
{
name: "直接支持 - claude-sonnet-4-5",
requestedModel: "claude-sonnet-4-5",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
{
name: "直接支持 - claude-opus-4-5-thinking",
requestedModel: "claude-opus-4-5-thinking",
accountMapping: nil,
expected: "claude-opus-4-5-thinking",
},
{
name: "直接支持 - claude-sonnet-4-5-thinking",
requestedModel: "claude-sonnet-4-5-thinking",
accountMapping: nil,
expected: "claude-sonnet-4-5-thinking",
},
// 5. 默认值 fallback未知 claude 模型)
{
name: "默认值 - claude-unknown",
requestedModel: "claude-unknown",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
{
name: "默认值 - claude-3-opus-20240229",
requestedModel: "claude-3-opus-20240229",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
}
if tt.accountMapping != nil {
// GetModelMapping 期望 model_mapping 是 map[string]any 格式
mappingAny := make(map[string]any)
for k, v := range tt.accountMapping {
mappingAny[k] = v
}
account.Credentials = map[string]any{
"model_mapping": mappingAny,
}
}
got := svc.getMappedModel(account, tt.requestedModel)
require.Equal(t, tt.expected, got, "model: %s", tt.requestedModel)
})
}
}
func TestAntigravityGatewayService_GetMappedModel_EdgeCases(t *testing.T) {
svc := &AntigravityGatewayService{}
tests := []struct {
name string
requestedModel string
expected string
}{
// 空字符串回退到默认值
{"空字符串", "", "claude-sonnet-4-5"},
// 非 claude/gemini 前缀回退到默认值
{"非claude/gemini前缀 - gpt", "gpt-4", "claude-sonnet-4-5"},
{"非claude/gemini前缀 - llama", "llama-3", "claude-sonnet-4-5"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{Platform: PlatformAntigravity}
got := svc.getMappedModel(account, tt.requestedModel)
require.Equal(t, tt.expected, got)
})
}
}
func TestAntigravityGatewayService_IsModelSupported(t *testing.T) {
svc := &AntigravityGatewayService{}
tests := []struct {
name string
model string
expected bool
}{
// 直接支持
{"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true},
{"直接支持 - gemini-3-flash", "gemini-3-flash", true},
// 可映射
{"可映射 - claude-opus-4", "claude-opus-4", true},
// 前缀透传
{"Gemini前缀", "gemini-unknown", true},
{"Claude前缀", "claude-unknown", true},
// 不支持
{"不支持 - gpt-4", "gpt-4", false},
{"不支持 - 空字符串", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.IsModelSupported(tt.model)
require.Equal(t, tt.expected, got)
})
}
}

View File

@@ -0,0 +1,145 @@
package service
import (
"context"
"errors"
"log"
"strconv"
"strings"
"time"
)
const (
antigravityTokenRefreshSkew = 3 * time.Minute
antigravityTokenCacheSkew = 5 * time.Minute
)
// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type AntigravityTokenCache = GeminiTokenCache
// AntigravityTokenProvider 管理 Antigravity 账户的 access_token
type AntigravityTokenProvider struct {
accountRepo AccountRepository
tokenCache AntigravityTokenCache
antigravityOAuthService *AntigravityOAuthService
}
func NewAntigravityTokenProvider(
accountRepo AccountRepository,
tokenCache AntigravityTokenCache,
antigravityOAuthService *AntigravityOAuthService,
) *AntigravityTokenProvider {
return &AntigravityTokenProvider{
accountRepo: accountRepo,
tokenCache: tokenCache,
antigravityOAuthService: antigravityOAuthService,
}
}
// GetAccessToken 获取有效的 access_token
func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformAntigravity || account.Type != AccountTypeOAuth {
return "", errors.New("not an antigravity oauth account")
}
cacheKey := antigravityTokenCacheKey(account)
// 1. 先尝试缓存
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
// 2. 如果即将过期则刷新
expiresAt := parseAntigravityExpiresAt(account)
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew
if needsRefresh && p.tokenCache != nil {
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if err == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
// 从数据库获取最新账户信息
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
expiresAt = parseAntigravityExpiresAt(account)
if expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew {
if p.antigravityOAuthService == nil {
return "", errors.New("antigravity oauth service not configured")
}
tokenInfo, err := p.antigravityOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
return "", err
}
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr)
}
expiresAt = parseAntigravityExpiresAt(account)
}
}
}
accessToken := account.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found in credentials")
}
// 3. 存入缓存
if p.tokenCache != nil {
ttl := 30 * time.Minute
if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > antigravityTokenCacheSkew:
ttl = until - antigravityTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
return accessToken, nil
}
func antigravityTokenCacheKey(account *Account) string {
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" {
return "ag:" + projectID
}
return "ag:account:" + strconv.FormatInt(account.ID, 10)
}
func parseAntigravityExpiresAt(account *Account) *time.Time {
raw := strings.TrimSpace(account.GetCredential("expires_at"))
if raw == "" {
return nil
}
if unixSec, err := strconv.ParseInt(raw, 10, 64); err == nil && unixSec > 0 {
t := time.Unix(unixSec, 0)
return &t
}
if t, err := time.Parse(time.RFC3339, raw); err == nil {
return &t
}
return nil
}

View File

@@ -0,0 +1,57 @@
package service
import (
"context"
"strconv"
"time"
)
// AntigravityTokenRefresher 实现 TokenRefresher 接口
type AntigravityTokenRefresher struct {
antigravityOAuthService *AntigravityOAuthService
}
func NewAntigravityTokenRefresher(antigravityOAuthService *AntigravityOAuthService) *AntigravityTokenRefresher {
return &AntigravityTokenRefresher{
antigravityOAuthService: antigravityOAuthService,
}
}
// CanRefresh 检查是否可以刷新此账户
func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool {
return account.Platform == PlatformAntigravity && account.Type == AccountTypeOAuth
}
// NeedsRefresh 检查账户是否需要刷新
func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
if !r.CanRefresh(account) {
return false
}
expiresAtStr := account.GetCredential("expires_at")
if expiresAtStr == "" {
return false
}
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
if err != nil {
return false
}
expiryTime := time.Unix(expiresAt, 0)
return time.Until(expiryTime) < refreshWindow
}
// Refresh 执行 token 刷新
func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
tokenInfo, err := r.antigravityOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
return nil, err
}
newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
return newCredentials, nil
}

View File

@@ -0,0 +1,565 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// mockAccountRepoForMultiplatform 多平台测试用的 mock
type mockAccountRepoForMultiplatform struct {
accounts []Account
accountsByID map[int64]*Account
listPlatformsFunc func(ctx context.Context, platforms []string) ([]Account, error)
}
func (m *mockAccountRepoForMultiplatform) GetByID(ctx context.Context, id int64) (*Account, error) {
if acc, ok := m.accountsByID[id]; ok {
return acc, nil
}
return nil, errors.New("account not found")
}
func (m *mockAccountRepoForMultiplatform) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
if m.listPlatformsFunc != nil {
return m.listPlatformsFunc(ctx, platforms)
}
// 过滤符合平台的账户
var result []Account
platformSet := make(map[string]bool)
for _, p := range platforms {
platformSet[p] = true
}
for _, acc := range m.accounts {
if platformSet[acc.Platform] && acc.IsSchedulable() {
result = append(result, acc)
}
}
return result, nil
}
func (m *mockAccountRepoForMultiplatform) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
return m.ListSchedulableByPlatforms(ctx, platforms)
}
// Stub methods to implement AccountRepository interface
func (m *mockAccountRepoForMultiplatform) Create(ctx context.Context, account *Account) error {
return nil
}
func (m *mockAccountRepoForMultiplatform) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) {
return nil, nil
}
func (m *mockAccountRepoForMultiplatform) Update(ctx context.Context, account *Account) error {
return nil
}
func (m *mockAccountRepoForMultiplatform) Delete(ctx context.Context, id int64) error { return nil }
func (m *mockAccountRepoForMultiplatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockAccountRepoForMultiplatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockAccountRepoForMultiplatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForMultiplatform) ListActive(ctx context.Context) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForMultiplatform) ListByPlatform(ctx context.Context, platform string) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForMultiplatform) UpdateLastUsed(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForMultiplatform) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
return nil
}
func (m *mockAccountRepoForMultiplatform) SetError(ctx context.Context, id int64, errorMsg string) error {
return nil
}
func (m *mockAccountRepoForMultiplatform) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return nil
}
func (m *mockAccountRepoForMultiplatform) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
return nil
}
func (m *mockAccountRepoForMultiplatform) ListSchedulable(ctx context.Context) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForMultiplatform) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForMultiplatform) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForMultiplatform) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForMultiplatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForMultiplatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil
}
func (m *mockAccountRepoForMultiplatform) ClearRateLimit(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForMultiplatform) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil
}
func (m *mockAccountRepoForMultiplatform) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
return nil
}
func (m *mockAccountRepoForMultiplatform) BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) {
return 0, nil
}
// Verify interface implementation
var _ AccountRepository = (*mockAccountRepoForMultiplatform)(nil)
// mockGatewayCacheForMultiplatform 多平台测试用的 cache mock
type mockGatewayCacheForMultiplatform struct {
sessionBindings map[string]int64
}
func (m *mockGatewayCacheForMultiplatform) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
if id, ok := m.sessionBindings[sessionHash]; ok {
return id, nil
}
return 0, errors.New("not found")
}
func (m *mockGatewayCacheForMultiplatform) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
if m.sessionBindings == nil {
m.sessionBindings = make(map[string]int64)
}
m.sessionBindings[sessionHash] = accountID
return nil
}
func (m *mockGatewayCacheForMultiplatform) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
return nil
}
func ptr[T any](v T) *T {
return &v
}
func TestGatewayService_SelectAccountForModelWithExclusions_OnlyAnthropic(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForMultiplatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForMultiplatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID, "应选择优先级最高的账户")
}
func TestGatewayService_SelectAccountForModelWithExclusions_OnlyAntigravity(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForMultiplatform{
accounts: []Account{
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForMultiplatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID)
require.Equal(t, PlatformAntigravity, acc.Platform)
}
func TestGatewayService_SelectAccountForModelWithExclusions_MixedPlatforms_SamePriority(t *testing.T) {
ctx := context.Background()
now := time.Now()
repo := &mockAccountRepoForMultiplatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: ptr(now.Add(-1 * time.Hour))},
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: ptr(now.Add(-2 * time.Hour))},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForMultiplatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "应选择最久未用的账户Antigravity")
}
func TestGatewayService_SelectAccountForModelWithExclusions_MixedPlatforms_DiffPriority(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForMultiplatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForMultiplatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "应选择优先级更高的账户Antigravity, priority=1")
}
func TestGatewayService_SelectAccountForModelWithExclusions_ModelNotSupported(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForMultiplatform{
accounts: []Account{
// Anthropic 账户配置了模型映射,只支持 other-model
// 注意model_mapping 需要是 map[string]any 格式
{
ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true,
Credentials: map[string]any{"model_mapping": map[string]any{"other-model": "x"}},
},
// Antigravity 账户支持所有 claude 模型
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForMultiplatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "Anthropic 不支持该模型,应选择 Antigravity")
}
func TestGatewayService_SelectAccountForModelWithExclusions_NoAvailableAccounts(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForMultiplatform{
accounts: []Account{},
accountsByID: map[int64]*Account{},
}
cache := &mockGatewayCacheForMultiplatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
require.Error(t, err)
require.Nil(t, acc)
require.Contains(t, err.Error(), "no available accounts")
}
func TestGatewayService_SelectAccountForModelWithExclusions_AllExcluded(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForMultiplatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForMultiplatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
}
excludedIDs := map[int64]struct{}{1: {}, 2: {}}
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, []string{PlatformAnthropic, PlatformAntigravity})
require.Error(t, err)
require.Nil(t, acc)
}
func TestGatewayService_SelectAccountForModelWithExclusions_Schedulability(t *testing.T) {
ctx := context.Background()
now := time.Now()
tests := []struct {
name string
accounts []Account
expectedID int64
}{
{
name: "过载账户被跳过",
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, OverloadUntil: ptr(now.Add(1 * time.Hour))},
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true},
},
expectedID: 2,
},
{
name: "限流账户被跳过",
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, RateLimitResetAt: ptr(now.Add(1 * time.Hour))},
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true},
},
expectedID: 2,
},
{
name: "非active账户被跳过",
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: "error", Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true},
},
expectedID: 2,
},
{
name: "schedulable=false被跳过",
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: false},
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true},
},
expectedID: 2,
},
{
name: "过期的过载账户可调度",
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, OverloadUntil: ptr(now.Add(-1 * time.Hour))},
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true},
},
expectedID: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &mockAccountRepoForMultiplatform{
accounts: tt.accounts,
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForMultiplatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, tt.expectedID, acc.ID)
})
}
}
func TestGatewayService_SelectAccountForModelWithExclusions_StickySession(t *testing.T) {
ctx := context.Background()
t.Run("粘性会话命中", func(t *testing.T) {
repo := &mockAccountRepoForMultiplatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForMultiplatform{
sessionBindings: map[string]int64{"session-123": 1},
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID, "应返回粘性会话绑定的账户")
})
t.Run("粘性会话账户被排除-降级选择", func(t *testing.T) {
repo := &mockAccountRepoForMultiplatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForMultiplatform{
sessionBindings: map[string]int64{"session-123": 1},
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
}
excludedIDs := map[int64]struct{}{1: {}}
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", excludedIDs, []string{PlatformAnthropic, PlatformAntigravity})
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "粘性会话账户被排除,应选择其他账户")
})
t.Run("粘性会话账户不可调度-降级选择", func(t *testing.T) {
repo := &mockAccountRepoForMultiplatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: "error", Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForMultiplatform{
sessionBindings: map[string]int64{"session-123": 1},
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "粘性会话账户不可调度,应选择其他账户")
})
}
func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
svc := &GatewayService{}
tests := []struct {
name string
account *Account
model string
expected bool
}{
{
name: "Antigravity平台-支持claude模型",
account: &Account{Platform: PlatformAntigravity},
model: "claude-3-5-sonnet-20241022",
expected: true,
},
{
name: "Antigravity平台-支持gemini模型",
account: &Account{Platform: PlatformAntigravity},
model: "gemini-2.5-flash",
expected: true,
},
{
name: "Antigravity平台-不支持gpt模型",
account: &Account{Platform: PlatformAntigravity},
model: "gpt-4",
expected: false,
},
{
name: "Anthropic平台-无映射配置-支持所有模型",
account: &Account{Platform: PlatformAnthropic},
model: "claude-3-5-sonnet-20241022",
expected: true,
},
{
name: "Anthropic平台-有映射配置-只支持配置的模型",
account: &Account{
Platform: PlatformAnthropic,
Credentials: map[string]any{"model_mapping": map[string]any{"claude-opus-4": "x"}},
},
model: "claude-3-5-sonnet-20241022",
expected: false,
},
{
name: "Anthropic平台-有映射配置-支持配置的模型",
account: &Account{
Platform: PlatformAnthropic,
Credentials: map[string]any{"model_mapping": map[string]any{"claude-3-5-sonnet-20241022": "x"}},
},
model: "claude-3-5-sonnet-20241022",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.isModelSupportedByAccount(tt.account, tt.model)
require.Equal(t, tt.expected, got)
})
}
}

View File

@@ -291,6 +291,13 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 使用多平台账户选择,包含 anthropic 和 antigravity 平台
platforms := []string{PlatformAnthropic, PlatformAntigravity}
return s.selectAccountForModelWithPlatforms(ctx, groupID, sessionHash, requestedModel, excludedIDs, platforms)
}
// selectAccountForModelWithPlatforms 选择多平台账户
func (s *GatewayService) selectAccountForModelWithPlatforms(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platforms []string) (*Account, error) {
// 1. 查询粘性会话
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
@@ -298,8 +305,8 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
// 使用IsSchedulable代替IsActive确保限流/过载账号不会被选中
// 同时检查模型支持
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
// 同时检查模型支持(根据平台类型分别处理)
if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
// 续期粘性会话
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
@@ -310,13 +317,13 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
}
}
// 2. 获取可调度账号列表(排除限流和过载的账号,仅限 Anthropic 平台)
// 2. 获取可调度账号列表(排除限流和过载的账号,支持多平台)
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAnthropic)
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAnthropic)
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
@@ -329,8 +336,8 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// 检查模型支持
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
// 检查模型支持(根据平台类型分别处理)
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
if selected == nil {
@@ -374,6 +381,37 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
return selected, nil
}
// isModelSupportedByAccount 根据账户平台检查模型支持
func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {
// Antigravity 平台使用专门的模型支持检查
return IsAntigravityModelSupported(requestedModel)
}
// 其他平台使用账户的模型支持检查
return account.IsModelSupported(requestedModel)
}
// IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型
func IsAntigravityModelSupported(requestedModel string) bool {
// 直接支持的模型
if antigravitySupportedModels[requestedModel] {
return true
}
// 可映射的模型
if _, ok := antigravityModelMapping[requestedModel]; ok {
return true
}
// Gemini 前缀透传
if strings.HasPrefix(requestedModel, "gemini-") {
return true
}
// Claude 模型支持(通过默认映射到 claude-sonnet-4-5
if strings.HasPrefix(requestedModel, "claude-") {
return true
}
return false
}
// GetAccessToken 获取账号凭证
func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {

View File

@@ -33,11 +33,12 @@ const (
)
type GeminiMessagesCompatService struct {
accountRepo AccountRepository
cache GatewayCache
tokenProvider *GeminiTokenProvider
rateLimitService *RateLimitService
httpUpstream HTTPUpstream
accountRepo AccountRepository
cache GatewayCache
tokenProvider *GeminiTokenProvider
rateLimitService *RateLimitService
httpUpstream HTTPUpstream
antigravityGatewayService *AntigravityGatewayService
}
func NewGeminiMessagesCompatService(
@@ -46,13 +47,15 @@ func NewGeminiMessagesCompatService(
tokenProvider *GeminiTokenProvider,
rateLimitService *RateLimitService,
httpUpstream HTTPUpstream,
antigravityGatewayService *AntigravityGatewayService,
) *GeminiMessagesCompatService {
return &GeminiMessagesCompatService{
accountRepo: accountRepo,
cache: cache,
tokenProvider: tokenProvider,
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
accountRepo: accountRepo,
cache: cache,
tokenProvider: tokenProvider,
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
antigravityGatewayService: antigravityGatewayService,
}
}
@@ -67,12 +70,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
cacheKey := "gemini:" + sessionHash
platforms := []string{PlatformGemini, PlatformAntigravity}
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && account.IsSchedulable() && account.Platform == PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
// 支持 gemini 和 antigravity 平台的粘性会话
if err == nil && account.IsSchedulable() && (account.Platform == PlatformGemini || account.Platform == PlatformAntigravity) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
return account, nil
}
@@ -80,12 +86,13 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
}
}
// 同时查询 gemini 和 antigravity 平台的可调度账户
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini)
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini)
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
@@ -97,7 +104,8 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
// 根据平台类型分别检查模型支持
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
if selected == nil {
@@ -127,9 +135,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
if selected == nil {
if requestedModel != "" {
return nil, fmt.Errorf("no available Gemini accounts supporting model: %s", requestedModel)
return nil, fmt.Errorf("no available Gemini/Antigravity accounts supporting model: %s", requestedModel)
}
return nil, errors.New("no available Gemini accounts")
return nil, errors.New("no available Gemini/Antigravity accounts")
}
if sessionHash != "" {
@@ -139,6 +147,34 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
return selected, nil
}
// isModelSupportedByAccount 根据账户平台检查模型支持
func (s *GeminiMessagesCompatService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {
return IsAntigravityModelSupported(requestedModel)
}
return account.IsModelSupported(requestedModel)
}
// GetAntigravityGatewayService 返回 AntigravityGatewayService
func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *AntigravityGatewayService {
return s.antigravityGatewayService
}
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAntigravity)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAntigravity)
}
if err != nil {
return false, err
}
return len(accounts) > 0, nil
}
// SelectAccountForAIStudioEndpoints selects an account that is likely to succeed against
// generativelanguage.googleapis.com (e.g. GET /v1beta/models).
//

View File

@@ -0,0 +1,568 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// mockAccountRepoForGemini Gemini 测试用的 mock
type mockAccountRepoForGemini struct {
accounts []Account
accountsByID map[int64]*Account
}
func (m *mockAccountRepoForGemini) GetByID(ctx context.Context, id int64) (*Account, error) {
if acc, ok := m.accountsByID[id]; ok {
return acc, nil
}
return nil, errors.New("account not found")
}
func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
platformSet := make(map[string]bool)
for _, p := range platforms {
platformSet[p] = true
}
var result []Account
for _, acc := range m.accounts {
if platformSet[acc.Platform] && acc.IsSchedulable() {
result = append(result, acc)
}
}
return result, nil
}
func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
return m.ListSchedulableByPlatforms(ctx, platforms)
}
// Stub methods to implement AccountRepository interface
func (m *mockAccountRepoForGemini) Create(ctx context.Context, account *Account) error { return nil }
func (m *mockAccountRepoForGemini) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) {
return nil, nil
}
func (m *mockAccountRepoForGemini) Update(ctx context.Context, account *Account) error { return nil }
func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error { return nil }
func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForGemini) ListActive(ctx context.Context) ([]Account, error) { return nil, nil }
func (m *mockAccountRepoForGemini) ListByPlatform(ctx context.Context, platform string) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForGemini) UpdateLastUsed(ctx context.Context, id int64) error { return nil }
func (m *mockAccountRepoForGemini) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, errorMsg string) error {
return nil
}
func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return nil
}
func (m *mockAccountRepoForGemini) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
return nil
}
func (m *mockAccountRepoForGemini) ListSchedulable(ctx context.Context) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForGemini) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForGemini) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
var result []Account
for _, acc := range m.accounts {
if acc.Platform == platform && acc.IsSchedulable() {
result = append(result, acc)
}
}
return result, nil
}
func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
// 测试时不区分 groupID直接按 platform 过滤
return m.ListSchedulableByPlatform(ctx, platform)
}
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64) error { return nil }
func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil
}
func (m *mockAccountRepoForGemini) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
return nil
}
func (m *mockAccountRepoForGemini) BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) {
return 0, nil
}
// Verify interface implementation
var _ AccountRepository = (*mockAccountRepoForGemini)(nil)
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
type mockGatewayCacheForGemini struct {
sessionBindings map[string]int64
}
func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
if id, ok := m.sessionBindings[sessionHash]; ok {
return id, nil
}
return 0, errors.New("not found")
}
func (m *mockGatewayCacheForGemini) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
if m.sessionBindings == nil {
m.sessionBindings = make(map[string]int64)
}
m.sessionBindings[sessionHash] = accountID
return nil
}
func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
return nil
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OnlyGemini(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID, "应选择优先级最高的账户")
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OnlyAntigravity(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID)
require.Equal(t, PlatformAntigravity, acc.Platform)
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ExcludesAnthropic(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 3, Platform: PlatformAntigravity, Priority: 3, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
// Anthropic 不在 [gemini, antigravity] 平台列表中,应被过滤
require.Equal(t, int64(2), acc.ID, "Anthropic 平台应被排除,选择 Gemini")
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_MixedPlatforms_SamePriority(t *testing.T) {
ctx := context.Background()
now := time.Now()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: ptr(now.Add(-1 * time.Hour))},
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: ptr(now.Add(-2 * time.Hour))},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "应选择最久未用的账户Antigravity")
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPreferred(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Type: AccountTypeApiKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
{ID: 2, Platform: PlatformGemini, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "同优先级且都未使用时,应优先选择 OAuth 账户")
require.Equal(t, AccountTypeOAuth, acc.Type)
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPreferred_MixedPlatforms(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Type: AccountTypeApiKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
{ID: 2, Platform: PlatformAntigravity, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "跨平台时,同样优先选择 OAuth 账户")
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_NoAvailableAccounts(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{},
accountsByID: map[int64]*Account{},
}
cache := &mockGatewayCacheForGemini{}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.Error(t, err)
require.Nil(t, acc)
require.Contains(t, err.Error(), "no available Gemini/Antigravity accounts")
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickySession(t *testing.T) {
ctx := context.Background()
t.Run("粘性会话命中-使用gemini前缀缓存键", func(t *testing.T) {
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
// 注意:缓存键使用 "gemini:" 前缀
cache := &mockGatewayCacheForGemini{
sessionBindings: map[string]int64{"gemini:session-123": 1},
}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID, "应返回粘性会话绑定的账户")
})
t.Run("粘性会话不命中无前缀缓存键", func(t *testing.T) {
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
// 缓存键没有 "gemini:" 前缀,不应命中
cache := &mockGatewayCacheForGemini{
sessionBindings: map[string]int64{"session-123": 1},
}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
// 粘性会话未命中,按优先级选择
require.Equal(t, int64(2), acc.ID, "粘性会话未命中,应按优先级选择 Antigravity")
})
t.Run("粘性会话Anthropic账户-降级选择", func(t *testing.T) {
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{
sessionBindings: map[string]int64{"gemini:session-123": 1},
}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
// 粘性会话绑定的是 Anthropic 账户,不在 Gemini 平台列表中,应降级选择
require.Equal(t, int64(2), acc.ID, "粘性会话账户是 Anthropic应降级选择 Gemini 平台账户")
})
}
func TestGeminiMessagesCompatService_HasAntigravityAccounts(t *testing.T) {
ctx := context.Background()
t.Run("有antigravity账户时返回true", func(t *testing.T) {
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Status: StatusActive, Schedulable: true},
},
}
svc := &GeminiMessagesCompatService{accountRepo: repo}
has, err := svc.HasAntigravityAccounts(ctx, nil)
require.NoError(t, err)
require.True(t, has)
})
t.Run("无antigravity账户时返回false", func(t *testing.T) {
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Status: StatusActive, Schedulable: true},
},
}
svc := &GeminiMessagesCompatService{accountRepo: repo}
has, err := svc.HasAntigravityAccounts(ctx, nil)
require.NoError(t, err)
require.False(t, has)
})
t.Run("antigravity账户不可调度时返回false", func(t *testing.T) {
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformAntigravity, Status: StatusActive, Schedulable: false},
},
}
svc := &GeminiMessagesCompatService{accountRepo: repo}
has, err := svc.HasAntigravityAccounts(ctx, nil)
require.NoError(t, err)
require.False(t, has)
})
t.Run("带groupID查询", func(t *testing.T) {
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformAntigravity, Status: StatusActive, Schedulable: true},
},
}
svc := &GeminiMessagesCompatService{accountRepo: repo}
groupID := int64(1)
has, err := svc.HasAntigravityAccounts(ctx, &groupID)
require.NoError(t, err)
require.True(t, has)
})
}
// TestGeminiPlatformRouting_DocumentRouteDecision 测试平台路由决策逻辑
// 该测试文档化了 Handler 层应该如何根据 account.Platform 进行分流
func TestGeminiPlatformRouting_DocumentRouteDecision(t *testing.T) {
tests := []struct {
name string
platform string
expectedService string // "gemini" 表示 ForwardNative, "antigravity" 表示 ForwardGemini
}{
{
name: "Gemini平台走ForwardNative",
platform: PlatformGemini,
expectedService: "gemini",
},
{
name: "Antigravity平台走ForwardGemini",
platform: PlatformAntigravity,
expectedService: "antigravity",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{Platform: tt.platform}
// 模拟 Handler 层的路由逻辑
var serviceName string
if account.Platform == PlatformAntigravity {
serviceName = "antigravity"
} else {
serviceName = "gemini"
}
require.Equal(t, tt.expectedService, serviceName,
"平台 %s 应该路由到 %s 服务", tt.platform, tt.expectedService)
})
}
}
func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) {
svc := &GeminiMessagesCompatService{}
tests := []struct {
name string
account *Account
model string
expected bool
}{
{
name: "Antigravity平台-支持gemini模型",
account: &Account{Platform: PlatformAntigravity},
model: "gemini-2.5-flash",
expected: true,
},
{
name: "Antigravity平台-支持claude模型",
account: &Account{Platform: PlatformAntigravity},
model: "claude-3-5-sonnet-20241022",
expected: true,
},
{
name: "Antigravity平台-不支持gpt模型",
account: &Account{Platform: PlatformAntigravity},
model: "gpt-4",
expected: false,
},
{
name: "Gemini平台-无映射配置-支持所有模型",
account: &Account{Platform: PlatformGemini},
model: "gemini-2.5-flash",
expected: true,
},
{
name: "Gemini平台-有映射配置-只支持配置的模型",
account: &Account{
Platform: PlatformGemini,
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-1.5-pro": "x"}},
},
model: "gemini-2.5-flash",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.isModelSupportedByAccount(tt.account, tt.model)
require.Equal(t, tt.expected, got)
})
}
}

View File

@@ -27,6 +27,7 @@ func NewTokenRefreshService(
oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
antigravityOAuthService *AntigravityOAuthService,
cfg *config.Config,
) *TokenRefreshService {
s := &TokenRefreshService{
@@ -40,6 +41,7 @@ func NewTokenRefreshService(
NewClaudeTokenRefresher(oauthService),
NewOpenAITokenRefresher(openaiOAuthService),
NewGeminiTokenRefresher(geminiOAuthService),
NewAntigravityTokenRefresher(antigravityOAuthService),
}
return s

View File

@@ -39,9 +39,10 @@ func ProvideTokenRefreshService(
oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
antigravityOAuthService *AntigravityOAuthService,
cfg *config.Config,
) *TokenRefreshService {
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, cfg)
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cfg)
svc.Start()
return svc
}
@@ -84,6 +85,8 @@ var ProviderSet = wire.NewSet(
NewAntigravityOAuthService,
NewGeminiTokenProvider,
NewGeminiMessagesCompatService,
NewAntigravityTokenProvider,
NewAntigravityGatewayService,
NewRateLimitService,
NewAccountUsageService,
NewAccountTestService,

View File

@@ -62,6 +62,10 @@ const filteredGroups = computed(() => {
if (!props.platform) {
return props.groups
}
// antigravity 账户可选择 anthropic 和 gemini 平台的分组
if (props.platform === 'antigravity') {
return props.groups.filter((g) => g.platform === 'anthropic' || g.platform === 'gemini')
}
return props.groups.filter((g) => g.platform === props.platform)
})