feat: 完善 Antigravity 多平台网关支持,修复 Gemini handler 分流逻辑
This commit is contained in:
@@ -122,8 +122,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
timingWheelService := service.ProvideTimingWheelService()
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, userService, concurrencyService, billingCacheService)
|
||||
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
|
||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
|
||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||
@@ -133,7 +135,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService)
|
||||
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService)
|
||||
httpServer := server.ProvideHTTPServer(configConfig, engine)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
|
||||
v := provideCleanup(db, client, tokenRefreshService, pricingService, emailQueueService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
|
||||
@@ -21,27 +21,30 @@ import (
|
||||
|
||||
// GatewayHandler handles API gateway requests
|
||||
type GatewayHandler struct {
|
||||
gatewayService *service.GatewayService
|
||||
geminiCompatService *service.GeminiMessagesCompatService
|
||||
userService *service.UserService
|
||||
billingCacheService *service.BillingCacheService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
gatewayService *service.GatewayService
|
||||
geminiCompatService *service.GeminiMessagesCompatService
|
||||
antigravityGatewayService *service.AntigravityGatewayService
|
||||
userService *service.UserService
|
||||
billingCacheService *service.BillingCacheService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
}
|
||||
|
||||
// NewGatewayHandler creates a new GatewayHandler
|
||||
func NewGatewayHandler(
|
||||
gatewayService *service.GatewayService,
|
||||
geminiCompatService *service.GeminiMessagesCompatService,
|
||||
antigravityGatewayService *service.AntigravityGatewayService,
|
||||
userService *service.UserService,
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
billingCacheService *service.BillingCacheService,
|
||||
) *GatewayHandler {
|
||||
return &GatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
geminiCompatService: geminiCompatService,
|
||||
userService: userService,
|
||||
billingCacheService: billingCacheService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude),
|
||||
gatewayService: gatewayService,
|
||||
geminiCompatService: geminiCompatService,
|
||||
antigravityGatewayService: antigravityGatewayService,
|
||||
userService: userService,
|
||||
billingCacheService: billingCacheService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,8 +166,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 转发请求
|
||||
result, err := h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
|
||||
// 转发请求 - 根据账号平台分流
|
||||
var result *service.ForwardResult
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, req.Model, "generateContent", req.Stream, body)
|
||||
} else {
|
||||
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
|
||||
}
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
@@ -240,8 +248,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 转发请求
|
||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
// 转发请求 - 根据账号平台分流
|
||||
var result *service.ForwardResult
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
} else {
|
||||
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
}
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
|
||||
@@ -32,6 +32,13 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
|
||||
|
||||
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
|
||||
if err != nil {
|
||||
// 没有 gemini 账户,检查是否有 antigravity 账户可用
|
||||
hasAntigravity, _ := h.geminiCompatService.HasAntigravityAccounts(c.Request.Context(), apiKey.GroupID)
|
||||
if hasAntigravity {
|
||||
// antigravity 账户使用静态模型列表
|
||||
c.JSON(http.StatusOK, gemini.FallbackModelsList())
|
||||
return
|
||||
}
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
@@ -69,6 +76,13 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
|
||||
|
||||
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
|
||||
if err != nil {
|
||||
// 没有 gemini 账户,检查是否有 antigravity 账户可用
|
||||
hasAntigravity, _ := h.geminiCompatService.HasAntigravityAccounts(c.Request.Context(), apiKey.GroupID)
|
||||
if hasAntigravity {
|
||||
// antigravity 账户使用静态模型信息
|
||||
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
|
||||
return
|
||||
}
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
@@ -182,8 +196,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 5) forward (writes response to client)
|
||||
result, err := h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body)
|
||||
// 5) forward (根据平台分流)
|
||||
var result *service.ForwardResult
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, modelName, action, stream, body)
|
||||
} else {
|
||||
result, err = h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body)
|
||||
}
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
|
||||
143
backend/internal/handler/gemini_v1beta_handler_test.go
Normal file
143
backend/internal/handler/gemini_v1beta_handler_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
//go:build unit
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestGeminiV1BetaHandler_PlatformRoutingInvariant 文档化并验证 Handler 层的平台路由逻辑不变量
|
||||
// 该测试确保 gemini 和 antigravity 平台的路由逻辑符合预期
|
||||
func TestGeminiV1BetaHandler_PlatformRoutingInvariant(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
platform string
|
||||
expectedService string
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Gemini平台使用ForwardNative",
|
||||
platform: service.PlatformGemini,
|
||||
expectedService: "GeminiMessagesCompatService.ForwardNative",
|
||||
description: "Gemini OAuth 账户直接调用 Google API",
|
||||
},
|
||||
{
|
||||
name: "Antigravity平台使用ForwardGemini",
|
||||
platform: service.PlatformAntigravity,
|
||||
expectedService: "AntigravityGatewayService.ForwardGemini",
|
||||
description: "Antigravity 账户通过 CRS 中转,支持 Gemini 协议",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 模拟 GeminiV1BetaModels 中的路由决策 (lines 199-205 in gemini_v1beta_handler.go)
|
||||
var routedService string
|
||||
if tt.platform == service.PlatformAntigravity {
|
||||
routedService = "AntigravityGatewayService.ForwardGemini"
|
||||
} else {
|
||||
routedService = "GeminiMessagesCompatService.ForwardNative"
|
||||
}
|
||||
|
||||
require.Equal(t, tt.expectedService, routedService,
|
||||
"平台 %s 应该路由到 %s: %s",
|
||||
tt.platform, tt.expectedService, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGeminiV1BetaHandler_ListModelsAntigravityFallback 验证 ListModels 的 antigravity 降级逻辑
|
||||
// 当没有 gemini 账户但有 antigravity 账户时,应返回静态模型列表
|
||||
func TestGeminiV1BetaHandler_ListModelsAntigravityFallback(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hasGeminiAccount bool
|
||||
hasAntigravity bool
|
||||
expectedBehavior string
|
||||
}{
|
||||
{
|
||||
name: "有Gemini账户-调用ForwardAIStudioGET",
|
||||
hasGeminiAccount: true,
|
||||
hasAntigravity: false,
|
||||
expectedBehavior: "forward_to_upstream",
|
||||
},
|
||||
{
|
||||
name: "无Gemini有Antigravity-返回静态列表",
|
||||
hasGeminiAccount: false,
|
||||
hasAntigravity: true,
|
||||
expectedBehavior: "static_fallback",
|
||||
},
|
||||
{
|
||||
name: "无任何账户-返回503",
|
||||
hasGeminiAccount: false,
|
||||
hasAntigravity: false,
|
||||
expectedBehavior: "service_unavailable",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 模拟 GeminiV1BetaListModels 的逻辑 (lines 33-44 in gemini_v1beta_handler.go)
|
||||
var behavior string
|
||||
|
||||
if tt.hasGeminiAccount {
|
||||
behavior = "forward_to_upstream"
|
||||
} else if tt.hasAntigravity {
|
||||
behavior = "static_fallback"
|
||||
} else {
|
||||
behavior = "service_unavailable"
|
||||
}
|
||||
|
||||
require.Equal(t, tt.expectedBehavior, behavior)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGeminiV1BetaHandler_GetModelAntigravityFallback 验证 GetModel 的 antigravity 降级逻辑
|
||||
func TestGeminiV1BetaHandler_GetModelAntigravityFallback(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hasGeminiAccount bool
|
||||
hasAntigravity bool
|
||||
expectedBehavior string
|
||||
}{
|
||||
{
|
||||
name: "有Gemini账户-调用ForwardAIStudioGET",
|
||||
hasGeminiAccount: true,
|
||||
hasAntigravity: false,
|
||||
expectedBehavior: "forward_to_upstream",
|
||||
},
|
||||
{
|
||||
name: "无Gemini有Antigravity-返回静态模型信息",
|
||||
hasGeminiAccount: false,
|
||||
hasAntigravity: true,
|
||||
expectedBehavior: "static_model_info",
|
||||
},
|
||||
{
|
||||
name: "无任何账户-返回503",
|
||||
hasGeminiAccount: false,
|
||||
hasAntigravity: false,
|
||||
expectedBehavior: "service_unavailable",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 模拟 GeminiV1BetaGetModel 的逻辑 (lines 77-87 in gemini_v1beta_handler.go)
|
||||
var behavior string
|
||||
|
||||
if tt.hasGeminiAccount {
|
||||
behavior = "forward_to_upstream"
|
||||
} else if tt.hasAntigravity {
|
||||
behavior = "static_model_info"
|
||||
} else {
|
||||
behavior = "service_unavailable"
|
||||
}
|
||||
|
||||
require.Equal(t, tt.expectedBehavior, behavior)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -337,6 +337,56 @@ func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Cont
|
||||
return outAccounts, nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
|
||||
if len(platforms) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
var accounts []accountModel
|
||||
now := time.Now()
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("platform IN ?", platforms).
|
||||
Where("status = ? AND schedulable = ?", service.StatusActive, true).
|
||||
Where("(overload_until IS NULL OR overload_until <= ?)", now).
|
||||
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
|
||||
Preload("Proxy").
|
||||
Order("priority ASC").
|
||||
Find(&accounts).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outAccounts := make([]service.Account, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
|
||||
}
|
||||
return outAccounts, nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
|
||||
if len(platforms) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
var accounts []accountModel
|
||||
now := time.Now()
|
||||
err := r.db.WithContext(ctx).
|
||||
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
|
||||
Where("account_groups.group_id = ?", groupID).
|
||||
Where("accounts.platform IN ?", platforms).
|
||||
Where("accounts.status = ? AND accounts.schedulable = ?", service.StatusActive, true).
|
||||
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
|
||||
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
|
||||
Preload("Proxy").
|
||||
Order("account_groups.priority ASC, accounts.priority ASC").
|
||||
Find(&accounts).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outAccounts := make([]service.Account, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
|
||||
}
|
||||
return outAccounts, nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
now := time.Now()
|
||||
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
|
||||
|
||||
250
backend/internal/repository/gateway_routing_integration_test.go
Normal file
250
backend/internal/repository/gateway_routing_integration_test.go
Normal file
@@ -0,0 +1,250 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/datatypes"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GatewayRoutingSuite 测试网关路由相关的数据库查询
|
||||
// 验证账户选择和分流逻辑在真实数据库环境下的行为
|
||||
type GatewayRoutingSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
accountRepo *accountRepository
|
||||
}
|
||||
|
||||
func (s *GatewayRoutingSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.accountRepo = NewAccountRepository(s.db).(*accountRepository)
|
||||
}
|
||||
|
||||
func TestGatewayRoutingSuite(t *testing.T) {
|
||||
suite.Run(t, new(GatewayRoutingSuite))
|
||||
}
|
||||
|
||||
// TestListSchedulableByPlatforms_GeminiAndAntigravity 验证多平台账户查询
|
||||
func (s *GatewayRoutingSuite) TestListSchedulableByPlatforms_GeminiAndAntigravity() {
|
||||
// 创建各平台账户
|
||||
geminiAcc := mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
Name: "gemini-oauth",
|
||||
Platform: service.PlatformGemini,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Priority: 1,
|
||||
})
|
||||
|
||||
antigravityAcc := mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
Name: "antigravity-oauth",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Priority: 2,
|
||||
Credentials: datatypes.JSONMap{
|
||||
"access_token": "test-token",
|
||||
"refresh_token": "test-refresh",
|
||||
"project_id": "test-project",
|
||||
},
|
||||
})
|
||||
|
||||
// 创建不应被选中的 anthropic 账户
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
Name: "anthropic-oauth",
|
||||
Platform: service.PlatformAnthropic,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Priority: 0,
|
||||
})
|
||||
|
||||
// 查询 gemini + antigravity 平台
|
||||
accounts, err := s.accountRepo.ListSchedulableByPlatforms(s.ctx, []string{
|
||||
service.PlatformGemini,
|
||||
service.PlatformAntigravity,
|
||||
})
|
||||
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 2, "应返回 gemini 和 antigravity 两个账户")
|
||||
|
||||
// 验证返回的账户平台
|
||||
platforms := make(map[string]bool)
|
||||
for _, acc := range accounts {
|
||||
platforms[acc.Platform] = true
|
||||
}
|
||||
s.Require().True(platforms[service.PlatformGemini], "应包含 gemini 账户")
|
||||
s.Require().True(platforms[service.PlatformAntigravity], "应包含 antigravity 账户")
|
||||
s.Require().False(platforms[service.PlatformAnthropic], "不应包含 anthropic 账户")
|
||||
|
||||
// 验证账户 ID 匹配
|
||||
ids := make(map[int64]bool)
|
||||
for _, acc := range accounts {
|
||||
ids[acc.ID] = true
|
||||
}
|
||||
s.Require().True(ids[geminiAcc.ID])
|
||||
s.Require().True(ids[antigravityAcc.ID])
|
||||
}
|
||||
|
||||
// TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding 验证按分组过滤
|
||||
func (s *GatewayRoutingSuite) TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding() {
|
||||
// 创建 gemini 分组
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{
|
||||
Name: "gemini-group",
|
||||
Platform: service.PlatformGemini,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
// 创建账户
|
||||
boundAcc := mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
Name: "bound-antigravity",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
unboundAcc := mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
Name: "unbound-antigravity",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
// 只绑定一个账户到分组
|
||||
mustBindAccountToGroup(s.T(), s.db, boundAcc.ID, group.ID, 1)
|
||||
|
||||
// 查询分组内的账户
|
||||
accounts, err := s.accountRepo.ListSchedulableByGroupIDAndPlatforms(s.ctx, group.ID, []string{
|
||||
service.PlatformGemini,
|
||||
service.PlatformAntigravity,
|
||||
})
|
||||
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 1, "应只返回绑定到分组的账户")
|
||||
s.Require().Equal(boundAcc.ID, accounts[0].ID)
|
||||
|
||||
// 确认未绑定的账户不在结果中
|
||||
for _, acc := range accounts {
|
||||
s.Require().NotEqual(unboundAcc.ID, acc.ID, "不应包含未绑定的账户")
|
||||
}
|
||||
}
|
||||
|
||||
// TestListSchedulableByPlatform_Antigravity 验证单平台查询
|
||||
func (s *GatewayRoutingSuite) TestListSchedulableByPlatform_Antigravity() {
|
||||
// 创建多种平台账户
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
Name: "gemini-1",
|
||||
Platform: service.PlatformGemini,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
antigravity := mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
Name: "antigravity-1",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
// 只查询 antigravity 平台
|
||||
accounts, err := s.accountRepo.ListSchedulableByPlatform(s.ctx, service.PlatformAntigravity)
|
||||
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().Equal(antigravity.ID, accounts[0].ID)
|
||||
s.Require().Equal(service.PlatformAntigravity, accounts[0].Platform)
|
||||
}
|
||||
|
||||
// TestSchedulableFilter_ExcludesInactive 验证不可调度账户被过滤
|
||||
func (s *GatewayRoutingSuite) TestSchedulableFilter_ExcludesInactive() {
|
||||
// 创建可调度账户
|
||||
activeAcc := mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
Name: "active-antigravity",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
// 创建不可调度账户(需要先创建再更新,因为 fixture 默认设置 Schedulable=true)
|
||||
inactiveAcc := mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
Name: "inactive-antigravity",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
s.Require().NoError(s.db.Model(&accountModel{}).Where("id = ?", inactiveAcc.ID).Update("schedulable", false).Error)
|
||||
|
||||
// 创建错误状态账户
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
Name: "error-antigravity",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusError,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
accounts, err := s.accountRepo.ListSchedulableByPlatform(s.ctx, service.PlatformAntigravity)
|
||||
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 1, "应只返回可调度的 active 账户")
|
||||
s.Require().Equal(activeAcc.ID, accounts[0].ID)
|
||||
}
|
||||
|
||||
// TestPlatformRoutingDecision 验证平台路由决策
|
||||
// 这个测试模拟 Handler 层在选择账户后的路由决策逻辑
|
||||
func (s *GatewayRoutingSuite) TestPlatformRoutingDecision() {
|
||||
// 创建两种平台的账户
|
||||
geminiAcc := mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
Name: "gemini-route-test",
|
||||
Platform: service.PlatformGemini,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
antigravityAcc := mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
Name: "antigravity-route-test",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID int64
|
||||
expectedService string
|
||||
}{
|
||||
{
|
||||
name: "Gemini账户路由到ForwardNative",
|
||||
accountID: geminiAcc.ID,
|
||||
expectedService: "GeminiMessagesCompatService.ForwardNative",
|
||||
},
|
||||
{
|
||||
name: "Antigravity账户路由到ForwardGemini",
|
||||
accountID: antigravityAcc.ID,
|
||||
expectedService: "AntigravityGatewayService.ForwardGemini",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
// 从数据库获取账户
|
||||
account, err := s.accountRepo.GetByID(s.ctx, tt.accountID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// 模拟 Handler 层的路由决策
|
||||
var routedService string
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
routedService = "AntigravityGatewayService.ForwardGemini"
|
||||
} else {
|
||||
routedService = "GeminiMessagesCompatService.ForwardNative"
|
||||
}
|
||||
|
||||
s.Require().Equal(tt.expectedService, routedService)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -38,6 +38,8 @@ type AccountRepository interface {
|
||||
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error)
|
||||
ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error)
|
||||
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error)
|
||||
ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
|
||||
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
|
||||
|
||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
||||
|
||||
845
backend/internal/service/antigravity_gateway_service.go
Normal file
845
backend/internal/service/antigravity_gateway_service.go
Normal file
@@ -0,0 +1,845 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
antigravityStickySessionTTL = time.Hour
|
||||
antigravityMaxRetries = 5
|
||||
antigravityRetryBaseDelay = 1 * time.Second
|
||||
antigravityRetryMaxDelay = 16 * time.Second
|
||||
)
|
||||
|
||||
// Antigravity 直接支持的模型
|
||||
var antigravitySupportedModels = map[string]bool{
|
||||
"claude-opus-4-5-thinking": true,
|
||||
"claude-sonnet-4-5": true,
|
||||
"claude-sonnet-4-5-thinking": true,
|
||||
"gemini-2.5-flash": true,
|
||||
"gemini-2.5-flash-lite": true,
|
||||
"gemini-2.5-flash-thinking": true,
|
||||
"gemini-3-flash": true,
|
||||
"gemini-3-pro-low": true,
|
||||
"gemini-3-pro-high": true,
|
||||
"gemini-3-pro-preview": true,
|
||||
"gemini-3-pro-image": true,
|
||||
}
|
||||
|
||||
// Antigravity 系统默认模型映射表(不支持 → 支持)
|
||||
var antigravityModelMapping = map[string]string{
|
||||
"claude-3-5-sonnet-20241022": "claude-sonnet-4-5",
|
||||
"claude-3-5-sonnet-20240620": "claude-sonnet-4-5",
|
||||
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5-thinking",
|
||||
"claude-opus-4": "claude-opus-4-5-thinking",
|
||||
"claude-opus-4-5-20251101": "claude-opus-4-5-thinking",
|
||||
"claude-haiku-4": "claude-sonnet-4-5",
|
||||
"claude-3-haiku-20240307": "claude-sonnet-4-5",
|
||||
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
|
||||
}
|
||||
|
||||
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
|
||||
type AntigravityGatewayService struct {
|
||||
accountRepo AccountRepository
|
||||
cache GatewayCache
|
||||
tokenProvider *AntigravityTokenProvider
|
||||
rateLimitService *RateLimitService
|
||||
httpUpstream HTTPUpstream
|
||||
}
|
||||
|
||||
func NewAntigravityGatewayService(
|
||||
accountRepo AccountRepository,
|
||||
cache GatewayCache,
|
||||
tokenProvider *AntigravityTokenProvider,
|
||||
rateLimitService *RateLimitService,
|
||||
httpUpstream HTTPUpstream,
|
||||
) *AntigravityGatewayService {
|
||||
return &AntigravityGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
cache: cache,
|
||||
tokenProvider: tokenProvider,
|
||||
rateLimitService: rateLimitService,
|
||||
httpUpstream: httpUpstream,
|
||||
}
|
||||
}
|
||||
|
||||
// GetTokenProvider 返回 token provider
|
||||
func (s *AntigravityGatewayService) GetTokenProvider() *AntigravityTokenProvider {
|
||||
return s.tokenProvider
|
||||
}
|
||||
|
||||
// getMappedModel 获取映射后的模型名
|
||||
func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedModel string) string {
|
||||
// 1. 优先使用账户级映射(复用现有方法)
|
||||
if mapped := account.GetMappedModel(requestedModel); mapped != requestedModel {
|
||||
return mapped
|
||||
}
|
||||
|
||||
// 2. 系统默认映射
|
||||
if mapped, ok := antigravityModelMapping[requestedModel]; ok {
|
||||
return mapped
|
||||
}
|
||||
|
||||
// 3. Gemini 模型透传
|
||||
if strings.HasPrefix(requestedModel, "gemini-") {
|
||||
return requestedModel
|
||||
}
|
||||
|
||||
// 4. Claude 前缀透传直接支持的模型
|
||||
if antigravitySupportedModels[requestedModel] {
|
||||
return requestedModel
|
||||
}
|
||||
|
||||
// 5. 默认值
|
||||
return "claude-sonnet-4-5"
|
||||
}
|
||||
|
||||
// IsModelSupported 检查模型是否被支持
|
||||
func (s *AntigravityGatewayService) IsModelSupported(requestedModel string) bool {
|
||||
// 直接支持的模型
|
||||
if antigravitySupportedModels[requestedModel] {
|
||||
return true
|
||||
}
|
||||
// 可映射的模型
|
||||
if _, ok := antigravityModelMapping[requestedModel]; ok {
|
||||
return true
|
||||
}
|
||||
// Gemini 前缀透传
|
||||
if strings.HasPrefix(requestedModel, "gemini-") {
|
||||
return true
|
||||
}
|
||||
// Claude 模型支持(通过默认映射)
|
||||
if strings.HasPrefix(requestedModel, "claude-") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// wrapV1InternalRequest 包装请求为 v1internal 格式
|
||||
func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model string, originalBody []byte) ([]byte, error) {
|
||||
var request any
|
||||
if err := json.Unmarshal(originalBody, &request); err != nil {
|
||||
return nil, fmt.Errorf("解析请求体失败: %w", err)
|
||||
}
|
||||
|
||||
wrapped := map[string]any{
|
||||
"project": projectID,
|
||||
"requestId": "agent-" + uuid.New().String(),
|
||||
"userAgent": "sub2api",
|
||||
"requestType": "agent",
|
||||
"model": model,
|
||||
"request": request,
|
||||
}
|
||||
|
||||
return json.Marshal(wrapped)
|
||||
}
|
||||
|
||||
// unwrapV1InternalResponse 解包 v1internal 响应
|
||||
func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byte, error) {
|
||||
var outer map[string]any
|
||||
if err := json.Unmarshal(body, &outer); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp, ok := outer["response"]; ok {
|
||||
return json.Marshal(resp)
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// unwrapSSELine 解包 SSE 行中的 v1internal 响应
|
||||
func (s *AntigravityGatewayService) unwrapSSELine(line string) string {
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
return line
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "" || data == "[DONE]" {
|
||||
return line
|
||||
}
|
||||
|
||||
var outer map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &outer); err != nil {
|
||||
return line
|
||||
}
|
||||
|
||||
if resp, ok := outer["response"]; ok {
|
||||
unwrapped, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
return "data: " + string(unwrapped)
|
||||
}
|
||||
|
||||
return line
|
||||
}
|
||||
|
||||
// Forward 转发 Claude 协议请求
|
||||
func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 解析请求获取 model 和 stream
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, fmt.Errorf("parse request: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(req.Model) == "" {
|
||||
return nil, fmt.Errorf("missing model")
|
||||
}
|
||||
|
||||
originalModel := req.Model
|
||||
mappedModel := s.getMappedModel(account, req.Model)
|
||||
if mappedModel != req.Model {
|
||||
log.Printf("Antigravity model mapping: %s -> %s (account: %s)", req.Model, mappedModel, account.Name)
|
||||
}
|
||||
|
||||
// 获取 access_token
|
||||
if s.tokenProvider == nil {
|
||||
return nil, errors.New("antigravity token provider not configured")
|
||||
}
|
||||
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 access_token 失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取 project_id
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if projectID == "" {
|
||||
return nil, errors.New("project_id not found in credentials")
|
||||
}
|
||||
|
||||
// 代理 URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 包装请求
|
||||
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 构建上游 URL
|
||||
action := "generateContent"
|
||||
if req.Stream {
|
||||
action = "streamGenerateContent"
|
||||
}
|
||||
fullURL := fmt.Sprintf("%s/v1internal:%s", antigravity.BaseURL, action)
|
||||
if req.Stream {
|
||||
fullURL += "?alt=sse"
|
||||
}
|
||||
|
||||
// 重试循环
|
||||
var resp *http.Response
|
||||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
||||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
upstreamReq.Header.Set("Content-Type", "application/json")
|
||||
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
upstreamReq.Header.Set("User-Agent", antigravity.UserAgent)
|
||||
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
|
||||
if err != nil {
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("Antigravity account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, antigravityMaxRetries, err)
|
||||
sleepAntigravityBackoff(attempt)
|
||||
continue
|
||||
}
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == 429 {
|
||||
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("Antigravity account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, antigravityMaxRetries)
|
||||
sleepAntigravityBackoff(attempt)
|
||||
continue
|
||||
}
|
||||
// 最后一次尝试也失败
|
||||
resp = &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 处理错误响应
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
|
||||
return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody)
|
||||
}
|
||||
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
if requestID != "" {
|
||||
c.Header("x-request-id", requestID)
|
||||
}
|
||||
|
||||
var usage *ClaudeUsage
|
||||
var firstTokenMs *int
|
||||
if req.Stream {
|
||||
streamRes, err := s.handleStreamingResponse(c, resp, startTime, originalModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usage = streamRes.usage
|
||||
firstTokenMs = streamRes.firstTokenMs
|
||||
} else {
|
||||
usage, err = s.handleNonStreamingResponse(c, resp, originalModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
Model: originalModel, // 使用原始模型用于计费和日志
|
||||
Stream: req.Stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ForwardGemini 转发 Gemini 协议请求
|
||||
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
if strings.TrimSpace(originalModel) == "" {
|
||||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing model in URL")
|
||||
}
|
||||
if strings.TrimSpace(action) == "" {
|
||||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing action in URL")
|
||||
}
|
||||
if len(body) == 0 {
|
||||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "generateContent", "streamGenerateContent", "countTokens":
|
||||
// ok
|
||||
default:
|
||||
return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action)
|
||||
}
|
||||
|
||||
mappedModel := s.getMappedModel(account, originalModel)
|
||||
|
||||
// 获取 access_token
|
||||
if s.tokenProvider == nil {
|
||||
return nil, errors.New("antigravity token provider not configured")
|
||||
}
|
||||
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 access_token 失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取 project_id
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if projectID == "" {
|
||||
return nil, errors.New("project_id not found in credentials")
|
||||
}
|
||||
|
||||
// 代理 URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 包装请求
|
||||
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 构建上游 URL
|
||||
upstreamAction := action
|
||||
if action == "generateContent" && stream {
|
||||
upstreamAction = "streamGenerateContent"
|
||||
}
|
||||
fullURL := fmt.Sprintf("%s/v1internal:%s", antigravity.BaseURL, upstreamAction)
|
||||
if stream || upstreamAction == "streamGenerateContent" {
|
||||
fullURL += "?alt=sse"
|
||||
}
|
||||
|
||||
// 重试循环
|
||||
var resp *http.Response
|
||||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
||||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
upstreamReq.Header.Set("Content-Type", "application/json")
|
||||
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
upstreamReq.Header.Set("User-Agent", antigravity.UserAgent)
|
||||
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
|
||||
if err != nil {
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("Antigravity account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, antigravityMaxRetries, err)
|
||||
sleepAntigravityBackoff(attempt)
|
||||
continue
|
||||
}
|
||||
if action == "countTokens" {
|
||||
estimated := estimateGeminiCountTokens(body)
|
||||
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
|
||||
return &ForwardResult{
|
||||
RequestID: "",
|
||||
Usage: ClaudeUsage{},
|
||||
Model: originalModel,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: nil,
|
||||
}, nil
|
||||
}
|
||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == 429 {
|
||||
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("Antigravity account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, antigravityMaxRetries)
|
||||
sleepAntigravityBackoff(attempt)
|
||||
continue
|
||||
}
|
||||
if action == "countTokens" {
|
||||
estimated := estimateGeminiCountTokens(body)
|
||||
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
|
||||
return &ForwardResult{
|
||||
RequestID: "",
|
||||
Usage: ClaudeUsage{},
|
||||
Model: originalModel,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: nil,
|
||||
}, nil
|
||||
}
|
||||
resp = &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
if requestID != "" {
|
||||
c.Header("x-request-id", requestID)
|
||||
}
|
||||
|
||||
// 处理错误响应
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
|
||||
if action == "countTokens" {
|
||||
estimated := estimateGeminiCountTokens(body)
|
||||
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: ClaudeUsage{},
|
||||
Model: originalModel,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: nil,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
|
||||
// 解包并返回错误
|
||||
unwrapped, _ := s.unwrapV1InternalResponse(respBody)
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
c.Data(resp.StatusCode, contentType, unwrapped)
|
||||
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var usage *ClaudeUsage
|
||||
var firstTokenMs *int
|
||||
|
||||
if stream || upstreamAction == "streamGenerateContent" {
|
||||
streamRes, err := s.handleGeminiStreamingResponse(c, resp, startTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usage = streamRes.usage
|
||||
firstTokenMs = streamRes.firstTokenMs
|
||||
} else {
|
||||
usageResp, err := s.handleGeminiNonStreamingResponse(c, resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usage = usageResp
|
||||
}
|
||||
|
||||
if usage == nil {
|
||||
usage = &ClaudeUsage{}
|
||||
}
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
Stream: stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) shouldRetryUpstreamError(statusCode int) bool {
|
||||
switch statusCode {
|
||||
case 429, 500, 502, 503, 504, 529:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
||||
switch statusCode {
|
||||
case 401, 403, 429, 529:
|
||||
return true
|
||||
default:
|
||||
return statusCode >= 500
|
||||
}
|
||||
}
|
||||
|
||||
func sleepAntigravityBackoff(attempt int) {
|
||||
sleepGeminiBackoff(attempt) // 复用 Gemini 的退避逻辑
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) {
|
||||
if s.rateLimitService == nil {
|
||||
return
|
||||
}
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
|
||||
}
|
||||
|
||||
type antigravityStreamResult struct {
|
||||
usage *ClaudeUsage
|
||||
firstTokenMs *int
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) handleStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
c.Status(http.StatusOK)
|
||||
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, errors.New("streaming not supported")
|
||||
}
|
||||
|
||||
usage := &ClaudeUsage{}
|
||||
var firstTokenMs *int
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return nil, fmt.Errorf("stream read error: %w", err)
|
||||
}
|
||||
|
||||
if len(line) > 0 {
|
||||
// 解包 v1internal 响应
|
||||
unwrapped := s.unwrapSSELine(strings.TrimRight(line, "\r\n"))
|
||||
|
||||
// 解析 usage
|
||||
if strings.HasPrefix(unwrapped, "data: ") {
|
||||
data := strings.TrimPrefix(unwrapped, "data: ")
|
||||
if data != "" && data != "[DONE]" {
|
||||
if firstTokenMs == nil {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
s.parseClaudeSSEUsage(data, usage)
|
||||
}
|
||||
}
|
||||
|
||||
// 写入响应
|
||||
if _, writeErr := fmt.Fprintf(c.Writer, "%s\n", unwrapped); writeErr != nil {
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, writeErr
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) handleNonStreamingResponse(c *gin.Context, resp *http.Response, originalModel string) (*ClaudeUsage, error) {
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
|
||||
if err != nil {
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response")
|
||||
}
|
||||
|
||||
// 解包 v1internal 响应
|
||||
unwrapped, err := s.unwrapV1InternalResponse(body)
|
||||
if err != nil {
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
|
||||
}
|
||||
|
||||
// 解析 usage
|
||||
var respObj struct {
|
||||
Usage ClaudeUsage `json:"usage"`
|
||||
}
|
||||
_ = json.Unmarshal(unwrapped, &respObj)
|
||||
|
||||
c.Data(http.StatusOK, "application/json", unwrapped)
|
||||
return &respObj.Usage, nil
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) {
|
||||
c.Status(resp.StatusCode)
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "text/event-stream; charset=utf-8"
|
||||
}
|
||||
c.Header("Content-Type", contentType)
|
||||
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, errors.New("streaming not supported")
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
usage := &ClaudeUsage{}
|
||||
var firstTokenMs *int
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if len(line) > 0 {
|
||||
trimmed := strings.TrimRight(line, "\r\n")
|
||||
if strings.HasPrefix(trimmed, "data:") {
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
|
||||
if payload == "" || payload == "[DONE]" {
|
||||
_, _ = io.WriteString(c.Writer, line)
|
||||
flusher.Flush()
|
||||
} else {
|
||||
// 解包 v1internal 响应
|
||||
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
|
||||
if parseErr == nil && inner != nil {
|
||||
payload = string(inner)
|
||||
}
|
||||
|
||||
// 解析 usage
|
||||
var parsed map[string]any
|
||||
if json.Unmarshal(inner, &parsed) == nil {
|
||||
if u := extractGeminiUsage(parsed); u != nil {
|
||||
usage = u
|
||||
}
|
||||
}
|
||||
|
||||
if firstTokenMs == nil {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", payload)
|
||||
flusher.Flush()
|
||||
}
|
||||
} else {
|
||||
_, _ = io.WriteString(c.Writer, line)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) handleGeminiNonStreamingResponse(c *gin.Context, resp *http.Response) (*ClaudeUsage, error) {
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 解包 v1internal 响应
|
||||
unwrapped, _ := s.unwrapV1InternalResponse(respBody)
|
||||
|
||||
var parsed map[string]any
|
||||
if json.Unmarshal(unwrapped, &parsed) == nil {
|
||||
if u := extractGeminiUsage(parsed); u != nil {
|
||||
c.Data(resp.StatusCode, "application/json", unwrapped)
|
||||
return u, nil
|
||||
}
|
||||
}
|
||||
|
||||
c.Data(resp.StatusCode, "application/json", unwrapped)
|
||||
return &ClaudeUsage{}, nil
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) parseClaudeSSEUsage(data string, usage *ClaudeUsage) {
|
||||
// 解析 message_start 获取 input tokens
|
||||
var msgStart struct {
|
||||
Type string `json:"type"`
|
||||
Message struct {
|
||||
Usage ClaudeUsage `json:"usage"`
|
||||
} `json:"message"`
|
||||
}
|
||||
if json.Unmarshal([]byte(data), &msgStart) == nil && msgStart.Type == "message_start" {
|
||||
usage.InputTokens = msgStart.Message.Usage.InputTokens
|
||||
usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens
|
||||
usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens
|
||||
}
|
||||
|
||||
// 解析 message_delta 获取 output tokens
|
||||
var msgDelta struct {
|
||||
Type string `json:"type"`
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
if json.Unmarshal([]byte(data), &msgDelta) == nil && msgDelta.Type == "message_delta" {
|
||||
usage.OutputTokens = msgDelta.Usage.OutputTokens
|
||||
if usage.InputTokens == 0 {
|
||||
usage.InputTokens = msgDelta.Usage.InputTokens
|
||||
}
|
||||
if usage.CacheCreationInputTokens == 0 {
|
||||
usage.CacheCreationInputTokens = msgDelta.Usage.CacheCreationInputTokens
|
||||
}
|
||||
if usage.CacheReadInputTokens == 0 {
|
||||
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) writeClaudeError(c *gin.Context, status int, errType, message string) error {
|
||||
c.JSON(status, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{"type": errType, "message": message},
|
||||
})
|
||||
return fmt.Errorf("%s", message)
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, upstreamStatus int, body []byte) error {
|
||||
var statusCode int
|
||||
var errType, errMsg string
|
||||
|
||||
switch upstreamStatus {
|
||||
case 400:
|
||||
statusCode = http.StatusBadRequest
|
||||
errType = "invalid_request_error"
|
||||
errMsg = "Invalid request"
|
||||
case 401:
|
||||
statusCode = http.StatusBadGateway
|
||||
errType = "authentication_error"
|
||||
errMsg = "Upstream authentication failed"
|
||||
case 403:
|
||||
statusCode = http.StatusBadGateway
|
||||
errType = "permission_error"
|
||||
errMsg = "Upstream access forbidden"
|
||||
case 429:
|
||||
statusCode = http.StatusTooManyRequests
|
||||
errType = "rate_limit_error"
|
||||
errMsg = "Upstream rate limit exceeded"
|
||||
case 529:
|
||||
statusCode = http.StatusServiceUnavailable
|
||||
errType = "overloaded_error"
|
||||
errMsg = "Upstream service overloaded"
|
||||
default:
|
||||
statusCode = http.StatusBadGateway
|
||||
errType = "upstream_error"
|
||||
errMsg = "Upstream request failed"
|
||||
}
|
||||
|
||||
c.JSON(statusCode, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{"type": errType, "message": errMsg},
|
||||
})
|
||||
return fmt.Errorf("upstream error: %d", upstreamStatus)
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error {
|
||||
statusStr := "UNKNOWN"
|
||||
switch status {
|
||||
case 400:
|
||||
statusStr = "INVALID_ARGUMENT"
|
||||
case 404:
|
||||
statusStr = "NOT_FOUND"
|
||||
case 429:
|
||||
statusStr = "RESOURCE_EXHAUSTED"
|
||||
case 500:
|
||||
statusStr = "INTERNAL"
|
||||
case 502, 503:
|
||||
statusStr = "UNAVAILABLE"
|
||||
}
|
||||
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"code": status,
|
||||
"message": message,
|
||||
"status": statusStr,
|
||||
},
|
||||
})
|
||||
return fmt.Errorf("%s", message)
|
||||
}
|
||||
257
backend/internal/service/antigravity_model_mapping_test.go
Normal file
257
backend/internal/service/antigravity_model_mapping_test.go
Normal file
@@ -0,0 +1,257 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsAntigravityModelSupported(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
expected bool
|
||||
}{
|
||||
// 直接支持的模型
|
||||
{"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true},
|
||||
{"直接支持 - claude-opus-4-5-thinking", "claude-opus-4-5-thinking", true},
|
||||
{"直接支持 - claude-sonnet-4-5-thinking", "claude-sonnet-4-5-thinking", true},
|
||||
{"直接支持 - gemini-2.5-flash", "gemini-2.5-flash", true},
|
||||
{"直接支持 - gemini-2.5-flash-lite", "gemini-2.5-flash-lite", true},
|
||||
{"直接支持 - gemini-3-pro-high", "gemini-3-pro-high", true},
|
||||
|
||||
// 可映射的模型
|
||||
{"可映射 - claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20241022", true},
|
||||
{"可映射 - claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20240620", true},
|
||||
{"可映射 - claude-opus-4", "claude-opus-4", true},
|
||||
{"可映射 - claude-haiku-4", "claude-haiku-4", true},
|
||||
{"可映射 - claude-3-haiku-20240307", "claude-3-haiku-20240307", true},
|
||||
|
||||
// Gemini 前缀透传
|
||||
{"Gemini前缀 - gemini-1.5-pro", "gemini-1.5-pro", true},
|
||||
{"Gemini前缀 - gemini-unknown-model", "gemini-unknown-model", true},
|
||||
{"Gemini前缀 - gemini-future-version", "gemini-future-version", true},
|
||||
|
||||
// Claude 前缀兜底
|
||||
{"Claude前缀 - claude-unknown-model", "claude-unknown-model", true},
|
||||
{"Claude前缀 - claude-3-opus-20240229", "claude-3-opus-20240229", true},
|
||||
{"Claude前缀 - claude-future-version", "claude-future-version", true},
|
||||
|
||||
// 不支持的模型
|
||||
{"不支持 - gpt-4", "gpt-4", false},
|
||||
{"不支持 - gpt-4o", "gpt-4o", false},
|
||||
{"不支持 - llama-3", "llama-3", false},
|
||||
{"不支持 - mistral-7b", "mistral-7b", false},
|
||||
{"不支持 - 空字符串", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsAntigravityModelSupported(tt.model)
|
||||
require.Equal(t, tt.expected, got, "model: %s", tt.model)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
requestedModel string
|
||||
accountMapping map[string]string
|
||||
expected string
|
||||
}{
|
||||
// 1. 账户级映射优先(注意:model_mapping 在 credentials 中存储为 map[string]any)
|
||||
{
|
||||
name: "账户映射优先",
|
||||
requestedModel: "claude-3-5-sonnet-20241022",
|
||||
accountMapping: map[string]string{"claude-3-5-sonnet-20241022": "custom-model"},
|
||||
expected: "custom-model",
|
||||
},
|
||||
{
|
||||
name: "账户映射覆盖系统映射",
|
||||
requestedModel: "claude-opus-4",
|
||||
accountMapping: map[string]string{"claude-opus-4": "my-opus"},
|
||||
expected: "my-opus",
|
||||
},
|
||||
|
||||
// 2. 系统默认映射
|
||||
{
|
||||
name: "系统映射 - claude-3-5-sonnet-20241022",
|
||||
requestedModel: "claude-3-5-sonnet-20241022",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-3-5-sonnet-20240620",
|
||||
requestedModel: "claude-3-5-sonnet-20240620",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-opus-4",
|
||||
requestedModel: "claude-opus-4",
|
||||
accountMapping: nil,
|
||||
expected: "claude-opus-4-5-thinking",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-opus-4-5-20251101",
|
||||
requestedModel: "claude-opus-4-5-20251101",
|
||||
accountMapping: nil,
|
||||
expected: "claude-opus-4-5-thinking",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-haiku-4",
|
||||
requestedModel: "claude-haiku-4",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-3-haiku-20240307",
|
||||
requestedModel: "claude-3-haiku-20240307",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-sonnet-4-5-20250929",
|
||||
requestedModel: "claude-sonnet-4-5-20250929",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5-thinking",
|
||||
},
|
||||
|
||||
// 3. Gemini 透传
|
||||
{
|
||||
name: "Gemini透传 - gemini-2.5-flash",
|
||||
requestedModel: "gemini-2.5-flash",
|
||||
accountMapping: nil,
|
||||
expected: "gemini-2.5-flash",
|
||||
},
|
||||
{
|
||||
name: "Gemini透传 - gemini-1.5-pro",
|
||||
requestedModel: "gemini-1.5-pro",
|
||||
accountMapping: nil,
|
||||
expected: "gemini-1.5-pro",
|
||||
},
|
||||
{
|
||||
name: "Gemini透传 - gemini-future-model",
|
||||
requestedModel: "gemini-future-model",
|
||||
accountMapping: nil,
|
||||
expected: "gemini-future-model",
|
||||
},
|
||||
|
||||
// 4. 直接支持的模型
|
||||
{
|
||||
name: "直接支持 - claude-sonnet-4-5",
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "直接支持 - claude-opus-4-5-thinking",
|
||||
requestedModel: "claude-opus-4-5-thinking",
|
||||
accountMapping: nil,
|
||||
expected: "claude-opus-4-5-thinking",
|
||||
},
|
||||
{
|
||||
name: "直接支持 - claude-sonnet-4-5-thinking",
|
||||
requestedModel: "claude-sonnet-4-5-thinking",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5-thinking",
|
||||
},
|
||||
|
||||
// 5. 默认值 fallback(未知 claude 模型)
|
||||
{
|
||||
name: "默认值 - claude-unknown",
|
||||
requestedModel: "claude-unknown",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "默认值 - claude-3-opus-20240229",
|
||||
requestedModel: "claude-3-opus-20240229",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
if tt.accountMapping != nil {
|
||||
// GetModelMapping 期望 model_mapping 是 map[string]any 格式
|
||||
mappingAny := make(map[string]any)
|
||||
for k, v := range tt.accountMapping {
|
||||
mappingAny[k] = v
|
||||
}
|
||||
account.Credentials = map[string]any{
|
||||
"model_mapping": mappingAny,
|
||||
}
|
||||
}
|
||||
|
||||
got := svc.getMappedModel(account, tt.requestedModel)
|
||||
require.Equal(t, tt.expected, got, "model: %s", tt.requestedModel)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_GetMappedModel_EdgeCases(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
requestedModel string
|
||||
expected string
|
||||
}{
|
||||
// 空字符串回退到默认值
|
||||
{"空字符串", "", "claude-sonnet-4-5"},
|
||||
|
||||
// 非 claude/gemini 前缀回退到默认值
|
||||
{"非claude/gemini前缀 - gpt", "gpt-4", "claude-sonnet-4-5"},
|
||||
{"非claude/gemini前缀 - llama", "llama-3", "claude-sonnet-4-5"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{Platform: PlatformAntigravity}
|
||||
got := svc.getMappedModel(account, tt.requestedModel)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_IsModelSupported(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
expected bool
|
||||
}{
|
||||
// 直接支持
|
||||
{"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true},
|
||||
{"直接支持 - gemini-3-flash", "gemini-3-flash", true},
|
||||
|
||||
// 可映射
|
||||
{"可映射 - claude-opus-4", "claude-opus-4", true},
|
||||
|
||||
// 前缀透传
|
||||
{"Gemini前缀", "gemini-unknown", true},
|
||||
{"Claude前缀", "claude-unknown", true},
|
||||
|
||||
// 不支持
|
||||
{"不支持 - gpt-4", "gpt-4", false},
|
||||
{"不支持 - 空字符串", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := svc.IsModelSupported(tt.model)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
145
backend/internal/service/antigravity_token_provider.go
Normal file
145
backend/internal/service/antigravity_token_provider.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
antigravityTokenRefreshSkew = 3 * time.Minute
|
||||
antigravityTokenCacheSkew = 5 * time.Minute
|
||||
)
|
||||
|
||||
// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
||||
type AntigravityTokenCache = GeminiTokenCache
|
||||
|
||||
// AntigravityTokenProvider 管理 Antigravity 账户的 access_token
|
||||
type AntigravityTokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache AntigravityTokenCache
|
||||
antigravityOAuthService *AntigravityOAuthService
|
||||
}
|
||||
|
||||
func NewAntigravityTokenProvider(
|
||||
accountRepo AccountRepository,
|
||||
tokenCache AntigravityTokenCache,
|
||||
antigravityOAuthService *AntigravityOAuthService,
|
||||
) *AntigravityTokenProvider {
|
||||
return &AntigravityTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: tokenCache,
|
||||
antigravityOAuthService: antigravityOAuthService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccessToken 获取有效的 access_token
|
||||
func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformAntigravity || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not an antigravity oauth account")
|
||||
}
|
||||
|
||||
cacheKey := antigravityTokenCacheKey(account)
|
||||
|
||||
// 1. 先尝试缓存
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 如果即将过期则刷新
|
||||
expiresAt := parseAntigravityExpiresAt(account)
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew
|
||||
if needsRefresh && p.tokenCache != nil {
|
||||
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if err == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
|
||||
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// 从数据库获取最新账户信息
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
expiresAt = parseAntigravityExpiresAt(account)
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew {
|
||||
if p.antigravityOAuthService == nil {
|
||||
return "", errors.New("antigravity oauth service not configured")
|
||||
}
|
||||
tokenInfo, err := p.antigravityOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr)
|
||||
}
|
||||
expiresAt = parseAntigravityExpiresAt(account)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 3. 存入缓存
|
||||
if p.tokenCache != nil {
|
||||
ttl := 30 * time.Minute
|
||||
if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > antigravityTokenCacheSkew:
|
||||
ttl = until - antigravityTokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func antigravityTokenCacheKey(account *Account) string {
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if projectID != "" {
|
||||
return "ag:" + projectID
|
||||
}
|
||||
return "ag:account:" + strconv.FormatInt(account.ID, 10)
|
||||
}
|
||||
|
||||
func parseAntigravityExpiresAt(account *Account) *time.Time {
|
||||
raw := strings.TrimSpace(account.GetCredential("expires_at"))
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
if unixSec, err := strconv.ParseInt(raw, 10, 64); err == nil && unixSec > 0 {
|
||||
t := time.Unix(unixSec, 0)
|
||||
return &t
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339, raw); err == nil {
|
||||
return &t
|
||||
}
|
||||
return nil
|
||||
}
|
||||
57
backend/internal/service/antigravity_token_refresher.go
Normal file
57
backend/internal/service/antigravity_token_refresher.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AntigravityTokenRefresher 实现 TokenRefresher 接口
|
||||
type AntigravityTokenRefresher struct {
|
||||
antigravityOAuthService *AntigravityOAuthService
|
||||
}
|
||||
|
||||
func NewAntigravityTokenRefresher(antigravityOAuthService *AntigravityOAuthService) *AntigravityTokenRefresher {
|
||||
return &AntigravityTokenRefresher{
|
||||
antigravityOAuthService: antigravityOAuthService,
|
||||
}
|
||||
}
|
||||
|
||||
// CanRefresh 检查是否可以刷新此账户
|
||||
func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool {
|
||||
return account.Platform == PlatformAntigravity && account.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
// NeedsRefresh 检查账户是否需要刷新
|
||||
func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
|
||||
if !r.CanRefresh(account) {
|
||||
return false
|
||||
}
|
||||
expiresAtStr := account.GetCredential("expires_at")
|
||||
if expiresAtStr == "" {
|
||||
return false
|
||||
}
|
||||
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
expiryTime := time.Unix(expiresAt, 0)
|
||||
return time.Until(expiryTime) < refreshWindow
|
||||
}
|
||||
|
||||
// Refresh 执行 token 刷新
|
||||
func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
|
||||
tokenInfo, err := r.antigravityOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return newCredentials, nil
|
||||
}
|
||||
565
backend/internal/service/gateway_multiplatform_test.go
Normal file
565
backend/internal/service/gateway_multiplatform_test.go
Normal file
@@ -0,0 +1,565 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// mockAccountRepoForMultiplatform 多平台测试用的 mock
|
||||
type mockAccountRepoForMultiplatform struct {
|
||||
accounts []Account
|
||||
accountsByID map[int64]*Account
|
||||
listPlatformsFunc func(ctx context.Context, platforms []string) ([]Account, error)
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForMultiplatform) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
if acc, ok := m.accountsByID[id]; ok {
|
||||
return acc, nil
|
||||
}
|
||||
return nil, errors.New("account not found")
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForMultiplatform) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
|
||||
if m.listPlatformsFunc != nil {
|
||||
return m.listPlatformsFunc(ctx, platforms)
|
||||
}
|
||||
// 过滤符合平台的账户
|
||||
var result []Account
|
||||
platformSet := make(map[string]bool)
|
||||
for _, p := range platforms {
|
||||
platformSet[p] = true
|
||||
}
|
||||
for _, acc := range m.accounts {
|
||||
if platformSet[acc.Platform] && acc.IsSchedulable() {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForMultiplatform) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
|
||||
return m.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
|
||||
// Stub methods to implement AccountRepository interface
|
||||
func (m *mockAccountRepoForMultiplatform) Create(ctx context.Context, account *Account) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForMultiplatform) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForMultiplatform) Update(ctx context.Context, account *Account) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForMultiplatform) Delete(ctx context.Context, id int64) error { return nil }
|
||||
func (m *mockAccountRepoForMultiplatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForMultiplatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForMultiplatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForMultiplatform) ListActive(ctx context.Context) ([]Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForMultiplatform) ListByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForMultiplatform) UpdateLastUsed(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForMultiplatform) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForMultiplatform) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForMultiplatform) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForMultiplatform) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForMultiplatform) ListSchedulable(ctx context.Context) ([]Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForMultiplatform) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForMultiplatform) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForMultiplatform) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForMultiplatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForMultiplatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForMultiplatform) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForMultiplatform) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForMultiplatform) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForMultiplatform) BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Verify interface implementation
|
||||
var _ AccountRepository = (*mockAccountRepoForMultiplatform)(nil)
|
||||
|
||||
// mockGatewayCacheForMultiplatform 多平台测试用的 cache mock
|
||||
type mockGatewayCacheForMultiplatform struct {
|
||||
sessionBindings map[string]int64
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForMultiplatform) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
|
||||
if id, ok := m.sessionBindings[sessionHash]; ok {
|
||||
return id, nil
|
||||
}
|
||||
return 0, errors.New("not found")
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForMultiplatform) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||
if m.sessionBindings == nil {
|
||||
m.sessionBindings = make(map[string]int64)
|
||||
}
|
||||
m.sessionBindings[sessionHash] = accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForMultiplatform) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func ptr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
func TestGatewayService_SelectAccountForModelWithExclusions_OnlyAnthropic(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForMultiplatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForMultiplatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(1), acc.ID, "应选择优先级最高的账户")
|
||||
}
|
||||
|
||||
func TestGatewayService_SelectAccountForModelWithExclusions_OnlyAntigravity(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForMultiplatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForMultiplatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(1), acc.ID)
|
||||
require.Equal(t, PlatformAntigravity, acc.Platform)
|
||||
}
|
||||
|
||||
func TestGatewayService_SelectAccountForModelWithExclusions_MixedPlatforms_SamePriority(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
|
||||
repo := &mockAccountRepoForMultiplatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: ptr(now.Add(-1 * time.Hour))},
|
||||
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: ptr(now.Add(-2 * time.Hour))},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForMultiplatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "应选择最久未用的账户(Antigravity)")
|
||||
}
|
||||
|
||||
func TestGatewayService_SelectAccountForModelWithExclusions_MixedPlatforms_DiffPriority(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForMultiplatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForMultiplatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "应选择优先级更高的账户(Antigravity, priority=1)")
|
||||
}
|
||||
|
||||
func TestGatewayService_SelectAccountForModelWithExclusions_ModelNotSupported(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForMultiplatform{
|
||||
accounts: []Account{
|
||||
// Anthropic 账户配置了模型映射,只支持 other-model
|
||||
// 注意:model_mapping 需要是 map[string]any 格式
|
||||
{
|
||||
ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true,
|
||||
Credentials: map[string]any{"model_mapping": map[string]any{"other-model": "x"}},
|
||||
},
|
||||
// Antigravity 账户支持所有 claude 模型
|
||||
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForMultiplatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "Anthropic 不支持该模型,应选择 Antigravity")
|
||||
}
|
||||
|
||||
func TestGatewayService_SelectAccountForModelWithExclusions_NoAvailableAccounts(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForMultiplatform{
|
||||
accounts: []Account{},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForMultiplatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
|
||||
require.Error(t, err)
|
||||
require.Nil(t, acc)
|
||||
require.Contains(t, err.Error(), "no available accounts")
|
||||
}
|
||||
|
||||
func TestGatewayService_SelectAccountForModelWithExclusions_AllExcluded(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForMultiplatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForMultiplatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
excludedIDs := map[int64]struct{}{1: {}, 2: {}}
|
||||
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, []string{PlatformAnthropic, PlatformAntigravity})
|
||||
require.Error(t, err)
|
||||
require.Nil(t, acc)
|
||||
}
|
||||
|
||||
func TestGatewayService_SelectAccountForModelWithExclusions_Schedulability(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accounts []Account
|
||||
expectedID int64
|
||||
}{
|
||||
{
|
||||
name: "过载账户被跳过",
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, OverloadUntil: ptr(now.Add(1 * time.Hour))},
|
||||
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
expectedID: 2,
|
||||
},
|
||||
{
|
||||
name: "限流账户被跳过",
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, RateLimitResetAt: ptr(now.Add(1 * time.Hour))},
|
||||
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
expectedID: 2,
|
||||
},
|
||||
{
|
||||
name: "非active账户被跳过",
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: "error", Schedulable: true},
|
||||
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
expectedID: 2,
|
||||
},
|
||||
{
|
||||
name: "schedulable=false被跳过",
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: false},
|
||||
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
expectedID: 2,
|
||||
},
|
||||
{
|
||||
name: "过期的过载账户可调度",
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, OverloadUntil: ptr(now.Add(-1 * time.Hour))},
|
||||
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
expectedID: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := &mockAccountRepoForMultiplatform{
|
||||
accounts: tt.accounts,
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForMultiplatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, tt.expectedID, acc.ID)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayService_SelectAccountForModelWithExclusions_StickySession(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("粘性会话命中", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForMultiplatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForMultiplatform{
|
||||
sessionBindings: map[string]int64{"session-123": 1},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(1), acc.ID, "应返回粘性会话绑定的账户")
|
||||
})
|
||||
|
||||
t.Run("粘性会话账户被排除-降级选择", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForMultiplatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForMultiplatform{
|
||||
sessionBindings: map[string]int64{"session-123": 1},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
excludedIDs := map[int64]struct{}{1: {}}
|
||||
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", excludedIDs, []string{PlatformAnthropic, PlatformAntigravity})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "粘性会话账户被排除,应选择其他账户")
|
||||
})
|
||||
|
||||
t.Run("粘性会话账户不可调度-降级选择", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForMultiplatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: "error", Schedulable: true},
|
||||
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForMultiplatform{
|
||||
sessionBindings: map[string]int64{"session-123": 1},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatforms(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, []string{PlatformAnthropic, PlatformAntigravity})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "粘性会话账户不可调度,应选择其他账户")
|
||||
})
|
||||
}
|
||||
|
||||
func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
model string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Antigravity平台-支持claude模型",
|
||||
account: &Account{Platform: PlatformAntigravity},
|
||||
model: "claude-3-5-sonnet-20241022",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Antigravity平台-支持gemini模型",
|
||||
account: &Account{Platform: PlatformAntigravity},
|
||||
model: "gemini-2.5-flash",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Antigravity平台-不支持gpt模型",
|
||||
account: &Account{Platform: PlatformAntigravity},
|
||||
model: "gpt-4",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Anthropic平台-无映射配置-支持所有模型",
|
||||
account: &Account{Platform: PlatformAnthropic},
|
||||
model: "claude-3-5-sonnet-20241022",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Anthropic平台-有映射配置-只支持配置的模型",
|
||||
account: &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Credentials: map[string]any{"model_mapping": map[string]any{"claude-opus-4": "x"}},
|
||||
},
|
||||
model: "claude-3-5-sonnet-20241022",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Anthropic平台-有映射配置-支持配置的模型",
|
||||
account: &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Credentials: map[string]any{"model_mapping": map[string]any{"claude-3-5-sonnet-20241022": "x"}},
|
||||
},
|
||||
model: "claude-3-5-sonnet-20241022",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := svc.isModelSupportedByAccount(tt.account, tt.model)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -291,6 +291,13 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
|
||||
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
|
||||
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||
// 使用多平台账户选择,包含 anthropic 和 antigravity 平台
|
||||
platforms := []string{PlatformAnthropic, PlatformAntigravity}
|
||||
return s.selectAccountForModelWithPlatforms(ctx, groupID, sessionHash, requestedModel, excludedIDs, platforms)
|
||||
}
|
||||
|
||||
// selectAccountForModelWithPlatforms 选择多平台账户
|
||||
func (s *GatewayService) selectAccountForModelWithPlatforms(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platforms []string) (*Account, error) {
|
||||
// 1. 查询粘性会话
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||
@@ -298,8 +305,8 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
// 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中
|
||||
// 同时检查模型支持
|
||||
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
// 同时检查模型支持(根据平台类型分别处理)
|
||||
if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
// 续期粘性会话
|
||||
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
@@ -310,13 +317,13 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 获取可调度账号列表(排除限流和过载的账号,仅限 Anthropic 平台)
|
||||
// 2. 获取可调度账号列表(排除限流和过载的账号,支持多平台)
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAnthropic)
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAnthropic)
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
@@ -329,8 +336,8 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
// 检查模型支持
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
// 检查模型支持(根据平台类型分别处理)
|
||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if selected == nil {
|
||||
@@ -374,6 +381,37 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// isModelSupportedByAccount 根据账户平台检查模型支持
|
||||
func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
|
||||
if account.Platform == PlatformAntigravity {
|
||||
// Antigravity 平台使用专门的模型支持检查
|
||||
return IsAntigravityModelSupported(requestedModel)
|
||||
}
|
||||
// 其他平台使用账户的模型支持检查
|
||||
return account.IsModelSupported(requestedModel)
|
||||
}
|
||||
|
||||
// IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型
|
||||
func IsAntigravityModelSupported(requestedModel string) bool {
|
||||
// 直接支持的模型
|
||||
if antigravitySupportedModels[requestedModel] {
|
||||
return true
|
||||
}
|
||||
// 可映射的模型
|
||||
if _, ok := antigravityModelMapping[requestedModel]; ok {
|
||||
return true
|
||||
}
|
||||
// Gemini 前缀透传
|
||||
if strings.HasPrefix(requestedModel, "gemini-") {
|
||||
return true
|
||||
}
|
||||
// Claude 模型支持(通过默认映射到 claude-sonnet-4-5)
|
||||
if strings.HasPrefix(requestedModel, "claude-") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetAccessToken 获取账号凭证
|
||||
func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
|
||||
switch account.Type {
|
||||
|
||||
@@ -33,11 +33,12 @@ const (
|
||||
)
|
||||
|
||||
type GeminiMessagesCompatService struct {
|
||||
accountRepo AccountRepository
|
||||
cache GatewayCache
|
||||
tokenProvider *GeminiTokenProvider
|
||||
rateLimitService *RateLimitService
|
||||
httpUpstream HTTPUpstream
|
||||
accountRepo AccountRepository
|
||||
cache GatewayCache
|
||||
tokenProvider *GeminiTokenProvider
|
||||
rateLimitService *RateLimitService
|
||||
httpUpstream HTTPUpstream
|
||||
antigravityGatewayService *AntigravityGatewayService
|
||||
}
|
||||
|
||||
func NewGeminiMessagesCompatService(
|
||||
@@ -46,13 +47,15 @@ func NewGeminiMessagesCompatService(
|
||||
tokenProvider *GeminiTokenProvider,
|
||||
rateLimitService *RateLimitService,
|
||||
httpUpstream HTTPUpstream,
|
||||
antigravityGatewayService *AntigravityGatewayService,
|
||||
) *GeminiMessagesCompatService {
|
||||
return &GeminiMessagesCompatService{
|
||||
accountRepo: accountRepo,
|
||||
cache: cache,
|
||||
tokenProvider: tokenProvider,
|
||||
rateLimitService: rateLimitService,
|
||||
httpUpstream: httpUpstream,
|
||||
accountRepo: accountRepo,
|
||||
cache: cache,
|
||||
tokenProvider: tokenProvider,
|
||||
rateLimitService: rateLimitService,
|
||||
httpUpstream: httpUpstream,
|
||||
antigravityGatewayService: antigravityGatewayService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,12 +70,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
|
||||
|
||||
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||
cacheKey := "gemini:" + sessionHash
|
||||
platforms := []string{PlatformGemini, PlatformAntigravity}
|
||||
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.Platform == PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
// 支持 gemini 和 antigravity 平台的粘性会话
|
||||
if err == nil && account.IsSchedulable() && (account.Platform == PlatformGemini || account.Platform == PlatformAntigravity) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
|
||||
return account, nil
|
||||
}
|
||||
@@ -80,12 +86,13 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
}
|
||||
}
|
||||
|
||||
// 同时查询 gemini 和 antigravity 平台的可调度账户
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini)
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini)
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
@@ -97,7 +104,8 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
// 根据平台类型分别检查模型支持
|
||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if selected == nil {
|
||||
@@ -127,9 +135,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
|
||||
if selected == nil {
|
||||
if requestedModel != "" {
|
||||
return nil, fmt.Errorf("no available Gemini accounts supporting model: %s", requestedModel)
|
||||
return nil, fmt.Errorf("no available Gemini/Antigravity accounts supporting model: %s", requestedModel)
|
||||
}
|
||||
return nil, errors.New("no available Gemini accounts")
|
||||
return nil, errors.New("no available Gemini/Antigravity accounts")
|
||||
}
|
||||
|
||||
if sessionHash != "" {
|
||||
@@ -139,6 +147,34 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// isModelSupportedByAccount 根据账户平台检查模型支持
|
||||
func (s *GeminiMessagesCompatService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
|
||||
if account.Platform == PlatformAntigravity {
|
||||
return IsAntigravityModelSupported(requestedModel)
|
||||
}
|
||||
return account.IsModelSupported(requestedModel)
|
||||
}
|
||||
|
||||
// GetAntigravityGatewayService 返回 AntigravityGatewayService
|
||||
func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *AntigravityGatewayService {
|
||||
return s.antigravityGatewayService
|
||||
}
|
||||
|
||||
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
|
||||
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAntigravity)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAntigravity)
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return len(accounts) > 0, nil
|
||||
}
|
||||
|
||||
// SelectAccountForAIStudioEndpoints selects an account that is likely to succeed against
|
||||
// generativelanguage.googleapis.com (e.g. GET /v1beta/models).
|
||||
//
|
||||
|
||||
568
backend/internal/service/gemini_multiplatform_test.go
Normal file
568
backend/internal/service/gemini_multiplatform_test.go
Normal file
@@ -0,0 +1,568 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// mockAccountRepoForGemini Gemini 测试用的 mock
|
||||
type mockAccountRepoForGemini struct {
|
||||
accounts []Account
|
||||
accountsByID map[int64]*Account
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForGemini) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
if acc, ok := m.accountsByID[id]; ok {
|
||||
return acc, nil
|
||||
}
|
||||
return nil, errors.New("account not found")
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
|
||||
platformSet := make(map[string]bool)
|
||||
for _, p := range platforms {
|
||||
platformSet[p] = true
|
||||
}
|
||||
var result []Account
|
||||
for _, acc := range m.accounts {
|
||||
if platformSet[acc.Platform] && acc.IsSchedulable() {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
|
||||
return m.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
|
||||
// Stub methods to implement AccountRepository interface
|
||||
func (m *mockAccountRepoForGemini) Create(ctx context.Context, account *Account) error { return nil }
|
||||
func (m *mockAccountRepoForGemini) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) Update(ctx context.Context, account *Account) error { return nil }
|
||||
func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error { return nil }
|
||||
func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListActive(ctx context.Context) ([]Account, error) { return nil, nil }
|
||||
func (m *mockAccountRepoForGemini) ListByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) UpdateLastUsed(ctx context.Context, id int64) error { return nil }
|
||||
func (m *mockAccountRepoForGemini) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListSchedulable(ctx context.Context) ([]Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
var result []Account
|
||||
for _, acc := range m.accounts {
|
||||
if acc.Platform == platform && acc.IsSchedulable() {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
|
||||
// 测试时不区分 groupID,直接按 platform 过滤
|
||||
return m.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64) error { return nil }
|
||||
func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Verify interface implementation
|
||||
var _ AccountRepository = (*mockAccountRepoForGemini)(nil)
|
||||
|
||||
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
|
||||
type mockGatewayCacheForGemini struct {
|
||||
sessionBindings map[string]int64
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
|
||||
if id, ok := m.sessionBindings[sessionHash]; ok {
|
||||
return id, nil
|
||||
}
|
||||
return 0, errors.New("not found")
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||
if m.sessionBindings == nil {
|
||||
m.sessionBindings = make(map[string]int64)
|
||||
}
|
||||
m.sessionBindings[sessionHash] = accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OnlyGemini(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(1), acc.ID, "应选择优先级最高的账户")
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OnlyAntigravity(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(1), acc.ID)
|
||||
require.Equal(t, PlatformAntigravity, acc.Platform)
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ExcludesAnthropic(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
{ID: 3, Platform: PlatformAntigravity, Priority: 3, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
// Anthropic 不在 [gemini, antigravity] 平台列表中,应被过滤
|
||||
require.Equal(t, int64(2), acc.ID, "Anthropic 平台应被排除,选择 Gemini")
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_MixedPlatforms_SamePriority(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: ptr(now.Add(-1 * time.Hour))},
|
||||
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: ptr(now.Add(-2 * time.Hour))},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "应选择最久未用的账户(Antigravity)")
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPreferred(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Type: AccountTypeApiKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
|
||||
{ID: 2, Platform: PlatformGemini, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "同优先级且都未使用时,应优先选择 OAuth 账户")
|
||||
require.Equal(t, AccountTypeOAuth, acc.Type)
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPreferred_MixedPlatforms(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Type: AccountTypeApiKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
|
||||
{ID: 2, Platform: PlatformAntigravity, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "跨平台时,同样优先选择 OAuth 账户")
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_NoAvailableAccounts(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, acc)
|
||||
require.Contains(t, err.Error(), "no available Gemini/Antigravity accounts")
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickySession(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("粘性会话命中-使用gemini前缀缓存键", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
// 注意:缓存键使用 "gemini:" 前缀
|
||||
cache := &mockGatewayCacheForGemini{
|
||||
sessionBindings: map[string]int64{"gemini:session-123": 1},
|
||||
}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(1), acc.ID, "应返回粘性会话绑定的账户")
|
||||
})
|
||||
|
||||
t.Run("粘性会话不命中无前缀缓存键", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
// 缓存键没有 "gemini:" 前缀,不应命中
|
||||
cache := &mockGatewayCacheForGemini{
|
||||
sessionBindings: map[string]int64{"session-123": 1},
|
||||
}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
// 粘性会话未命中,按优先级选择
|
||||
require.Equal(t, int64(2), acc.ID, "粘性会话未命中,应按优先级选择 Antigravity")
|
||||
})
|
||||
|
||||
t.Run("粘性会话Anthropic账户-降级选择", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{
|
||||
sessionBindings: map[string]int64{"gemini:session-123": 1},
|
||||
}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
// 粘性会话绑定的是 Anthropic 账户,不在 Gemini 平台列表中,应降级选择
|
||||
require.Equal(t, int64(2), acc.ID, "粘性会话账户是 Anthropic,应降级选择 Gemini 平台账户")
|
||||
})
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_HasAntigravityAccounts(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("有antigravity账户时返回true", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformAntigravity, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GeminiMessagesCompatService{accountRepo: repo}
|
||||
|
||||
has, err := svc.HasAntigravityAccounts(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
require.True(t, has)
|
||||
})
|
||||
|
||||
t.Run("无antigravity账户时返回false", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GeminiMessagesCompatService{accountRepo: repo}
|
||||
|
||||
has, err := svc.HasAntigravityAccounts(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
require.False(t, has)
|
||||
})
|
||||
|
||||
t.Run("antigravity账户不可调度时返回false", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAntigravity, Status: StatusActive, Schedulable: false},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GeminiMessagesCompatService{accountRepo: repo}
|
||||
|
||||
has, err := svc.HasAntigravityAccounts(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
require.False(t, has)
|
||||
})
|
||||
|
||||
t.Run("带groupID查询", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAntigravity, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GeminiMessagesCompatService{accountRepo: repo}
|
||||
|
||||
groupID := int64(1)
|
||||
has, err := svc.HasAntigravityAccounts(ctx, &groupID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, has)
|
||||
})
|
||||
}
|
||||
|
||||
// TestGeminiPlatformRouting_DocumentRouteDecision 测试平台路由决策逻辑
|
||||
// 该测试文档化了 Handler 层应该如何根据 account.Platform 进行分流
|
||||
func TestGeminiPlatformRouting_DocumentRouteDecision(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
platform string
|
||||
expectedService string // "gemini" 表示 ForwardNative, "antigravity" 表示 ForwardGemini
|
||||
}{
|
||||
{
|
||||
name: "Gemini平台走ForwardNative",
|
||||
platform: PlatformGemini,
|
||||
expectedService: "gemini",
|
||||
},
|
||||
{
|
||||
name: "Antigravity平台走ForwardGemini",
|
||||
platform: PlatformAntigravity,
|
||||
expectedService: "antigravity",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{Platform: tt.platform}
|
||||
|
||||
// 模拟 Handler 层的路由逻辑
|
||||
var serviceName string
|
||||
if account.Platform == PlatformAntigravity {
|
||||
serviceName = "antigravity"
|
||||
} else {
|
||||
serviceName = "gemini"
|
||||
}
|
||||
|
||||
require.Equal(t, tt.expectedService, serviceName,
|
||||
"平台 %s 应该路由到 %s 服务", tt.platform, tt.expectedService)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) {
|
||||
svc := &GeminiMessagesCompatService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
model string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Antigravity平台-支持gemini模型",
|
||||
account: &Account{Platform: PlatformAntigravity},
|
||||
model: "gemini-2.5-flash",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Antigravity平台-支持claude模型",
|
||||
account: &Account{Platform: PlatformAntigravity},
|
||||
model: "claude-3-5-sonnet-20241022",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Antigravity平台-不支持gpt模型",
|
||||
account: &Account{Platform: PlatformAntigravity},
|
||||
model: "gpt-4",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Gemini平台-无映射配置-支持所有模型",
|
||||
account: &Account{Platform: PlatformGemini},
|
||||
model: "gemini-2.5-flash",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Gemini平台-有映射配置-只支持配置的模型",
|
||||
account: &Account{
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-1.5-pro": "x"}},
|
||||
},
|
||||
model: "gemini-2.5-flash",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := svc.isModelSupportedByAccount(tt.account, tt.model)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -27,6 +27,7 @@ func NewTokenRefreshService(
|
||||
oauthService *OAuthService,
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
geminiOAuthService *GeminiOAuthService,
|
||||
antigravityOAuthService *AntigravityOAuthService,
|
||||
cfg *config.Config,
|
||||
) *TokenRefreshService {
|
||||
s := &TokenRefreshService{
|
||||
@@ -40,6 +41,7 @@ func NewTokenRefreshService(
|
||||
NewClaudeTokenRefresher(oauthService),
|
||||
NewOpenAITokenRefresher(openaiOAuthService),
|
||||
NewGeminiTokenRefresher(geminiOAuthService),
|
||||
NewAntigravityTokenRefresher(antigravityOAuthService),
|
||||
}
|
||||
|
||||
return s
|
||||
|
||||
@@ -39,9 +39,10 @@ func ProvideTokenRefreshService(
|
||||
oauthService *OAuthService,
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
geminiOAuthService *GeminiOAuthService,
|
||||
antigravityOAuthService *AntigravityOAuthService,
|
||||
cfg *config.Config,
|
||||
) *TokenRefreshService {
|
||||
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, cfg)
|
||||
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
@@ -84,6 +85,8 @@ var ProviderSet = wire.NewSet(
|
||||
NewAntigravityOAuthService,
|
||||
NewGeminiTokenProvider,
|
||||
NewGeminiMessagesCompatService,
|
||||
NewAntigravityTokenProvider,
|
||||
NewAntigravityGatewayService,
|
||||
NewRateLimitService,
|
||||
NewAccountUsageService,
|
||||
NewAccountTestService,
|
||||
|
||||
@@ -62,6 +62,10 @@ const filteredGroups = computed(() => {
|
||||
if (!props.platform) {
|
||||
return props.groups
|
||||
}
|
||||
// antigravity 账户可选择 anthropic 和 gemini 平台的分组
|
||||
if (props.platform === 'antigravity') {
|
||||
return props.groups.filter((g) => g.platform === 'anthropic' || g.platform === 'gemini')
|
||||
}
|
||||
return props.groups.filter((g) => g.platform === props.platform)
|
||||
})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user