diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index a12d3790..7b22a31e 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -84,7 +84,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
}
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
- accountRepository := repository.NewAccountRepository(client, db)
+ schedulerCache := repository.NewSchedulerCache(redisClient)
+ accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
@@ -129,7 +130,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminRedeemHandler := admin.NewRedeemHandler(adminService)
promoHandler := admin.NewPromoHandler(promoService)
opsRepository := repository.NewOpsRepository(db)
- schedulerCache := repository.NewSchedulerCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index 5a543d6c..0e3e0a2f 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -47,6 +47,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
+ PromoCodeEnabled: settings.PromoCodeEnabled,
SMTPHost: settings.SMTPHost,
SMTPPort: settings.SMTPPort,
SMTPUsername: settings.SMTPUsername,
@@ -90,6 +91,7 @@ type UpdateSettingsRequest struct {
// 注册设置
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
+ PromoCodeEnabled bool `json:"promo_code_enabled"`
// 邮件服务设置
SMTPHost string `json:"smtp_host"`
@@ -240,6 +242,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
+ PromoCodeEnabled: req.PromoCodeEnabled,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
@@ -314,6 +317,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
+ PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
SMTPHost: updatedSettings.SMTPHost,
SMTPPort: updatedSettings.SMTPPort,
SMTPUsername: updatedSettings.SMTPUsername,
diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go
index 882e4cf2..89f34aae 100644
--- a/backend/internal/handler/auth_handler.go
+++ b/backend/internal/handler/auth_handler.go
@@ -195,6 +195,15 @@ type ValidatePromoCodeResponse struct {
// ValidatePromoCode 验证优惠码(公开接口,注册前调用)
// POST /api/v1/auth/validate-promo-code
func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
+ // 检查优惠码功能是否启用
+ if h.settingSvc != nil && !h.settingSvc.IsPromoCodeEnabled(c.Request.Context()) {
+ response.Success(c, ValidatePromoCodeResponse{
+ Valid: false,
+ ErrorCode: "PROMO_CODE_DISABLED",
+ })
+ return
+ }
+
var req ValidatePromoCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index 19356e46..01f39478 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -4,6 +4,7 @@ package dto
type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
+ PromoCodeEnabled bool `json:"promo_code_enabled"`
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
@@ -55,6 +56,7 @@ type SystemSettings struct {
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
+ PromoCodeEnabled bool `json:"promo_code_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go
index 0fc61144..8723c746 100644
--- a/backend/internal/handler/setting_handler.go
+++ b/backend/internal/handler/setting_handler.go
@@ -34,6 +34,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
response.Success(c, dto.PublicSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
+ PromoCodeEnabled: settings.PromoCodeEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
diff --git a/backend/internal/pkg/tlsfingerprint/dialer_test.go b/backend/internal/pkg/tlsfingerprint/dialer_test.go
index 2aed1287..845d51e5 100644
--- a/backend/internal/pkg/tlsfingerprint/dialer_test.go
+++ b/backend/internal/pkg/tlsfingerprint/dialer_test.go
@@ -305,3 +305,139 @@ func mustParseURL(rawURL string) *url.URL {
}
return u
}
+
+// TestProfileExpectation defines expected fingerprint values for a profile.
+type TestProfileExpectation struct {
+ Profile *Profile
+ ExpectedJA3 string // Expected JA3 hash (empty = don't check)
+ ExpectedJA4 string // Expected full JA4 (empty = don't check)
+ JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check)
+}
+
+// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
+// Run with: go test -v -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
+func TestAllProfiles(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping integration test in short mode")
+ }
+
+ // Define all profiles to test with their expected fingerprints
+ // These profiles are from config.yaml gateway.tls_fingerprint.profiles
+ profiles := []TestProfileExpectation{
+ {
+ // Linux x64 Node.js v22.17.1
+ // Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
+ // Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
+ Profile: &Profile{
+ Name: "linux_x64_node_v22171",
+ EnableGREASE: false,
+ CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
+ Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
+ PointFormats: []uint8{0, 1, 2},
+ },
+ JA4CipherHash: "a33745022dd6", // stable part
+ },
+ {
+ // MacOS arm64 Node.js v22.18.0
+ // Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
+ // Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
+ Profile: &Profile{
+ Name: "macos_arm64_node_v22180",
+ EnableGREASE: false,
+ CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
+ Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
+ PointFormats: []uint8{0, 1, 2},
+ },
+ JA4CipherHash: "a33745022dd6", // stable part (same cipher suites)
+ },
+ }
+
+ for _, tc := range profiles {
+ tc := tc // capture range variable
+ t.Run(tc.Profile.Name, func(t *testing.T) {
+ fp := fetchFingerprint(t, tc.Profile)
+ if fp == nil {
+ return // fetchFingerprint already called t.Fatal
+ }
+
+ t.Logf("Profile: %s", tc.Profile.Name)
+ t.Logf(" JA3: %s", fp.JA3)
+ t.Logf(" JA3 Hash: %s", fp.JA3Hash)
+ t.Logf(" JA4: %s", fp.JA4)
+ t.Logf(" PeetPrint: %s", fp.PeetPrint)
+ t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash)
+
+ // Verify expectations
+ if tc.ExpectedJA3 != "" {
+ if fp.JA3Hash == tc.ExpectedJA3 {
+ t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3)
+ } else {
+ t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3)
+ }
+ }
+
+ if tc.ExpectedJA4 != "" {
+ if fp.JA4 == tc.ExpectedJA4 {
+ t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4)
+ } else {
+ t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4)
+ }
+ }
+
+ // Check JA4 cipher hash (stable middle part)
+ // JA4 format: prefix_cipherHash_extHash
+ if tc.JA4CipherHash != "" {
+ if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") {
+ t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash)
+ } else {
+ t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash)
+ }
+ }
+ })
+ }
+}
+
+// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
+func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo {
+ t.Helper()
+
+ dialer := NewDialer(profile, nil)
+ client := &http.Client{
+ Transport: &http.Transport{
+ DialTLSContext: dialer.DialTLSContext,
+ },
+ Timeout: 30 * time.Second,
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
+ if err != nil {
+ t.Fatalf("failed to create request: %v", err)
+ return nil
+ }
+ req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
+
+ resp, err := client.Do(req)
+ if err != nil {
+ t.Fatalf("failed to get fingerprint: %v", err)
+ return nil
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatalf("failed to read response: %v", err)
+ return nil
+ }
+
+ var fpResp FingerprintResponse
+ if err := json.Unmarshal(body, &fpResp); err != nil {
+ t.Logf("Response body: %s", string(body))
+ t.Fatalf("failed to parse fingerprint response: %v", err)
+ return nil
+ }
+
+ return &fpResp.TLS
+}
diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go
index c2673ad3..c11c079b 100644
--- a/backend/internal/repository/account_repo.go
+++ b/backend/internal/repository/account_repo.go
@@ -39,9 +39,15 @@ import (
// 设计说明:
// - client: Ent 客户端,用于类型安全的 ORM 操作
// - sql: 原生 SQL 执行器,用于复杂查询和批量操作
+// - schedulerCache: 调度器缓存,用于在账号状态变更时同步快照
type accountRepository struct {
client *dbent.Client // Ent ORM 客户端
sql sqlExecutor // 原生 SQL 执行接口
+ // schedulerCache 用于在账号状态变更时主动同步快照到缓存,
+ // 确保粘性会话能及时感知账号不可用状态。
+ // Used to proactively sync account snapshot to cache when status changes,
+ // ensuring sticky sessions can promptly detect unavailable accounts.
+ schedulerCache service.SchedulerCache
}
type tempUnschedSnapshot struct {
@@ -51,14 +57,14 @@ type tempUnschedSnapshot struct {
// NewAccountRepository 创建账户仓储实例。
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
-func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB) service.AccountRepository {
- return newAccountRepositoryWithSQL(client, sqlDB)
+func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
+ return newAccountRepositoryWithSQL(client, sqlDB, schedulerCache)
}
// newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。
// 这种设计便于单元测试时注入 mock 对象。
-func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accountRepository {
- return &accountRepository{client: client, sql: sqlq}
+func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor, schedulerCache service.SchedulerCache) *accountRepository {
+ return &accountRepository{client: client, sql: sqlq, schedulerCache: schedulerCache}
}
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
@@ -356,6 +362,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
}
+ if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
+ r.syncSchedulerAccountSnapshot(ctx, account.ID)
+ }
return nil
}
@@ -540,9 +549,32 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
}
+ r.syncSchedulerAccountSnapshot(ctx, id)
return nil
}
+// syncSchedulerAccountSnapshot 在账号状态变更时主动同步快照到调度器缓存。
+// 当账号被设置为错误、禁用、不可调度或临时不可调度时调用,
+// 确保调度器和粘性会话逻辑能及时感知账号的最新状态,避免继续使用不可用账号。
+//
+// syncSchedulerAccountSnapshot proactively syncs account snapshot to scheduler cache
+// when account status changes. Called when account is set to error, disabled,
+// unschedulable, or temporarily unschedulable, ensuring scheduler and sticky session
+// logic can promptly detect the latest account state and avoid using unavailable accounts.
+func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, accountID int64) {
+ if r == nil || r.schedulerCache == nil || accountID <= 0 {
+ return
+ }
+ account, err := r.GetByID(ctx, accountID)
+ if err != nil {
+ log.Printf("[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err)
+ return
+ }
+ if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
+ log.Printf("[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err)
+ }
+}
+
func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
@@ -873,6 +905,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
}
+ r.syncSchedulerAccountSnapshot(ctx, id)
return nil
}
@@ -992,6 +1025,9 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
}
+ if !schedulable {
+ r.syncSchedulerAccountSnapshot(ctx, id)
+ }
return nil
}
@@ -1146,6 +1182,18 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
}
+ shouldSync := false
+ if updates.Status != nil && (*updates.Status == service.StatusError || *updates.Status == service.StatusDisabled) {
+ shouldSync = true
+ }
+ if updates.Schedulable != nil && !*updates.Schedulable {
+ shouldSync = true
+ }
+ if shouldSync {
+ for _, id := range ids {
+ r.syncSchedulerAccountSnapshot(ctx, id)
+ }
+ }
}
return rows, nil
}
diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go
index 250b141d..a054b6d6 100644
--- a/backend/internal/repository/account_repo_integration_test.go
+++ b/backend/internal/repository/account_repo_integration_test.go
@@ -21,11 +21,56 @@ type AccountRepoSuite struct {
repo *accountRepository
}
+type schedulerCacheRecorder struct {
+ setAccounts []*service.Account
+}
+
+func (s *schedulerCacheRecorder) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
+ return nil, false, nil
+}
+
+func (s *schedulerCacheRecorder) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
+ return nil
+}
+
+func (s *schedulerCacheRecorder) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
+ return nil, nil
+}
+
+func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *service.Account) error {
+ s.setAccounts = append(s.setAccounts, account)
+ return nil
+}
+
+func (s *schedulerCacheRecorder) DeleteAccount(ctx context.Context, accountID int64) error {
+ return nil
+}
+
+func (s *schedulerCacheRecorder) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
+ return nil
+}
+
+func (s *schedulerCacheRecorder) TryLockBucket(ctx context.Context, bucket service.SchedulerBucket, ttl time.Duration) (bool, error) {
+ return true, nil
+}
+
+func (s *schedulerCacheRecorder) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
+ return nil, nil
+}
+
+func (s *schedulerCacheRecorder) GetOutboxWatermark(ctx context.Context) (int64, error) {
+ return 0, nil
+}
+
+func (s *schedulerCacheRecorder) SetOutboxWatermark(ctx context.Context, id int64) error {
+ return nil
+}
+
func (s *AccountRepoSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.client = tx.Client()
- s.repo = newAccountRepositoryWithSQL(s.client, tx)
+ s.repo = newAccountRepositoryWithSQL(s.client, tx, nil)
}
func TestAccountRepoSuite(t *testing.T) {
@@ -73,6 +118,20 @@ func (s *AccountRepoSuite) TestUpdate() {
s.Require().Equal("updated", got.Name)
}
+func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() {
+ account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "sync-update", Status: service.StatusActive, Schedulable: true})
+ cacheRecorder := &schedulerCacheRecorder{}
+ s.repo.schedulerCache = cacheRecorder
+
+ account.Status = service.StatusDisabled
+ err := s.repo.Update(s.ctx, account)
+ s.Require().NoError(err, "Update")
+
+ s.Require().Len(cacheRecorder.setAccounts, 1)
+ s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
+ s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status)
+}
+
func (s *AccountRepoSuite) TestDelete() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
@@ -174,7 +233,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
// 每个 case 重新获取隔离资源
tx := testEntTx(s.T())
client := tx.Client()
- repo := newAccountRepositoryWithSQL(client, tx)
+ repo := newAccountRepositoryWithSQL(client, tx, nil)
ctx := context.Background()
tt.setup(client)
@@ -365,12 +424,38 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
func (s *AccountRepoSuite) TestSetSchedulable() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-sched", Schedulable: true})
+ cacheRecorder := &schedulerCacheRecorder{}
+ s.repo.schedulerCache = cacheRecorder
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().False(got.Schedulable)
+ s.Require().Len(cacheRecorder.setAccounts, 1)
+ s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
+}
+
+func (s *AccountRepoSuite) TestBulkUpdate_SyncSchedulerSnapshotOnDisabled() {
+ account1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-1", Status: service.StatusActive, Schedulable: true})
+ account2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-2", Status: service.StatusActive, Schedulable: true})
+ cacheRecorder := &schedulerCacheRecorder{}
+ s.repo.schedulerCache = cacheRecorder
+
+ disabled := service.StatusDisabled
+ rows, err := s.repo.BulkUpdate(s.ctx, []int64{account1.ID, account2.ID}, service.AccountBulkUpdate{
+ Status: &disabled,
+ })
+ s.Require().NoError(err)
+ s.Require().Equal(int64(2), rows)
+
+ s.Require().Len(cacheRecorder.setAccounts, 2)
+ ids := map[int64]struct{}{}
+ for _, acc := range cacheRecorder.setAccounts {
+ ids[acc.ID] = struct{}{}
+ }
+ s.Require().Contains(ids, account1.ID)
+ s.Require().Contains(ids, account2.ID)
}
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---
diff --git a/backend/internal/repository/gateway_cache.go b/backend/internal/repository/gateway_cache.go
index 40a9ad05..58291b66 100644
--- a/backend/internal/repository/gateway_cache.go
+++ b/backend/internal/repository/gateway_cache.go
@@ -39,3 +39,15 @@ func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, ses
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Expire(ctx, key, ttl).Err()
}
+
+// DeleteSessionAccountID 删除粘性会话与账号的绑定关系。
+// 当检测到绑定的账号不可用(如状态错误、禁用、不可调度等)时调用,
+// 以便下次请求能够重新选择可用账号。
+//
+// DeleteSessionAccountID removes the sticky session binding for the given session.
+// Called when the bound account becomes unavailable (e.g., error status, disabled,
+// or unschedulable), allowing subsequent requests to select a new available account.
+func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
+ key := buildSessionKey(groupID, sessionHash)
+ return c.rdb.Del(ctx, key).Err()
+}
diff --git a/backend/internal/repository/gateway_cache_integration_test.go b/backend/internal/repository/gateway_cache_integration_test.go
index d8885bca..0eebc33f 100644
--- a/backend/internal/repository/gateway_cache_integration_test.go
+++ b/backend/internal/repository/gateway_cache_integration_test.go
@@ -78,6 +78,19 @@ func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
}
+func (s *GatewayCacheSuite) TestDeleteSessionAccountID() {
+ sessionID := "openai:s4"
+ accountID := int64(102)
+ groupID := int64(1)
+ sessionTTL := 1 * time.Minute
+
+ require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
+ require.NoError(s.T(), s.cache.DeleteSessionAccountID(s.ctx, groupID, sessionID), "DeleteSessionAccountID")
+
+ _, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
+ require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
+}
+
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
sessionID := "corrupted"
groupID := int64(1)
diff --git a/backend/internal/repository/gateway_routing_integration_test.go b/backend/internal/repository/gateway_routing_integration_test.go
index 5566d2e9..77591fe3 100644
--- a/backend/internal/repository/gateway_routing_integration_test.go
+++ b/backend/internal/repository/gateway_routing_integration_test.go
@@ -24,7 +24,7 @@ func (s *GatewayRoutingSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.client = tx.Client()
- s.accountRepo = newAccountRepositoryWithSQL(s.client, tx)
+ s.accountRepo = newAccountRepositoryWithSQL(s.client, tx, nil)
}
func TestGatewayRoutingSuite(t *testing.T) {
diff --git a/backend/internal/repository/openai_oauth_service.go b/backend/internal/repository/openai_oauth_service.go
index 07d57410..b7f3606f 100644
--- a/backend/internal/repository/openai_oauth_service.go
+++ b/backend/internal/repository/openai_oauth_service.go
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/url"
+ "strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
@@ -21,7 +22,7 @@ type openaiOAuthService struct {
}
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
- client := createOpenAIReqClient(proxyURL)
+ client := createOpenAIReqClient(s.tokenURL, proxyURL)
if redirectURI == "" {
redirectURI = openai.DefaultRedirectURI
@@ -54,7 +55,7 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
}
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
- client := createOpenAIReqClient(proxyURL)
+ client := createOpenAIReqClient(s.tokenURL, proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
@@ -81,9 +82,14 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
return &tokenResp, nil
}
-func createOpenAIReqClient(proxyURL string) *req.Client {
+func createOpenAIReqClient(tokenURL, proxyURL string) *req.Client {
+ forceHTTP2 := false
+ if parsedURL, err := url.Parse(tokenURL); err == nil {
+ forceHTTP2 = strings.EqualFold(parsedURL.Scheme, "https")
+ }
return getSharedReqClient(reqClientOptions{
- ProxyURL: proxyURL,
- Timeout: 60 * time.Second,
+ ProxyURL: proxyURL,
+ Timeout: 120 * time.Second,
+ ForceHTTP2: forceHTTP2,
})
}
diff --git a/backend/internal/repository/openai_oauth_service_test.go b/backend/internal/repository/openai_oauth_service_test.go
index 51142306..f9df08c8 100644
--- a/backend/internal/repository/openai_oauth_service_test.go
+++ b/backend/internal/repository/openai_oauth_service_test.go
@@ -244,6 +244,13 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
require.ErrorContains(s.T(), err, "status 401")
}
+func TestNewOpenAIOAuthClient_DefaultTokenURL(t *testing.T) {
+ client := NewOpenAIOAuthClient()
+ svc, ok := client.(*openaiOAuthService)
+ require.True(t, ok)
+ require.Equal(t, openai.TokenURL, svc.tokenURL)
+}
+
func TestOpenAIOAuthServiceSuite(t *testing.T) {
suite.Run(t, new(OpenAIOAuthServiceSuite))
}
diff --git a/backend/internal/repository/req_client_pool.go b/backend/internal/repository/req_client_pool.go
index b23462a4..af71a7ee 100644
--- a/backend/internal/repository/req_client_pool.go
+++ b/backend/internal/repository/req_client_pool.go
@@ -14,6 +14,7 @@ type reqClientOptions struct {
ProxyURL string // 代理 URL(支持 http/https/socks5)
Timeout time.Duration // 请求超时时间
Impersonate bool // 是否模拟 Chrome 浏览器指纹
+ ForceHTTP2 bool // 是否强制使用 HTTP/2
}
// sharedReqClients 存储按配置参数缓存的 req 客户端实例
@@ -41,6 +42,9 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
}
client := req.C().SetTimeout(opts.Timeout)
+ if opts.ForceHTTP2 {
+ client = client.EnableForceHTTP2()
+ }
if opts.Impersonate {
client = client.ImpersonateChrome()
}
@@ -56,9 +60,10 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
}
func buildReqClientKey(opts reqClientOptions) string {
- return fmt.Sprintf("%s|%s|%t",
+ return fmt.Sprintf("%s|%s|%t|%t",
strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(),
opts.Impersonate,
+ opts.ForceHTTP2,
)
}
diff --git a/backend/internal/repository/req_client_pool_test.go b/backend/internal/repository/req_client_pool_test.go
new file mode 100644
index 00000000..cf7e8bd0
--- /dev/null
+++ b/backend/internal/repository/req_client_pool_test.go
@@ -0,0 +1,102 @@
+package repository
+
+import (
+ "reflect"
+ "sync"
+ "testing"
+ "time"
+ "unsafe"
+
+ "github.com/imroc/req/v3"
+ "github.com/stretchr/testify/require"
+)
+
+func forceHTTPVersion(t *testing.T, client *req.Client) string {
+ t.Helper()
+ transport := client.GetTransport()
+ field := reflect.ValueOf(transport).Elem().FieldByName("forceHttpVersion")
+ require.True(t, field.IsValid(), "forceHttpVersion field not found")
+ require.True(t, field.CanAddr(), "forceHttpVersion field not addressable")
+ return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().String()
+}
+
+func TestGetSharedReqClient_ForceHTTP2SeparatesCache(t *testing.T) {
+ sharedReqClients = sync.Map{}
+ base := reqClientOptions{
+ ProxyURL: "http://proxy.local:8080",
+ Timeout: time.Second,
+ }
+ clientDefault := getSharedReqClient(base)
+
+ force := base
+ force.ForceHTTP2 = true
+ clientForce := getSharedReqClient(force)
+
+ require.NotSame(t, clientDefault, clientForce)
+ require.NotEqual(t, buildReqClientKey(base), buildReqClientKey(force))
+}
+
+func TestGetSharedReqClient_ReuseCachedClient(t *testing.T) {
+ sharedReqClients = sync.Map{}
+ opts := reqClientOptions{
+ ProxyURL: "http://proxy.local:8080",
+ Timeout: 2 * time.Second,
+ }
+ first := getSharedReqClient(opts)
+ second := getSharedReqClient(opts)
+ require.Same(t, first, second)
+}
+
+func TestGetSharedReqClient_IgnoresNonClientCache(t *testing.T) {
+ sharedReqClients = sync.Map{}
+ opts := reqClientOptions{
+ ProxyURL: " http://proxy.local:8080 ",
+ Timeout: 3 * time.Second,
+ }
+ key := buildReqClientKey(opts)
+ sharedReqClients.Store(key, "invalid")
+
+ client := getSharedReqClient(opts)
+
+ require.NotNil(t, client)
+ loaded, ok := sharedReqClients.Load(key)
+ require.True(t, ok)
+ require.IsType(t, "invalid", loaded)
+}
+
+func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) {
+ sharedReqClients = sync.Map{}
+ opts := reqClientOptions{
+ ProxyURL: " http://proxy.local:8080 ",
+ Timeout: 4 * time.Second,
+ Impersonate: true,
+ }
+ client := getSharedReqClient(opts)
+
+ require.NotNil(t, client)
+ require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts))
+}
+
+func TestCreateOpenAIReqClient_ForceHTTP2Enabled(t *testing.T) {
+ sharedReqClients = sync.Map{}
+ client := createOpenAIReqClient("https://auth.openai.com/oauth/token", "http://proxy.local:8080")
+ require.Equal(t, "2", forceHTTPVersion(t, client))
+}
+
+func TestCreateOpenAIReqClient_ForceHTTP2DisabledForHTTP(t *testing.T) {
+ sharedReqClients = sync.Map{}
+ client := createOpenAIReqClient("http://localhost/oauth/token", "http://proxy.local:8080")
+ require.Equal(t, "", forceHTTPVersion(t, client))
+}
+
+func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) {
+ sharedReqClients = sync.Map{}
+ client := createOpenAIReqClient("https://auth.openai.com/oauth/token", "http://proxy.local:8080")
+ require.Equal(t, 120*time.Second, client.GetClient().Timeout)
+}
+
+func TestCreateGeminiReqClient_ForceHTTP2Disabled(t *testing.T) {
+ sharedReqClients = sync.Map{}
+ client := createGeminiReqClient("http://proxy.local:8080")
+ require.Equal(t, "", forceHTTPVersion(t, client))
+}
diff --git a/backend/internal/repository/scheduler_snapshot_outbox_integration_test.go b/backend/internal/repository/scheduler_snapshot_outbox_integration_test.go
index e442a125..a88b74ef 100644
--- a/backend/internal/repository/scheduler_snapshot_outbox_integration_test.go
+++ b/backend/internal/repository/scheduler_snapshot_outbox_integration_test.go
@@ -19,7 +19,7 @@ func TestSchedulerSnapshotOutboxReplay(t *testing.T) {
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox")
- accountRepo := newAccountRepositoryWithSQL(client, integrationDB)
+ accountRepo := newAccountRepositoryWithSQL(client, integrationDB, nil)
outboxRepo := NewSchedulerOutboxRepository(integrationDB)
cache := NewSchedulerCache(rdb)
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index 4ce58942..230a3c60 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -412,6 +412,7 @@ func TestAPIContracts(t *testing.T) {
deps.settingRepo.SetAll(map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false",
+ service.SettingKeyPromoCodeEnabled: "true",
service.SettingKeySMTPHost: "smtp.example.com",
service.SettingKeySMTPPort: "587",
@@ -450,6 +451,7 @@ func TestAPIContracts(t *testing.T) {
"data": {
"registration_enabled": true,
"email_verify_enabled": false,
+ "promo_code_enabled": true,
"smtp_host": "smtp.example.com",
"smtp_port": 587,
"smtp_username": "user",
diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go
index 386b43fc..854e7732 100644
--- a/backend/internal/service/auth_service.go
+++ b/backend/internal/service/auth_service.go
@@ -153,8 +153,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, ErrServiceUnavailable
}
- // 应用优惠码(如果提供)
- if promoCode != "" && s.promoService != nil {
+ // 应用优惠码(如果提供且功能已启用)
+ if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) {
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
// 优惠码应用失败不影响注册,只记录日志
log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err)
diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go
index da1b9377..3bb63ffa 100644
--- a/backend/internal/service/domain_constants.go
+++ b/backend/internal/service/domain_constants.go
@@ -71,6 +71,7 @@ const (
// 注册设置
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
+ SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
// 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go
index 4d17d5e1..26eb24e4 100644
--- a/backend/internal/service/gateway_multiplatform_test.go
+++ b/backend/internal/service/gateway_multiplatform_test.go
@@ -182,6 +182,7 @@ var _ AccountRepository = (*mockAccountRepoForPlatform)(nil)
// mockGatewayCacheForPlatform 单平台测试用的 cache mock
type mockGatewayCacheForPlatform struct {
sessionBindings map[string]int64
+ deletedSessions map[string]int
}
func (m *mockGatewayCacheForPlatform) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
@@ -203,6 +204,18 @@ func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, gro
return nil
}
+func (m *mockGatewayCacheForPlatform) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
+ if m.sessionBindings == nil {
+ return nil
+ }
+ if m.deletedSessions == nil {
+ m.deletedSessions = make(map[string]int)
+ }
+ m.deletedSessions[sessionHash]++
+ delete(m.sessionBindings, sessionHash)
+ return nil
+}
+
type mockGroupRepoForGateway struct {
groups map[int64]*Group
getByIDCalls int
@@ -626,6 +639,363 @@ func TestGatewayService_SelectAccountForModelWithPlatform_StickySession(t *testi
})
}
+func TestGatewayService_SelectAccountForModelWithExclusions_ForcePlatform(t *testing.T) {
+ ctx := context.Background()
+ ctx = context.WithValue(ctx, ctxkey.ForcePlatform, PlatformAntigravity)
+
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
+ {ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{}
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ cache: cache,
+ cfg: testConfig(),
+ }
+
+ acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
+ require.NoError(t, err)
+ require.NotNil(t, acc)
+ require.Equal(t, int64(2), acc.ID)
+ require.Equal(t, PlatformAntigravity, acc.Platform)
+}
+
+func TestGatewayService_SelectAccountForModelWithPlatform_RoutedStickySessionClears(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(10)
+ requestedModel := "claude-3-5-sonnet-20241022"
+
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusDisabled, Schedulable: true},
+ {ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{
+ sessionBindings: map[string]int64{"session-123": 1},
+ }
+
+ groupRepo := &mockGroupRepoForGateway{
+ groups: map[int64]*Group{
+ groupID: {
+ ID: groupID,
+ Name: "route-group",
+ Platform: PlatformAnthropic,
+ Status: StatusActive,
+ Hydrated: true,
+ ModelRoutingEnabled: true,
+ ModelRouting: map[string][]int64{
+ requestedModel: {1, 2},
+ },
+ },
+ },
+ }
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ cache: cache,
+ cfg: testConfig(),
+ groupRepo: groupRepo,
+ }
+
+ acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "session-123", requestedModel, nil, PlatformAnthropic)
+ require.NoError(t, err)
+ require.NotNil(t, acc)
+ require.Equal(t, int64(2), acc.ID)
+ require.Equal(t, 1, cache.deletedSessions["session-123"])
+ require.Equal(t, int64(2), cache.sessionBindings["session-123"])
+}
+
+func TestGatewayService_SelectAccountForModelWithPlatform_RoutedStickySessionHit(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(11)
+ requestedModel := "claude-3-5-sonnet-20241022"
+
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
+ {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{
+ sessionBindings: map[string]int64{"session-456": 1},
+ }
+
+ groupRepo := &mockGroupRepoForGateway{
+ groups: map[int64]*Group{
+ groupID: {
+ ID: groupID,
+ Name: "route-group-hit",
+ Platform: PlatformAnthropic,
+ Status: StatusActive,
+ Hydrated: true,
+ ModelRoutingEnabled: true,
+ ModelRouting: map[string][]int64{
+ requestedModel: {1, 2},
+ },
+ },
+ },
+ }
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ cache: cache,
+ cfg: testConfig(),
+ groupRepo: groupRepo,
+ }
+
+ acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "session-456", requestedModel, nil, PlatformAnthropic)
+ require.NoError(t, err)
+ require.NotNil(t, acc)
+ require.Equal(t, int64(1), acc.ID)
+}
+
+func TestGatewayService_SelectAccountForModelWithPlatform_RoutedFallbackToNormal(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(12)
+ requestedModel := "claude-3-5-sonnet-20241022"
+
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
+ {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{}
+
+ groupRepo := &mockGroupRepoForGateway{
+ groups: map[int64]*Group{
+ groupID: {
+ ID: groupID,
+ Name: "route-fallback",
+ Platform: PlatformAnthropic,
+ Status: StatusActive,
+ Hydrated: true,
+ ModelRoutingEnabled: true,
+ ModelRouting: map[string][]int64{
+ requestedModel: {99},
+ },
+ },
+ },
+ }
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ cache: cache,
+ cfg: testConfig(),
+ groupRepo: groupRepo,
+ }
+
+ acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "", requestedModel, nil, PlatformAnthropic)
+ require.NoError(t, err)
+ require.NotNil(t, acc)
+ require.Equal(t, int64(1), acc.ID)
+}
+
+func TestGatewayService_SelectAccountForModelWithPlatform_NoModelSupport(t *testing.T) {
+ ctx := context.Background()
+
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {
+ ID: 1,
+ Platform: PlatformAnthropic,
+ Priority: 1,
+ Status: StatusActive,
+ Schedulable: true,
+ Credentials: map[string]any{"model_mapping": map[string]any{"claude-3-5-haiku-20241022": "claude-3-5-haiku-20241022"}},
+ },
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{}
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ cache: cache,
+ cfg: testConfig(),
+ }
+
+ acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
+ require.Error(t, err)
+ require.Nil(t, acc)
+ require.Contains(t, err.Error(), "supporting model")
+}
+
+func TestGatewayService_SelectAccountForModelWithPlatform_GeminiPreferOAuth(t *testing.T) {
+ ctx := context.Background()
+
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
+ {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{}
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ cache: cache,
+ cfg: testConfig(),
+ }
+
+ acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini)
+ require.NoError(t, err)
+ require.NotNil(t, acc)
+ require.Equal(t, int64(2), acc.ID)
+}
+
+func TestGatewayService_SelectAccountForModelWithPlatform_StickyInGroup(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(50)
+
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, AccountGroups: []AccountGroup{{GroupID: groupID}}},
+ {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, AccountGroups: []AccountGroup{{GroupID: groupID}}},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{
+ sessionBindings: map[string]int64{"session-group": 1},
+ }
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ cache: cache,
+ cfg: testConfig(),
+ }
+
+ acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "session-group", "", nil, PlatformAnthropic)
+ require.NoError(t, err)
+ require.NotNil(t, acc)
+ require.Equal(t, int64(1), acc.ID)
+}
+
+func TestGatewayService_SelectAccountForModelWithPlatform_StickyModelMismatchFallback(t *testing.T) {
+ ctx := context.Background()
+
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {
+ ID: 1,
+ Platform: PlatformAnthropic,
+ Priority: 1,
+ Status: StatusActive,
+ Schedulable: true,
+ Credentials: map[string]any{"model_mapping": map[string]any{"claude-3-5-haiku-20241022": "claude-3-5-haiku-20241022"}},
+ },
+ {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{
+ sessionBindings: map[string]int64{"session-miss": 1},
+ }
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ cache: cache,
+ cfg: testConfig(),
+ }
+
+ acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "session-miss", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
+ require.NoError(t, err)
+ require.NotNil(t, acc)
+ require.Equal(t, int64(2), acc.ID)
+}
+
+func TestGatewayService_SelectAccountForModelWithPlatform_PreferNeverUsed(t *testing.T) {
+ ctx := context.Background()
+ lastUsed := time.Now().Add(-1 * time.Hour)
+
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: &lastUsed},
+ {ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{}
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ cache: cache,
+ cfg: testConfig(),
+ }
+
+ acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
+ require.NoError(t, err)
+ require.NotNil(t, acc)
+ require.Equal(t, int64(2), acc.ID)
+}
+
+func TestGatewayService_SelectAccountForModelWithPlatform_NoAccounts(t *testing.T) {
+ ctx := context.Background()
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{},
+ accountsByID: map[int64]*Account{},
+ }
+
+ cache := &mockGatewayCacheForPlatform{}
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ cache: cache,
+ cfg: testConfig(),
+ }
+
+ acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformAnthropic)
+ require.Error(t, err)
+ require.Nil(t, acc)
+ require.Contains(t, err.Error(), "no available accounts")
+}
+
func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
svc := &GatewayService{}
@@ -743,6 +1113,301 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
require.Equal(t, int64(2), acc.ID, "应选择优先级最高的账户(包含启用混合调度的antigravity)")
})
+ t.Run("混合调度-路由优先选择路由账号", func(t *testing.T) {
+ groupID := int64(30)
+ requestedModel := "claude-3-5-sonnet-20241022"
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
+ {ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{}
+
+ groupRepo := &mockGroupRepoForGateway{
+ groups: map[int64]*Group{
+ groupID: {
+ ID: groupID,
+ Name: "route-mixed-select",
+ Platform: PlatformAnthropic,
+ Status: StatusActive,
+ Hydrated: true,
+ ModelRoutingEnabled: true,
+ ModelRouting: map[string][]int64{
+ requestedModel: {2},
+ },
+ },
+ },
+ }
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ cache: cache,
+ cfg: testConfig(),
+ groupRepo: groupRepo,
+ }
+
+ acc, err := svc.selectAccountWithMixedScheduling(ctx, &groupID, "", requestedModel, nil, PlatformAnthropic)
+ require.NoError(t, err)
+ require.NotNil(t, acc)
+ require.Equal(t, int64(2), acc.ID)
+ })
+
+ t.Run("混合调度-路由粘性命中", func(t *testing.T) {
+ groupID := int64(31)
+ requestedModel := "claude-3-5-sonnet-20241022"
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
+ {ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}, AccountGroups: []AccountGroup{{GroupID: groupID}}},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{
+ sessionBindings: map[string]int64{"session-777": 2},
+ }
+
+ groupRepo := &mockGroupRepoForGateway{
+ groups: map[int64]*Group{
+ groupID: {
+ ID: groupID,
+ Name: "route-mixed-sticky",
+ Platform: PlatformAnthropic,
+ Status: StatusActive,
+ Hydrated: true,
+ ModelRoutingEnabled: true,
+ ModelRouting: map[string][]int64{
+ requestedModel: {2},
+ },
+ },
+ },
+ }
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ cache: cache,
+ cfg: testConfig(),
+ groupRepo: groupRepo,
+ }
+
+ acc, err := svc.selectAccountWithMixedScheduling(ctx, &groupID, "session-777", requestedModel, nil, PlatformAnthropic)
+ require.NoError(t, err)
+ require.NotNil(t, acc)
+ require.Equal(t, int64(2), acc.ID)
+ })
+
+ t.Run("混合调度-路由账号缺失回退", func(t *testing.T) {
+ groupID := int64(32)
+ requestedModel := "claude-3-5-sonnet-20241022"
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
+ {ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{}
+
+ groupRepo := &mockGroupRepoForGateway{
+ groups: map[int64]*Group{
+ groupID: {
+ ID: groupID,
+ Name: "route-mixed-miss",
+ Platform: PlatformAnthropic,
+ Status: StatusActive,
+ Hydrated: true,
+ ModelRoutingEnabled: true,
+ ModelRouting: map[string][]int64{
+ requestedModel: {99},
+ },
+ },
+ },
+ }
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ cache: cache,
+ cfg: testConfig(),
+ groupRepo: groupRepo,
+ }
+
+ acc, err := svc.selectAccountWithMixedScheduling(ctx, &groupID, "", requestedModel, nil, PlatformAnthropic)
+ require.NoError(t, err)
+ require.NotNil(t, acc)
+ require.Equal(t, int64(1), acc.ID)
+ })
+
+ t.Run("混合调度-路由账号未启用mixed_scheduling回退", func(t *testing.T) {
+ groupID := int64(33)
+ requestedModel := "claude-3-5-sonnet-20241022"
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
+ {ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true}, // 未启用 mixed_scheduling
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{}
+
+ groupRepo := &mockGroupRepoForGateway{
+ groups: map[int64]*Group{
+ groupID: {
+ ID: groupID,
+ Name: "route-mixed-disabled",
+ Platform: PlatformAnthropic,
+ Status: StatusActive,
+ Hydrated: true,
+ ModelRoutingEnabled: true,
+ ModelRouting: map[string][]int64{
+ requestedModel: {2},
+ },
+ },
+ },
+ }
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ cache: cache,
+ cfg: testConfig(),
+ groupRepo: groupRepo,
+ }
+
+ acc, err := svc.selectAccountWithMixedScheduling(ctx, &groupID, "", requestedModel, nil, PlatformAnthropic)
+ require.NoError(t, err)
+ require.NotNil(t, acc)
+ require.Equal(t, int64(1), acc.ID)
+ })
+
+ t.Run("混合调度-路由过滤覆盖", func(t *testing.T) {
+ groupID := int64(35)
+ requestedModel := "claude-3-5-sonnet-20241022"
+ resetAt := time.Now().Add(10 * time.Minute)
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
+ {ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: false},
+ {ID: 3, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
+ {
+ ID: 4,
+ Platform: PlatformAnthropic,
+ Priority: 1,
+ Status: StatusActive,
+ Schedulable: true,
+ Extra: map[string]any{
+ "model_rate_limits": map[string]any{
+ "claude_sonnet": map[string]any{
+ "rate_limit_reset_at": resetAt.Format(time.RFC3339),
+ },
+ },
+ },
+ },
+ {
+ ID: 5,
+ Platform: PlatformAnthropic,
+ Priority: 1,
+ Status: StatusActive,
+ Schedulable: true,
+ Credentials: map[string]any{"model_mapping": map[string]any{"claude-3-5-haiku-20241022": "claude-3-5-haiku-20241022"}},
+ },
+ {ID: 6, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
+ {ID: 7, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{}
+
+ groupRepo := &mockGroupRepoForGateway{
+ groups: map[int64]*Group{
+ groupID: {
+ ID: groupID,
+ Name: "route-mixed-filter",
+ Platform: PlatformAnthropic,
+ Status: StatusActive,
+ Hydrated: true,
+ ModelRoutingEnabled: true,
+ ModelRouting: map[string][]int64{
+ requestedModel: {1, 2, 3, 4, 5, 6, 7},
+ },
+ },
+ },
+ }
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ cache: cache,
+ cfg: testConfig(),
+ groupRepo: groupRepo,
+ }
+
+ excluded := map[int64]struct{}{1: {}}
+ acc, err := svc.selectAccountWithMixedScheduling(ctx, &groupID, "", requestedModel, excluded, PlatformAnthropic)
+ require.NoError(t, err)
+ require.NotNil(t, acc)
+ require.Equal(t, int64(7), acc.ID)
+ })
+
+ t.Run("混合调度-粘性命中分组账号", func(t *testing.T) {
+ groupID := int64(34)
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, AccountGroups: []AccountGroup{{GroupID: groupID}}},
+ {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, AccountGroups: []AccountGroup{{GroupID: groupID}}},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{
+ sessionBindings: map[string]int64{"session-group": 1},
+ }
+
+ groupRepo := &mockGroupRepoForGateway{
+ groups: map[int64]*Group{
+ groupID: {
+ ID: groupID,
+ Platform: PlatformAnthropic,
+ Status: StatusActive,
+ Hydrated: true,
+ },
+ },
+ }
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ cache: cache,
+ cfg: testConfig(),
+ groupRepo: groupRepo,
+ }
+
+ acc, err := svc.selectAccountWithMixedScheduling(ctx, &groupID, "session-group", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
+ require.NoError(t, err)
+ require.NotNil(t, acc)
+ require.Equal(t, int64(1), acc.ID)
+ })
+
t.Run("混合调度-过滤未启用mixed_scheduling的antigravity账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
@@ -826,6 +1491,85 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
require.Equal(t, int64(1), acc.ID, "粘性会话绑定的账户未启用mixed_scheduling,应降级选择anthropic账户")
})
+ t.Run("混合调度-粘性会话不可调度-清理并回退", func(t *testing.T) {
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusDisabled, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
+ {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{
+ sessionBindings: map[string]int64{"session-123": 1},
+ }
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ cache: cache,
+ cfg: testConfig(),
+ }
+
+ acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
+ require.NoError(t, err)
+ require.NotNil(t, acc)
+ require.Equal(t, int64(2), acc.ID)
+ require.Equal(t, 1, cache.deletedSessions["session-123"])
+ require.Equal(t, int64(2), cache.sessionBindings["session-123"])
+ })
+
+ t.Run("混合调度-路由粘性不可调度-清理并回退", func(t *testing.T) {
+ groupID := int64(12)
+ requestedModel := "claude-3-5-sonnet-20241022"
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusDisabled, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
+ {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{
+ sessionBindings: map[string]int64{"session-123": 1},
+ }
+
+ groupRepo := &mockGroupRepoForGateway{
+ groups: map[int64]*Group{
+ groupID: {
+ ID: groupID,
+ Name: "route-mixed",
+ Platform: PlatformAnthropic,
+ Status: StatusActive,
+ Hydrated: true,
+ ModelRoutingEnabled: true,
+ ModelRouting: map[string][]int64{
+ requestedModel: {1, 2},
+ },
+ },
+ },
+ }
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ cache: cache,
+ cfg: testConfig(),
+ groupRepo: groupRepo,
+ }
+
+ acc, err := svc.selectAccountWithMixedScheduling(ctx, &groupID, "session-123", requestedModel, nil, PlatformAnthropic)
+ require.NoError(t, err)
+ require.NotNil(t, acc)
+ require.Equal(t, int64(2), acc.ID)
+ require.Equal(t, 1, cache.deletedSessions["session-123"])
+ require.Equal(t, int64(2), cache.sessionBindings["session-123"])
+ })
+
t.Run("混合调度-仅有启用mixed_scheduling的antigravity账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
@@ -876,6 +1620,65 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
require.Nil(t, acc)
require.Contains(t, err.Error(), "no available accounts")
})
+
+ t.Run("混合调度-不支持模型返回错误", func(t *testing.T) {
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {
+ ID: 1,
+ Platform: PlatformAnthropic,
+ Priority: 1,
+ Status: StatusActive,
+ Schedulable: true,
+ Credentials: map[string]any{"model_mapping": map[string]any{"claude-3-5-haiku-20241022": "claude-3-5-haiku-20241022"}},
+ },
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{}
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ cache: cache,
+ cfg: testConfig(),
+ }
+
+ acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
+ require.Error(t, err)
+ require.Nil(t, acc)
+ require.Contains(t, err.Error(), "supporting model")
+ })
+
+ t.Run("混合调度-优先未使用账号", func(t *testing.T) {
+ lastUsed := time.Now().Add(-2 * time.Hour)
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: &lastUsed},
+ {ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{}
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ cache: cache,
+ cfg: testConfig(),
+ }
+
+ acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
+ require.NoError(t, err)
+ require.NotNil(t, acc)
+ require.Equal(t, int64(2), acc.ID)
+ })
}
// TestAccount_IsMixedSchedulingEnabled 测试混合调度开关检查
@@ -962,10 +1765,20 @@ func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, acc
type mockConcurrencyCache struct {
acquireAccountCalls int
loadBatchCalls int
+ acquireResults map[int64]bool
+ loadBatchErr error
+ loadMap map[int64]*AccountLoadInfo
+ waitCounts map[int64]int
+ skipDefaultLoad bool
}
func (m *mockConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
m.acquireAccountCalls++
+ if m.acquireResults != nil {
+ if result, ok := m.acquireResults[accountID]; ok {
+ return result, nil
+ }
+ }
return true, nil
}
@@ -986,6 +1799,11 @@ func (m *mockConcurrencyCache) DecrementAccountWaitCount(ctx context.Context, ac
}
func (m *mockConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
+ if m.waitCounts != nil {
+ if count, ok := m.waitCounts[accountID]; ok {
+ return count, nil
+ }
+ }
return 0, nil
}
@@ -1011,8 +1829,25 @@ func (m *mockConcurrencyCache) DecrementWaitCount(ctx context.Context, userID in
func (m *mockConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
m.loadBatchCalls++
+ if m.loadBatchErr != nil {
+ return nil, m.loadBatchErr
+ }
result := make(map[int64]*AccountLoadInfo, len(accounts))
+ if m.skipDefaultLoad && m.loadMap != nil {
+ for _, acc := range accounts {
+ if load, ok := m.loadMap[acc.ID]; ok {
+ result[acc.ID] = load
+ }
+ }
+ return result, nil
+ }
for _, acc := range accounts {
+ if m.loadMap != nil {
+ if load, ok := m.loadMap[acc.ID]; ok {
+ result[acc.ID] = load
+ continue
+ }
+ }
result[acc.ID] = &AccountLoadInfo{
AccountID: acc.ID,
CurrentConcurrency: 0,
@@ -1251,6 +2086,48 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
require.Equal(t, 1, concurrencyCache.loadBatchCalls, "应继续进行负载批量查询")
})
+ t.Run("粘性账号禁用-清理会话并回退选择", func(t *testing.T) {
+ testCtx := context.WithValue(ctx, ctxkey.ForcePlatform, PlatformAnthropic)
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: false, Concurrency: 5},
+ {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+ repo.listPlatformFunc = func(ctx context.Context, platform string) ([]Account, error) {
+ return repo.accounts, nil
+ }
+
+ cache := &mockGatewayCacheForPlatform{
+ sessionBindings: map[string]int64{"sticky": 1},
+ }
+
+ cfg := testConfig()
+ cfg.Gateway.Scheduling.LoadBatchEnabled = true
+
+ concurrencyCache := &mockConcurrencyCache{}
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ result, err := svc.SelectAccountWithLoadAwareness(testCtx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.NotNil(t, result.Account)
+ require.Equal(t, int64(2), result.Account.ID, "粘性账号禁用时应回退到可用账号")
+ updatedID, ok := cache.sessionBindings["sticky"]
+ require.True(t, ok, "粘性会话应更新绑定")
+ require.Equal(t, int64(2), updatedID, "粘性会话应绑定到新账号")
+ })
+
t.Run("无可用账号-返回错误", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{},
@@ -1340,6 +2217,751 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "应跳过过载账号,选择可用账号")
})
+
+ t.Run("粘性账号槽位满-返回粘性等待计划", func(t *testing.T) {
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{
+ sessionBindings: map[string]int64{"sticky": 1},
+ }
+
+ cfg := testConfig()
+ cfg.Gateway.Scheduling.LoadBatchEnabled = true
+ cfg.Gateway.Scheduling.StickySessionMaxWaiting = 1
+
+ concurrencyCache := &mockConcurrencyCache{
+ acquireResults: map[int64]bool{1: false},
+ waitCounts: map[int64]int{1: 0},
+ }
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.NotNil(t, result.WaitPlan)
+ require.Equal(t, int64(1), result.Account.ID)
+ require.Equal(t, 0, concurrencyCache.loadBatchCalls)
+ })
+
+ t.Run("负载批量查询失败-降级旧顺序选择", func(t *testing.T) {
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
+ {ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{}
+
+ cfg := testConfig()
+ cfg.Gateway.Scheduling.LoadBatchEnabled = true
+
+ concurrencyCache := &mockConcurrencyCache{
+ loadBatchErr: errors.New("load batch failed"),
+ }
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "legacy", "claude-3-5-sonnet-20241022", nil, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.NotNil(t, result.Account)
+ require.Equal(t, int64(2), result.Account.ID)
+ require.Equal(t, int64(2), cache.sessionBindings["legacy"])
+ })
+
+ t.Run("模型路由-粘性账号等待计划", func(t *testing.T) {
+ groupID := int64(20)
+ sessionHash := "route-sticky"
+
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
+ {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{
+ sessionBindings: map[string]int64{sessionHash: 1},
+ }
+
+ groupRepo := &mockGroupRepoForGateway{
+ groups: map[int64]*Group{
+ groupID: {
+ ID: groupID,
+ Platform: PlatformAnthropic,
+ Status: StatusActive,
+ Hydrated: true,
+ ModelRoutingEnabled: true,
+ ModelRouting: map[string][]int64{
+ "claude-3-5-sonnet-20241022": {1, 2},
+ },
+ },
+ },
+ }
+
+ cfg := testConfig()
+ cfg.Gateway.Scheduling.LoadBatchEnabled = true
+ cfg.Gateway.Scheduling.StickySessionMaxWaiting = 1
+
+ concurrencyCache := &mockConcurrencyCache{
+ acquireResults: map[int64]bool{1: false},
+ waitCounts: map[int64]int{1: 0},
+ }
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ groupRepo: groupRepo,
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.NotNil(t, result.WaitPlan)
+ require.Equal(t, int64(1), result.Account.ID)
+ })
+
+ t.Run("模型路由-粘性账号命中", func(t *testing.T) {
+ groupID := int64(20)
+ sessionHash := "route-hit"
+
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
+ {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{
+ sessionBindings: map[string]int64{sessionHash: 1},
+ }
+
+ groupRepo := &mockGroupRepoForGateway{
+ groups: map[int64]*Group{
+ groupID: {
+ ID: groupID,
+ Platform: PlatformAnthropic,
+ Status: StatusActive,
+ Hydrated: true,
+ ModelRoutingEnabled: true,
+ ModelRouting: map[string][]int64{
+ "claude-3-5-sonnet-20241022": {1, 2},
+ },
+ },
+ },
+ }
+
+ cfg := testConfig()
+ cfg.Gateway.Scheduling.LoadBatchEnabled = true
+
+ concurrencyCache := &mockConcurrencyCache{}
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ groupRepo: groupRepo,
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.NotNil(t, result.Account)
+ require.Equal(t, int64(1), result.Account.ID)
+ require.Equal(t, 0, concurrencyCache.loadBatchCalls)
+ })
+
+ t.Run("模型路由-粘性账号缺失-清理并回退", func(t *testing.T) {
+ groupID := int64(22)
+ sessionHash := "route-missing"
+
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{
+ sessionBindings: map[string]int64{sessionHash: 1},
+ }
+
+ groupRepo := &mockGroupRepoForGateway{
+ groups: map[int64]*Group{
+ groupID: {
+ ID: groupID,
+ Platform: PlatformAnthropic,
+ Status: StatusActive,
+ Hydrated: true,
+ ModelRoutingEnabled: true,
+ ModelRouting: map[string][]int64{
+ "claude-3-5-sonnet-20241022": {1, 2},
+ },
+ },
+ },
+ }
+
+ cfg := testConfig()
+ cfg.Gateway.Scheduling.LoadBatchEnabled = true
+
+ concurrencyCache := &mockConcurrencyCache{}
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ groupRepo: groupRepo,
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.NotNil(t, result.Account)
+ require.Equal(t, int64(2), result.Account.ID)
+ require.Equal(t, 1, cache.deletedSessions[sessionHash])
+ require.Equal(t, int64(2), cache.sessionBindings[sessionHash])
+ })
+
+ t.Run("模型路由-按负载选择账号", func(t *testing.T) {
+ groupID := int64(21)
+
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
+ {ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{}
+
+ groupRepo := &mockGroupRepoForGateway{
+ groups: map[int64]*Group{
+ groupID: {
+ ID: groupID,
+ Platform: PlatformAnthropic,
+ Status: StatusActive,
+ Hydrated: true,
+ ModelRoutingEnabled: true,
+ ModelRouting: map[string][]int64{
+ "claude-3-5-sonnet-20241022": {1, 2},
+ },
+ },
+ },
+ }
+
+ cfg := testConfig()
+ cfg.Gateway.Scheduling.LoadBatchEnabled = true
+
+ concurrencyCache := &mockConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 1: {AccountID: 1, LoadRate: 80},
+ 2: {AccountID: 2, LoadRate: 20},
+ },
+ }
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ groupRepo: groupRepo,
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route", "claude-3-5-sonnet-20241022", nil, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.NotNil(t, result.Account)
+ require.Equal(t, int64(2), result.Account.ID)
+ require.Equal(t, int64(2), cache.sessionBindings["route"])
+ })
+
+ t.Run("模型路由-路由账号全满返回等待计划", func(t *testing.T) {
+ groupID := int64(23)
+
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
+ {ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{}
+
+ groupRepo := &mockGroupRepoForGateway{
+ groups: map[int64]*Group{
+ groupID: {
+ ID: groupID,
+ Platform: PlatformAnthropic,
+ Status: StatusActive,
+ Hydrated: true,
+ ModelRoutingEnabled: true,
+ ModelRouting: map[string][]int64{
+ "claude-3-5-sonnet-20241022": {1, 2},
+ },
+ },
+ },
+ }
+
+ cfg := testConfig()
+ cfg.Gateway.Scheduling.LoadBatchEnabled = true
+
+ concurrencyCache := &mockConcurrencyCache{
+ acquireResults: map[int64]bool{1: false, 2: false},
+ loadMap: map[int64]*AccountLoadInfo{
+ 1: {AccountID: 1, LoadRate: 10},
+ 2: {AccountID: 2, LoadRate: 20},
+ },
+ }
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ groupRepo: groupRepo,
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route-full", "claude-3-5-sonnet-20241022", nil, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.NotNil(t, result.WaitPlan)
+ require.Equal(t, int64(1), result.Account.ID)
+ })
+
+ t.Run("模型路由-路由账号全满-回退普通选择", func(t *testing.T) {
+ groupID := int64(22)
+
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
+ {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
+ {ID: 3, Platform: PlatformAnthropic, Priority: 0, Status: StatusActive, Schedulable: true, Concurrency: 5},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{}
+
+ groupRepo := &mockGroupRepoForGateway{
+ groups: map[int64]*Group{
+ groupID: {
+ ID: groupID,
+ Platform: PlatformAnthropic,
+ Status: StatusActive,
+ Hydrated: true,
+ ModelRoutingEnabled: true,
+ ModelRouting: map[string][]int64{
+ "claude-3-5-sonnet-20241022": {1, 2},
+ },
+ },
+ },
+ }
+
+ cfg := testConfig()
+ cfg.Gateway.Scheduling.LoadBatchEnabled = true
+
+ concurrencyCache := &mockConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 1: {AccountID: 1, LoadRate: 100},
+ 2: {AccountID: 2, LoadRate: 100},
+ 3: {AccountID: 3, LoadRate: 0},
+ },
+ }
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ groupRepo: groupRepo,
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "fallback", "claude-3-5-sonnet-20241022", nil, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.NotNil(t, result.Account)
+ require.Equal(t, int64(3), result.Account.ID)
+ require.Equal(t, int64(3), cache.sessionBindings["fallback"])
+ })
+
+ t.Run("负载批量失败且无法获取-兜底等待", func(t *testing.T) {
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
+ {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{}
+
+ cfg := testConfig()
+ cfg.Gateway.Scheduling.LoadBatchEnabled = true
+
+ concurrencyCache := &mockConcurrencyCache{
+ loadBatchErr: errors.New("load batch failed"),
+ acquireResults: map[int64]bool{1: false, 2: false},
+ }
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.NotNil(t, result.WaitPlan)
+ require.Equal(t, int64(1), result.Account.ID)
+ })
+
+ t.Run("Gemini负载排序-优先OAuth", func(t *testing.T) {
+ groupID := int64(24)
+
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, Type: AccountTypeAPIKey},
+ {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, Type: AccountTypeOAuth},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{}
+
+ groupRepo := &mockGroupRepoForGateway{
+ groups: map[int64]*Group{
+ groupID: {
+ ID: groupID,
+ Platform: PlatformGemini,
+ Status: StatusActive,
+ Hydrated: true,
+ },
+ },
+ }
+
+ cfg := testConfig()
+ cfg.Gateway.Scheduling.LoadBatchEnabled = true
+
+ concurrencyCache := &mockConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 1: {AccountID: 1, LoadRate: 10},
+ 2: {AccountID: 2, LoadRate: 10},
+ },
+ }
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ groupRepo: groupRepo,
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "gemini", "gemini-2.5-pro", nil, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.NotNil(t, result.Account)
+ require.Equal(t, int64(2), result.Account.ID)
+ })
+
+ t.Run("模型路由-过滤路径覆盖", func(t *testing.T) {
+ groupID := int64(70)
+ now := time.Now().Add(10 * time.Minute)
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
+ {ID: 3, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: false, Concurrency: 5},
+ {ID: 4, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
+ {
+ ID: 5,
+ Platform: PlatformAnthropic,
+ Priority: 1,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ Extra: map[string]any{
+ "model_rate_limits": map[string]any{
+ "claude_sonnet": map[string]any{
+ "rate_limit_reset_at": now.Format(time.RFC3339),
+ },
+ },
+ },
+ },
+ {
+ ID: 6,
+ Platform: PlatformAnthropic,
+ Priority: 1,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ Credentials: map[string]any{"model_mapping": map[string]any{"claude-3-5-haiku-20241022": "claude-3-5-haiku-20241022"}},
+ },
+ {ID: 7, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{}
+
+ groupRepo := &mockGroupRepoForGateway{
+ groups: map[int64]*Group{
+ groupID: {
+ ID: groupID,
+ Platform: PlatformAnthropic,
+ Status: StatusActive,
+ Hydrated: true,
+ ModelRoutingEnabled: true,
+ ModelRouting: map[string][]int64{
+ "claude-3-5-sonnet-20241022": {1, 2, 3, 4, 5, 6},
+ },
+ },
+ },
+ }
+
+ cfg := testConfig()
+ cfg.Gateway.Scheduling.LoadBatchEnabled = true
+
+ concurrencyCache := &mockConcurrencyCache{}
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ groupRepo: groupRepo,
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ excluded := map[int64]struct{}{1: {}}
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", excluded, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.NotNil(t, result.Account)
+ require.Equal(t, int64(7), result.Account.ID)
+ })
+
+ t.Run("ClaudeCode限制-回退分组", func(t *testing.T) {
+ groupID := int64(60)
+ fallbackID := int64(61)
+
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ groupRepo := &mockGroupRepoForGateway{
+ groups: map[int64]*Group{
+ groupID: {
+ ID: groupID,
+ Platform: PlatformAnthropic,
+ Status: StatusActive,
+ Hydrated: true,
+ ClaudeCodeOnly: true,
+ FallbackGroupID: func() *int64 {
+ v := fallbackID
+ return &v
+ }(),
+ },
+ fallbackID: {
+ ID: fallbackID,
+ Platform: PlatformGemini,
+ Status: StatusActive,
+ Hydrated: true,
+ },
+ },
+ }
+
+ cfg := testConfig()
+ cfg.Gateway.Scheduling.LoadBatchEnabled = false
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ groupRepo: groupRepo,
+ cache: &mockGatewayCacheForPlatform{},
+ cfg: cfg,
+ concurrencyService: nil,
+ }
+
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "gemini-2.5-pro", nil, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.NotNil(t, result.Account)
+ require.Equal(t, int64(1), result.Account.ID)
+ })
+
+ t.Run("ClaudeCode限制-无降级返回错误", func(t *testing.T) {
+ groupID := int64(62)
+
+ groupRepo := &mockGroupRepoForGateway{
+ groups: map[int64]*Group{
+ groupID: {
+ ID: groupID,
+ Platform: PlatformAnthropic,
+ Status: StatusActive,
+ Hydrated: true,
+ ClaudeCodeOnly: true,
+ },
+ },
+ }
+
+ cfg := testConfig()
+ cfg.Gateway.Scheduling.LoadBatchEnabled = false
+
+ svc := &GatewayService{
+ accountRepo: &mockAccountRepoForPlatform{},
+ groupRepo: groupRepo,
+ cache: &mockGatewayCacheForPlatform{},
+ cfg: cfg,
+ concurrencyService: nil,
+ }
+
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil, "")
+ require.Error(t, err)
+ require.Nil(t, result)
+ require.ErrorIs(t, err, ErrClaudeCodeOnly)
+ })
+
+ t.Run("负载可用但无法获取槽位-兜底等待", func(t *testing.T) {
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
+ {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{}
+
+ cfg := testConfig()
+ cfg.Gateway.Scheduling.LoadBatchEnabled = true
+
+ concurrencyCache := &mockConcurrencyCache{
+ acquireResults: map[int64]bool{1: false, 2: false},
+ loadMap: map[int64]*AccountLoadInfo{
+ 1: {AccountID: 1, LoadRate: 10},
+ 2: {AccountID: 2, LoadRate: 20},
+ },
+ }
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "wait", "claude-3-5-sonnet-20241022", nil, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.NotNil(t, result.WaitPlan)
+ require.Equal(t, int64(1), result.Account.ID)
+ })
+
+ t.Run("负载信息缺失-使用默认负载", func(t *testing.T) {
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
+ {ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{}
+
+ cfg := testConfig()
+ cfg.Gateway.Scheduling.LoadBatchEnabled = true
+
+ concurrencyCache := &mockConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 1: {AccountID: 1, LoadRate: 50},
+ },
+ skipDefaultLoad: true,
+ }
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "missing-load", "claude-3-5-sonnet-20241022", nil, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.NotNil(t, result.Account)
+ require.Equal(t, int64(2), result.Account.ID)
+ })
}
func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) {
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index f04397e8..36bb8e84 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -136,11 +136,24 @@ var allowedHeaders = map[string]bool{
"content-type": true,
}
-// GatewayCache defines cache operations for gateway service
+// GatewayCache 定义网关服务的缓存操作接口。
+// 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。
+//
+// GatewayCache defines cache operations for gateway service.
+// Provides sticky session storage, retrieval, refresh and deletion capabilities.
type GatewayCache interface {
+ // GetSessionAccountID 获取粘性会话绑定的账号 ID
+ // Get the account ID bound to a sticky session
GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error)
+ // SetSessionAccountID 设置粘性会话与账号的绑定关系
+ // Set the binding between sticky session and account
SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error
+ // RefreshSessionTTL 刷新粘性会话的过期时间
+ // Refresh the expiration time of a sticky session
RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error
+ // DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
+ // Delete sticky session binding, used to proactively clean up when account becomes unavailable
+ DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
}
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
@@ -151,6 +164,28 @@ func derefGroupID(groupID *int64) int64 {
return *groupID
}
+// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
+// 当账号状态为错误、禁用、不可调度,或处于临时不可调度期间时,返回 true。
+// 这确保后续请求不会继续使用不可用的账号。
+//
+// shouldClearStickySession checks if an account is in an unschedulable state
+// and the sticky session binding should be cleared.
+// Returns true when account status is error/disabled, schedulable is false,
+// or within temporary unschedulable period.
+// This ensures subsequent requests won't continue using unavailable accounts.
+func shouldClearStickySession(account *Account) bool {
+ if account == nil {
+ return false
+ }
+ if account.Status == StatusError || account.Status == StatusDisabled || !account.Schedulable {
+ return true
+ }
+ if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
+ return true
+ }
+ return false
+}
+
type AccountWaitPlan struct {
AccountID int64
MaxConcurrency int
@@ -1067,6 +1102,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
}
+ } else {
+ _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
}
}
@@ -1173,41 +1210,52 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, ok := accountByID[accountID]
- if ok && s.isAccountInGroup(account, groupID) &&
- s.isAccountAllowedForPlatform(account, platform, useMixed) &&
- account.IsSchedulableForModel(requestedModel) &&
- (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) &&
- s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
- result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
- if err == nil && result.Acquired {
- // 会话数量限制检查
- if !s.checkAndRegisterSession(ctx, account, sessionHash) {
- result.ReleaseFunc() // 释放槽位,继续到 Layer 2
- } else {
- _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
- return &AccountSelectionResult{
- Account: account,
- Acquired: true,
- ReleaseFunc: result.ReleaseFunc,
- }, nil
- }
+ if ok {
+ // 检查账户是否需要清理粘性会话绑定
+ // Check if the account needs sticky session cleanup
+ clearSticky := shouldClearStickySession(account)
+ if clearSticky {
+ _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
+ if !clearSticky && s.isAccountInGroup(account, groupID) &&
+ s.isAccountAllowedForPlatform(account, platform, useMixed) &&
+ account.IsSchedulableForModel(requestedModel) &&
+ (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) &&
+ s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
+ result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
+ if err == nil && result.Acquired {
+ // 会话数量限制检查
+ // Session count limit check
+ if !s.checkAndRegisterSession(ctx, account, sessionHash) {
+ result.ReleaseFunc() // 释放槽位,继续到 Layer 2
+ } else {
+ _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
+ return &AccountSelectionResult{
+ Account: account,
+ Acquired: true,
+ ReleaseFunc: result.ReleaseFunc,
+ }, nil
+ }
+ }
- waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
- if waitingCount < cfg.StickySessionMaxWaiting {
- // 会话数量限制检查(等待计划也需要占用会话配额)
- if !s.checkAndRegisterSession(ctx, account, sessionHash) {
- // 会话限制已满,继续到 Layer 2
- } else {
- return &AccountSelectionResult{
- Account: account,
- WaitPlan: &AccountWaitPlan{
- AccountID: accountID,
- MaxConcurrency: account.Concurrency,
- Timeout: cfg.StickySessionWaitTimeout,
- MaxWaiting: cfg.StickySessionMaxWaiting,
- },
- }, nil
+ waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
+ if waitingCount < cfg.StickySessionMaxWaiting {
+ // 会话数量限制检查(等待计划也需要占用会话配额)
+ // Session count limit check (wait plan also requires session quota)
+ if !s.checkAndRegisterSession(ctx, account, sessionHash) {
+ // 会话限制已满,继续到 Layer 2
+ // Session limit full, continue to Layer 2
+ } else {
+ return &AccountSelectionResult{
+ Account: account,
+ WaitPlan: &AccountWaitPlan{
+ AccountID: accountID,
+ MaxConcurrency: account.Concurrency,
+ Timeout: cfg.StickySessionWaitTimeout,
+ MaxWaiting: cfg.StickySessionMaxWaiting,
+ },
+ }, nil
+ }
}
}
}
@@ -1827,14 +1875,20 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
- if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
- if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
- log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
+ if err == nil {
+ clearSticky := shouldClearStickySession(account)
+ if clearSticky {
+ _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
- if s.debugModelRoutingEnabled() {
- log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
+ if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
+ if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
+ log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
+ }
+ if s.debugModelRoutingEnabled() {
+ log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
+ }
+ return account, nil
}
- return account, nil
}
}
}
@@ -1924,11 +1978,17 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
- if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
- if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
- log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
+ if err == nil {
+ clearSticky := shouldClearStickySession(account)
+ if clearSticky {
+ _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
+ }
+ if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
+ if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
+ log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
+ }
+ return account, nil
}
- return account, nil
}
}
}
@@ -2028,15 +2088,21 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
- if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
- if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
- if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
- log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
+ if err == nil {
+ clearSticky := shouldClearStickySession(account)
+ if clearSticky {
+ _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
+ }
+ if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
+ if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
+ if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
+ log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
+ }
+ if s.debugModelRoutingEnabled() {
+ log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
+ }
+ return account, nil
}
- if s.debugModelRoutingEnabled() {
- log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
- }
- return account, nil
}
}
}
@@ -2127,12 +2193,18 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
- if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
- if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
- if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
- log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
+ if err == nil {
+ clearSticky := shouldClearStickySession(account)
+ if clearSticky {
+ _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
+ }
+ if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
+ if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
+ if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
+ log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
+ }
+ return account, nil
}
- return account, nil
}
}
}
diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go
index 75de90f2..7234540f 100644
--- a/backend/internal/service/gemini_messages_compat_service.go
+++ b/backend/internal/service/gemini_messages_compat_service.go
@@ -82,70 +82,23 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
}
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
- // 优先检查 context 中的强制平台(/antigravity 路由)
- var platform string
- forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
- if hasForcePlatform && forcePlatform != "" {
- platform = forcePlatform
- } else if groupID != nil {
- // 根据分组 platform 决定查询哪种账号
- var group *Group
- if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID {
- group = ctxGroup
- } else {
- var err error
- group, err = s.groupRepo.GetByIDLite(ctx, *groupID)
- if err != nil {
- return nil, fmt.Errorf("get group failed: %w", err)
- }
- }
- platform = group.Platform
- } else {
- // 无分组时只使用原生 gemini 平台
- platform = PlatformGemini
+ // 1. 确定目标平台和调度模式
+ // Determine target platform and scheduling mode
+ platform, useMixedScheduling, hasForcePlatform, err := s.resolvePlatformAndSchedulingMode(ctx, groupID)
+ if err != nil {
+ return nil, err
}
- // gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
- // 注意:强制平台模式不走混合调度
- useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
-
cacheKey := "gemini:" + sessionHash
- if sessionHash != "" {
- accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
- if err == nil && accountID > 0 {
- if _, excluded := excludedIDs[accountID]; !excluded {
- account, err := s.getSchedulableAccount(ctx, accountID)
- // 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
- if err == nil && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
- valid := false
- if account.Platform == platform {
- valid = true
- } else if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
- valid = true
- }
- if valid {
- usable := true
- if s.rateLimitService != nil && requestedModel != "" {
- ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
- if err != nil {
- log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
- }
- if !ok {
- usable = false
- }
- }
- if usable {
- _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL)
- return account, nil
- }
- }
- }
- }
- }
+ // 2. 尝试粘性会话命中
+ // Try sticky session hit
+ if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs, platform, useMixedScheduling); account != nil {
+ return account, nil
}
- // 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
+ // 3. 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
+ // Query schedulable accounts (force platform mode: try group first, fallback to all)
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, platform, hasForcePlatform)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
@@ -158,56 +111,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
}
}
- var selected *Account
- for i := range accounts {
- acc := &accounts[i]
- if _, excluded := excludedIDs[acc.ID]; excluded {
- continue
- }
- // 混合调度模式下:原生平台直接通过,antigravity 需要启用 mixed_scheduling
- // 非混合调度模式(antigravity 分组):不需要过滤
- if useMixedScheduling && acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
- continue
- }
- if !acc.IsSchedulableForModel(requestedModel) {
- continue
- }
- if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
- continue
- }
- if s.rateLimitService != nil && requestedModel != "" {
- ok, err := s.rateLimitService.PreCheckUsage(ctx, acc, requestedModel)
- if err != nil {
- log.Printf("[Gemini PreCheck] Account %d precheck error: %v", acc.ID, err)
- }
- if !ok {
- continue
- }
- }
- if selected == nil {
- selected = acc
- continue
- }
- if acc.Priority < selected.Priority {
- selected = acc
- } else if acc.Priority == selected.Priority {
- switch {
- case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
- selected = acc
- case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
- // keep selected (never used is preferred)
- case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
- // Prefer OAuth accounts when both are unused (more compatible for Code Assist flows).
- if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth {
- selected = acc
- }
- default:
- if acc.LastUsedAt.Before(*selected.LastUsedAt) {
- selected = acc
- }
- }
- }
- }
+ // 4. 按优先级 + LRU 选择最佳账号
+ // Select best account by priority + LRU
+ selected := s.selectBestGeminiAccount(ctx, accounts, requestedModel, excludedIDs, platform, useMixedScheduling)
if selected == nil {
if requestedModel != "" {
@@ -216,6 +122,8 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
return nil, errors.New("no available Gemini accounts")
}
+ // 5. 设置粘性会话绑定
+ // Set sticky session binding
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL)
}
@@ -223,6 +131,229 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
return selected, nil
}
+// resolvePlatformAndSchedulingMode 解析目标平台和调度模式。
+// 返回:平台名称、是否使用混合调度、是否强制平台、错误。
+//
+// resolvePlatformAndSchedulingMode resolves target platform and scheduling mode.
+// Returns: platform name, whether to use mixed scheduling, whether force platform, error.
+func (s *GeminiMessagesCompatService) resolvePlatformAndSchedulingMode(ctx context.Context, groupID *int64) (platform string, useMixedScheduling bool, hasForcePlatform bool, err error) {
+ // 优先检查 context 中的强制平台(/antigravity 路由)
+ forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
+ if hasForcePlatform && forcePlatform != "" {
+ return forcePlatform, false, true, nil
+ }
+
+ if groupID != nil {
+ // 根据分组 platform 决定查询哪种账号
+ var group *Group
+ if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID {
+ group = ctxGroup
+ } else {
+ group, err = s.groupRepo.GetByIDLite(ctx, *groupID)
+ if err != nil {
+ return "", false, false, fmt.Errorf("get group failed: %w", err)
+ }
+ }
+ // gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
+ return group.Platform, group.Platform == PlatformGemini, false, nil
+ }
+
+ // 无分组时只使用原生 gemini 平台
+ return PlatformGemini, true, false, nil
+}
+
+// tryStickySessionHit 尝试从粘性会话获取账号。
+// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。
+//
+// tryStickySessionHit attempts to get account from sticky session.
+// Returns account if hit and usable; clears session and returns nil if account unavailable.
+func (s *GeminiMessagesCompatService) tryStickySessionHit(
+ ctx context.Context,
+ groupID *int64,
+ sessionHash, cacheKey, requestedModel string,
+ excludedIDs map[int64]struct{},
+ platform string,
+ useMixedScheduling bool,
+) *Account {
+ if sessionHash == "" {
+ return nil
+ }
+
+ accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
+ if err != nil || accountID <= 0 {
+ return nil
+ }
+
+ if _, excluded := excludedIDs[accountID]; excluded {
+ return nil
+ }
+
+ account, err := s.getSchedulableAccount(ctx, accountID)
+ if err != nil {
+ return nil
+ }
+
+ // 检查账号是否需要清理粘性会话
+ // Check if sticky session should be cleared
+ if shouldClearStickySession(account) {
+ _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
+ return nil
+ }
+
+ // 验证账号是否可用于当前请求
+ // Verify account is usable for current request
+ if !s.isAccountUsableForRequest(ctx, account, requestedModel, platform, useMixedScheduling) {
+ return nil
+ }
+
+ // 刷新会话 TTL 并返回账号
+ // Refresh session TTL and return account
+ _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL)
+ return account
+}
+
+// isAccountUsableForRequest 检查账号是否可用于当前请求。
+// 验证:模型调度、模型支持、平台匹配、速率限制预检。
+//
+// isAccountUsableForRequest checks if account is usable for current request.
+// Validates: model scheduling, model support, platform matching, rate limit precheck.
+func (s *GeminiMessagesCompatService) isAccountUsableForRequest(
+ ctx context.Context,
+ account *Account,
+ requestedModel, platform string,
+ useMixedScheduling bool,
+) bool {
+ // 检查模型调度能力
+ // Check model scheduling capability
+ if !account.IsSchedulableForModel(requestedModel) {
+ return false
+ }
+
+ // 检查模型支持
+ // Check model support
+ if requestedModel != "" && !s.isModelSupportedByAccount(account, requestedModel) {
+ return false
+ }
+
+ // 检查平台匹配
+ // Check platform matching
+ if !s.isAccountValidForPlatform(account, platform, useMixedScheduling) {
+ return false
+ }
+
+ // 速率限制预检
+ // Rate limit precheck
+ if !s.passesRateLimitPreCheck(ctx, account, requestedModel) {
+ return false
+ }
+
+ return true
+}
+
+// isAccountValidForPlatform 检查账号是否匹配目标平台。
+// 原生平台直接匹配;混合调度模式下 antigravity 需要启用 mixed_scheduling。
+//
+// isAccountValidForPlatform checks if account matches target platform.
+// Native platform matches directly; mixed scheduling mode requires antigravity to enable mixed_scheduling.
+func (s *GeminiMessagesCompatService) isAccountValidForPlatform(account *Account, platform string, useMixedScheduling bool) bool {
+ if account.Platform == platform {
+ return true
+ }
+ if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
+ return true
+ }
+ return false
+}
+
+// passesRateLimitPreCheck 执行速率限制预检。
+// 返回 true 表示通过预检或无需预检。
+//
+// passesRateLimitPreCheck performs rate limit precheck.
+// Returns true if passed or precheck not required.
+func (s *GeminiMessagesCompatService) passesRateLimitPreCheck(ctx context.Context, account *Account, requestedModel string) bool {
+ if s.rateLimitService == nil || requestedModel == "" {
+ return true
+ }
+ ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
+ if err != nil {
+ log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
+ }
+ return ok
+}
+
+// selectBestGeminiAccount 从候选账号中选择最佳账号(优先级 + LRU + OAuth 优先)。
+// 返回 nil 表示无可用账号。
+//
+// selectBestGeminiAccount selects best account from candidates (priority + LRU + OAuth preferred).
+// Returns nil if no available account.
+func (s *GeminiMessagesCompatService) selectBestGeminiAccount(
+ ctx context.Context,
+ accounts []Account,
+ requestedModel string,
+ excludedIDs map[int64]struct{},
+ platform string,
+ useMixedScheduling bool,
+) *Account {
+ var selected *Account
+
+ for i := range accounts {
+ acc := &accounts[i]
+
+ // 跳过被排除的账号
+ if _, excluded := excludedIDs[acc.ID]; excluded {
+ continue
+ }
+
+ // 检查账号是否可用于当前请求
+ if !s.isAccountUsableForRequest(ctx, acc, requestedModel, platform, useMixedScheduling) {
+ continue
+ }
+
+ // 选择最佳账号
+ if selected == nil {
+ selected = acc
+ continue
+ }
+
+ if s.isBetterGeminiAccount(acc, selected) {
+ selected = acc
+ }
+ }
+
+ return selected
+}
+
+// isBetterGeminiAccount 判断 candidate 是否比 current 更优。
+// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先(OAuth > 非 OAuth),其次是最久未使用的。
+//
+// isBetterGeminiAccount checks if candidate is better than current.
+// Rules: higher priority (lower value) wins; same priority: never used (OAuth > non-OAuth) > least recently used.
+func (s *GeminiMessagesCompatService) isBetterGeminiAccount(candidate, current *Account) bool {
+ // 优先级更高(数值更小)
+ if candidate.Priority < current.Priority {
+ return true
+ }
+ if candidate.Priority > current.Priority {
+ return false
+ }
+
+ // 同优先级,比较最后使用时间
+ switch {
+ case candidate.LastUsedAt == nil && current.LastUsedAt != nil:
+ // candidate 从未使用,优先
+ return true
+ case candidate.LastUsedAt != nil && current.LastUsedAt == nil:
+ // current 从未使用,保持
+ return false
+ case candidate.LastUsedAt == nil && current.LastUsedAt == nil:
+ // 都未使用,优先选择 OAuth 账号(更兼容 Code Assist 流程)
+ return candidate.Type == AccountTypeOAuth && current.Type != AccountTypeOAuth
+ default:
+ // 都使用过,选择最久未使用的
+ return candidate.LastUsedAt.Before(*current.LastUsedAt)
+ }
+}
+
// isModelSupportedByAccount 根据账户平台检查模型支持
func (s *GeminiMessagesCompatService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {
diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go
index 262a05d9..c63a020c 100644
--- a/backend/internal/service/gemini_multiplatform_test.go
+++ b/backend/internal/service/gemini_multiplatform_test.go
@@ -15,8 +15,10 @@ import (
// mockAccountRepoForGemini Gemini 测试用的 mock
type mockAccountRepoForGemini struct {
- accounts []Account
- accountsByID map[int64]*Account
+ accounts []Account
+ accountsByID map[int64]*Account
+ listByGroupFunc func(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
+ listByPlatformFunc func(ctx context.Context, platforms []string) ([]Account, error)
}
func (m *mockAccountRepoForGemini) GetByID(ctx context.Context, id int64) (*Account, error) {
@@ -107,6 +109,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupID(ctx context.Context,
return nil, nil
}
func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
+ if m.listByPlatformFunc != nil {
+ return m.listByPlatformFunc(ctx, platforms)
+ }
var result []Account
platformSet := make(map[string]bool)
for _, p := range platforms {
@@ -120,6 +125,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Contex
return result, nil
}
func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
+ if m.listByGroupFunc != nil {
+ return m.listByGroupFunc(ctx, groupID, platforms)
+ }
return m.ListSchedulableByPlatforms(ctx, platforms)
}
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
@@ -215,6 +223,7 @@ var _ GroupRepository = (*mockGroupRepoForGemini)(nil)
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
type mockGatewayCacheForGemini struct {
sessionBindings map[string]int64
+ deletedSessions map[string]int
}
func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
@@ -236,6 +245,18 @@ func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, group
return nil
}
+func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
+ if m.sessionBindings == nil {
+ return nil
+ }
+ if m.deletedSessions == nil {
+ m.deletedSessions = make(map[string]int)
+ }
+ m.deletedSessions[sessionHash]++
+ delete(m.sessionBindings, sessionHash)
+ return nil
+}
+
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
ctx := context.Background()
@@ -526,6 +547,274 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickyS
// 粘性会话未命中,按优先级选择
require.Equal(t, int64(2), acc.ID, "粘性会话未命中,应按优先级选择")
})
+
+ t.Run("粘性会话不可调度-清理并回退选择", func(t *testing.T) {
+ repo := &mockAccountRepoForGemini{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformGemini, Priority: 2, Status: StatusDisabled, Schedulable: true},
+ {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForGemini{
+ sessionBindings: map[string]int64{"gemini:session-123": 1},
+ }
+ groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
+
+ svc := &GeminiMessagesCompatService{
+ accountRepo: repo,
+ groupRepo: groupRepo,
+ cache: cache,
+ }
+
+ acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
+ require.NoError(t, err)
+ require.NotNil(t, acc)
+ require.Equal(t, int64(2), acc.ID)
+ require.Equal(t, 1, cache.deletedSessions["gemini:session-123"])
+ require.Equal(t, int64(2), cache.sessionBindings["gemini:session-123"])
+ })
+}
+
+func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ForcePlatformFallback(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(9)
+ ctx = context.WithValue(ctx, ctxkey.ForcePlatform, PlatformAntigravity)
+
+ repo := &mockAccountRepoForGemini{
+ listByGroupFunc: func(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
+ return nil, nil
+ },
+ listByPlatformFunc: func(ctx context.Context, platforms []string) ([]Account, error) {
+ return []Account{
+ {ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
+ }, nil
+ },
+ accountsByID: map[int64]*Account{
+ 1: {ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
+ },
+ }
+
+ cache := &mockGatewayCacheForGemini{}
+ groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
+
+ svc := &GeminiMessagesCompatService{
+ accountRepo: repo,
+ groupRepo: groupRepo,
+ cache: cache,
+ }
+
+ acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
+ require.NoError(t, err)
+ require.NotNil(t, acc)
+ require.Equal(t, int64(1), acc.ID)
+}
+
+func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_NoModelSupport(t *testing.T) {
+ ctx := context.Background()
+
+ repo := &mockAccountRepoForGemini{
+ accounts: []Account{
+ {
+ ID: 1,
+ Platform: PlatformGemini,
+ Priority: 1,
+ Status: StatusActive,
+ Schedulable: true,
+ Credentials: map[string]any{"model_mapping": map[string]any{"gemini-1.0-pro": "gemini-1.0-pro"}},
+ },
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForGemini{}
+ groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
+
+ svc := &GeminiMessagesCompatService{
+ accountRepo: repo,
+ groupRepo: groupRepo,
+ cache: cache,
+ }
+
+ acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
+ require.Error(t, err)
+ require.Nil(t, acc)
+ require.Contains(t, err.Error(), "supporting model")
+}
+
+func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickyMixedScheduling(t *testing.T) {
+ ctx := context.Background()
+ repo := &mockAccountRepoForGemini{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
+ {ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForGemini{
+ sessionBindings: map[string]int64{"gemini:session-999": 1},
+ }
+ groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
+
+ svc := &GeminiMessagesCompatService{
+ accountRepo: repo,
+ groupRepo: groupRepo,
+ cache: cache,
+ }
+
+ acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-999", "gemini-2.5-flash", nil)
+ require.NoError(t, err)
+ require.NotNil(t, acc)
+ require.Equal(t, int64(1), acc.ID)
+}
+
+func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_SkipDisabledMixedScheduling(t *testing.T) {
+ ctx := context.Background()
+ repo := &mockAccountRepoForGemini{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
+ {ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForGemini{}
+ groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
+
+ svc := &GeminiMessagesCompatService{
+ accountRepo: repo,
+ groupRepo: groupRepo,
+ cache: cache,
+ }
+
+ acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
+ require.NoError(t, err)
+ require.NotNil(t, acc)
+ require.Equal(t, int64(2), acc.ID)
+}
+
+func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ExcludedAccount(t *testing.T) {
+ ctx := context.Background()
+ repo := &mockAccountRepoForGemini{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
+ {ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForGemini{}
+ groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
+
+ svc := &GeminiMessagesCompatService{
+ accountRepo: repo,
+ groupRepo: groupRepo,
+ cache: cache,
+ }
+
+ excluded := map[int64]struct{}{1: {}}
+ acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", excluded)
+ require.NoError(t, err)
+ require.NotNil(t, acc)
+ require.Equal(t, int64(2), acc.ID)
+}
+
+func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ListError(t *testing.T) {
+ ctx := context.Background()
+ repo := &mockAccountRepoForGemini{
+ listByPlatformFunc: func(ctx context.Context, platforms []string) ([]Account, error) {
+ return nil, errors.New("query failed")
+ },
+ }
+
+ cache := &mockGatewayCacheForGemini{}
+ groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
+
+ svc := &GeminiMessagesCompatService{
+ accountRepo: repo,
+ groupRepo: groupRepo,
+ cache: cache,
+ }
+
+ acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
+ require.Error(t, err)
+ require.Nil(t, acc)
+ require.Contains(t, err.Error(), "query accounts failed")
+}
+
+func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_PreferOAuth(t *testing.T) {
+ ctx := context.Background()
+ repo := &mockAccountRepoForGemini{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
+ {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForGemini{}
+ groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
+
+ svc := &GeminiMessagesCompatService{
+ accountRepo: repo,
+ groupRepo: groupRepo,
+ cache: cache,
+ }
+
+ acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-pro", nil)
+ require.NoError(t, err)
+ require.NotNil(t, acc)
+ require.Equal(t, int64(2), acc.ID)
+}
+
+func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_PreferLeastRecentlyUsed(t *testing.T) {
+ ctx := context.Background()
+ oldTime := time.Now().Add(-2 * time.Hour)
+ newTime := time.Now().Add(-1 * time.Hour)
+ repo := &mockAccountRepoForGemini{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: &newTime},
+ {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: &oldTime},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForGemini{}
+ groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
+
+ svc := &GeminiMessagesCompatService{
+ accountRepo: repo,
+ groupRepo: groupRepo,
+ cache: cache,
+ }
+
+ acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-pro", nil)
+ require.NoError(t, err)
+ require.NotNil(t, acc)
+ require.Equal(t, int64(2), acc.ID)
}
// TestGeminiPlatformRouting_DocumentRouteDecision 测试平台路由决策逻辑
diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go
index ff731be5..74bff747 100644
--- a/backend/internal/service/openai_gateway_service.go
+++ b/backend/internal/service/openai_gateway_service.go
@@ -180,67 +180,26 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
}
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
+// SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
- // 1. Check sticky session
- if sessionHash != "" {
- accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
- if err == nil && accountID > 0 {
- if _, excluded := excludedIDs[accountID]; !excluded {
- account, err := s.getSchedulableAccount(ctx, accountID)
- if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
- // Refresh sticky session TTL
- _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
- return account, nil
- }
- }
- }
+ cacheKey := "openai:" + sessionHash
+
+ // 1. 尝试粘性会话命中
+ // Try sticky session hit
+ if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs); account != nil {
+ return account, nil
}
- // 2. Get schedulable OpenAI accounts
+ // 2. 获取可调度的 OpenAI 账号
+ // Get schedulable OpenAI accounts
accounts, err := s.listSchedulableAccounts(ctx, groupID)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
- // 3. Select by priority + LRU
- var selected *Account
- for i := range accounts {
- acc := &accounts[i]
- if _, excluded := excludedIDs[acc.ID]; excluded {
- continue
- }
- // Scheduler snapshots can be temporarily stale; re-check schedulability here to
- // avoid selecting accounts that were recently rate-limited/overloaded.
- if !acc.IsSchedulable() {
- continue
- }
- // Check model support
- if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
- continue
- }
- if selected == nil {
- selected = acc
- continue
- }
- // Lower priority value means higher priority
- if acc.Priority < selected.Priority {
- selected = acc
- } else if acc.Priority == selected.Priority {
- switch {
- case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
- selected = acc
- case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
- // keep selected (never used is preferred)
- case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
- // keep selected (both never used)
- default:
- // Same priority, select least recently used
- if acc.LastUsedAt.Before(*selected.LastUsedAt) {
- selected = acc
- }
- }
- }
- }
+ // 3. 按优先级 + LRU 选择最佳账号
+ // Select by priority + LRU
+ selected := s.selectBestAccount(accounts, requestedModel, excludedIDs)
if selected == nil {
if requestedModel != "" {
@@ -249,14 +208,138 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
return nil, errors.New("no available OpenAI accounts")
}
- // 4. Set sticky session
+ // 4. 设置粘性会话绑定
+ // Set sticky session binding
if sessionHash != "" {
- _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, selected.ID, openaiStickySessionTTL)
+ _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, openaiStickySessionTTL)
}
return selected, nil
}
+// tryStickySessionHit 尝试从粘性会话获取账号。
+// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。
+//
+// tryStickySessionHit attempts to get account from sticky session.
+// Returns account if hit and usable; clears session and returns nil if account is unavailable.
+func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, cacheKey, requestedModel string, excludedIDs map[int64]struct{}) *Account {
+ if sessionHash == "" {
+ return nil
+ }
+
+ accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
+ if err != nil || accountID <= 0 {
+ return nil
+ }
+
+ if _, excluded := excludedIDs[accountID]; excluded {
+ return nil
+ }
+
+ account, err := s.getSchedulableAccount(ctx, accountID)
+ if err != nil {
+ return nil
+ }
+
+ // 检查账号是否需要清理粘性会话
+ // Check if sticky session should be cleared
+ if shouldClearStickySession(account) {
+ _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
+ return nil
+ }
+
+ // 验证账号是否可用于当前请求
+ // Verify account is usable for current request
+ if !account.IsSchedulable() || !account.IsOpenAI() {
+ return nil
+ }
+ if requestedModel != "" && !account.IsModelSupported(requestedModel) {
+ return nil
+ }
+
+ // 刷新会话 TTL 并返回账号
+ // Refresh session TTL and return account
+ _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, openaiStickySessionTTL)
+ return account
+}
+
+// selectBestAccount 从候选账号中选择最佳账号(优先级 + LRU)。
+// 返回 nil 表示无可用账号。
+//
+// selectBestAccount selects the best account from candidates (priority + LRU).
+// Returns nil if no available account.
+func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
+ var selected *Account
+
+ for i := range accounts {
+ acc := &accounts[i]
+
+ // 跳过被排除的账号
+ // Skip excluded accounts
+ if _, excluded := excludedIDs[acc.ID]; excluded {
+ continue
+ }
+
+ // 调度器快照可能暂时过时,这里重新检查可调度性和平台
+ // Scheduler snapshots can be temporarily stale; re-check schedulability and platform
+ if !acc.IsSchedulable() || !acc.IsOpenAI() {
+ continue
+ }
+
+ // 检查模型支持
+ // Check model support
+ if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
+ continue
+ }
+
+ // 选择优先级最高且最久未使用的账号
+ // Select highest priority and least recently used
+ if selected == nil {
+ selected = acc
+ continue
+ }
+
+ if s.isBetterAccount(acc, selected) {
+ selected = acc
+ }
+ }
+
+ return selected
+}
+
+// isBetterAccount 判断 candidate 是否比 current 更优。
+// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先,其次是最久未使用的。
+//
+// isBetterAccount checks if candidate is better than current.
+// Rules: higher priority (lower value) wins; same priority: never used > least recently used.
+func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool {
+ // 优先级更高(数值更小)
+ // Higher priority (lower value)
+ if candidate.Priority < current.Priority {
+ return true
+ }
+ if candidate.Priority > current.Priority {
+ return false
+ }
+
+ // 同优先级,比较最后使用时间
+ // Same priority, compare last used time
+ switch {
+ case candidate.LastUsedAt == nil && current.LastUsedAt != nil:
+ // candidate 从未使用,优先
+ return true
+ case candidate.LastUsedAt != nil && current.LastUsedAt == nil:
+ // current 从未使用,保持
+ return false
+ case candidate.LastUsedAt == nil && current.LastUsedAt == nil:
+ // 都未使用,保持
+ return false
+ default:
+ // 都使用过,选择最久未使用的
+ return candidate.LastUsedAt.Before(*current.LastUsedAt)
+ }
+}
+
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
cfg := s.schedulingConfig()
@@ -325,29 +408,35 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.getSchedulableAccount(ctx, accountID)
- if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
- (requestedModel == "" || account.IsModelSupported(requestedModel)) {
- result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
- if err == nil && result.Acquired {
- _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
- return &AccountSelectionResult{
- Account: account,
- Acquired: true,
- ReleaseFunc: result.ReleaseFunc,
- }, nil
+ if err == nil {
+ clearSticky := shouldClearStickySession(account)
+ if clearSticky {
+ _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
}
+ if !clearSticky && account.IsSchedulable() && account.IsOpenAI() &&
+ (requestedModel == "" || account.IsModelSupported(requestedModel)) {
+ result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
+ if err == nil && result.Acquired {
+ _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
+ return &AccountSelectionResult{
+ Account: account,
+ Acquired: true,
+ ReleaseFunc: result.ReleaseFunc,
+ }, nil
+ }
- waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
- if waitingCount < cfg.StickySessionMaxWaiting {
- return &AccountSelectionResult{
- Account: account,
- WaitPlan: &AccountWaitPlan{
- AccountID: accountID,
- MaxConcurrency: account.Concurrency,
- Timeout: cfg.StickySessionWaitTimeout,
- MaxWaiting: cfg.StickySessionMaxWaiting,
- },
- }, nil
+ waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
+ if waitingCount < cfg.StickySessionMaxWaiting {
+ return &AccountSelectionResult{
+ Account: account,
+ WaitPlan: &AccountWaitPlan{
+ AccountID: accountID,
+ MaxConcurrency: account.Concurrency,
+ Timeout: cfg.StickySessionWaitTimeout,
+ MaxWaiting: cfg.StickySessionMaxWaiting,
+ },
+ }, nil
+ }
}
}
}
diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go
index 14dd7699..ae69a986 100644
--- a/backend/internal/service/openai_gateway_service_test.go
+++ b/backend/internal/service/openai_gateway_service_test.go
@@ -21,16 +21,42 @@ type stubOpenAIAccountRepo struct {
accounts []Account
}
+func (r stubOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) {
+ for i := range r.accounts {
+ if r.accounts[i].ID == id {
+ return &r.accounts[i], nil
+ }
+ }
+ return nil, errors.New("account not found")
+}
+
func (r stubOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
- return append([]Account(nil), r.accounts...), nil
+ var result []Account
+ for _, acc := range r.accounts {
+ if acc.Platform == platform {
+ result = append(result, acc)
+ }
+ }
+ return result, nil
}
func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
- return append([]Account(nil), r.accounts...), nil
+ var result []Account
+ for _, acc := range r.accounts {
+ if acc.Platform == platform {
+ result = append(result, acc)
+ }
+ }
+ return result, nil
}
type stubConcurrencyCache struct {
ConcurrencyCache
+ loadBatchErr error
+ loadMap map[int64]*AccountLoadInfo
+ acquireResults map[int64]bool
+ waitCounts map[int64]int
+ skipDefaultLoad bool
}
type cancelReadCloser struct{}
@@ -53,6 +79,11 @@ func (w *failingGinWriter) Write(p []byte) (int, error) {
}
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
+ if c.acquireResults != nil {
+ if result, ok := c.acquireResults[accountID]; ok {
+ return result, nil
+ }
+ }
return true, nil
}
@@ -61,8 +92,25 @@ func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID
}
func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
+ if c.loadBatchErr != nil {
+ return nil, c.loadBatchErr
+ }
out := make(map[int64]*AccountLoadInfo, len(accounts))
+ if c.skipDefaultLoad && c.loadMap != nil {
+ for _, acc := range accounts {
+ if load, ok := c.loadMap[acc.ID]; ok {
+ out[acc.ID] = load
+ }
+ }
+ return out, nil
+ }
for _, acc := range accounts {
+ if c.loadMap != nil {
+ if load, ok := c.loadMap[acc.ID]; ok {
+ out[acc.ID] = load
+ continue
+ }
+ }
out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
}
return out, nil
@@ -111,6 +159,51 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
}
}
+func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
+ if c.waitCounts != nil {
+ if count, ok := c.waitCounts[accountID]; ok {
+ return count, nil
+ }
+ }
+ return 0, nil
+}
+
+type stubGatewayCache struct {
+ sessionBindings map[string]int64
+ deletedSessions map[string]int
+}
+
+func (c *stubGatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
+ if id, ok := c.sessionBindings[sessionHash]; ok {
+ return id, nil
+ }
+ return 0, errors.New("not found")
+}
+
+func (c *stubGatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
+ if c.sessionBindings == nil {
+ c.sessionBindings = make(map[string]int64)
+ }
+ c.sessionBindings[sessionHash] = accountID
+ return nil
+}
+
+func (c *stubGatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
+ return nil
+}
+
+func (c *stubGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
+ if c.sessionBindings == nil {
+ return nil
+ }
+ if c.deletedSessions == nil {
+ c.deletedSessions = make(map[string]int)
+ }
+ c.deletedSessions[sessionHash]++
+ delete(c.sessionBindings, sessionHash)
+ return nil
+}
+
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)
@@ -201,6 +294,515 @@ func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurre
}
}
+func TestOpenAISelectAccountForModelWithExclusions_StickyUnschedulableClearsSession(t *testing.T) {
+ sessionHash := "session-1"
+ repo := stubOpenAIAccountRepo{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformOpenAI, Status: StatusDisabled, Schedulable: true, Concurrency: 1},
+ {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1},
+ },
+ }
+ cache := &stubGatewayCache{
+ sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: repo,
+ cache: cache,
+ }
+
+ acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
+ if err != nil {
+ t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
+ }
+ if acc == nil || acc.ID != 2 {
+ t.Fatalf("expected account 2, got %+v", acc)
+ }
+ if cache.deletedSessions["openai:"+sessionHash] != 1 {
+ t.Fatalf("expected sticky session to be deleted")
+ }
+ if cache.sessionBindings["openai:"+sessionHash] != 2 {
+ t.Fatalf("expected sticky session to bind to account 2")
+ }
+}
+
+func TestOpenAISelectAccountWithLoadAwareness_StickyUnschedulableClearsSession(t *testing.T) {
+ sessionHash := "session-2"
+ groupID := int64(1)
+ repo := stubOpenAIAccountRepo{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformOpenAI, Status: StatusDisabled, Schedulable: true, Concurrency: 1},
+ {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1},
+ },
+ }
+ cache := &stubGatewayCache{
+ sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: repo,
+ cache: cache,
+ concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ }
+
+ selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, sessionHash, "gpt-4", nil)
+ if err != nil {
+ t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
+ }
+ if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
+ t.Fatalf("expected account 2, got %+v", selection)
+ }
+ if cache.deletedSessions["openai:"+sessionHash] != 1 {
+ t.Fatalf("expected sticky session to be deleted")
+ }
+ if cache.sessionBindings["openai:"+sessionHash] != 2 {
+ t.Fatalf("expected sticky session to bind to account 2")
+ }
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+}
+
+func TestOpenAISelectAccountForModelWithExclusions_NoModelSupport(t *testing.T) {
+ repo := stubOpenAIAccountRepo{
+ accounts: []Account{
+ {
+ ID: 1,
+ Platform: PlatformOpenAI,
+ Status: StatusActive,
+ Schedulable: true,
+ Credentials: map[string]any{"model_mapping": map[string]any{"gpt-3.5-turbo": "gpt-3.5-turbo"}},
+ },
+ },
+ }
+ cache := &stubGatewayCache{}
+
+ svc := &OpenAIGatewayService{
+ accountRepo: repo,
+ cache: cache,
+ }
+
+ acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "gpt-4", nil)
+ if err == nil {
+ t.Fatalf("expected error for unsupported model")
+ }
+ if acc != nil {
+ t.Fatalf("expected nil account for unsupported model")
+ }
+ if !strings.Contains(err.Error(), "supporting model") {
+ t.Fatalf("unexpected error: %v", err)
+ }
+}
+
+func TestOpenAISelectAccountWithLoadAwareness_LoadBatchErrorFallback(t *testing.T) {
+ groupID := int64(1)
+ repo := stubOpenAIAccountRepo{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
+ {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
+ },
+ }
+ cache := &stubGatewayCache{}
+ concurrencyCache := stubConcurrencyCache{
+ loadBatchErr: errors.New("load batch failed"),
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: repo,
+ cache: cache,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "fallback", "gpt-4", nil)
+ if err != nil {
+ t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
+ }
+ if selection == nil || selection.Account == nil {
+ t.Fatalf("expected selection")
+ }
+ if selection.Account.ID != 2 {
+ t.Fatalf("expected account 2, got %d", selection.Account.ID)
+ }
+ if cache.sessionBindings["openai:fallback"] != 2 {
+ t.Fatalf("expected sticky session updated")
+ }
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+}
+
+func TestOpenAISelectAccountWithLoadAwareness_NoSlotFallbackWait(t *testing.T) {
+ groupID := int64(1)
+ repo := stubOpenAIAccountRepo{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
+ },
+ }
+ cache := &stubGatewayCache{}
+ concurrencyCache := stubConcurrencyCache{
+ acquireResults: map[int64]bool{1: false},
+ loadMap: map[int64]*AccountLoadInfo{
+ 1: {AccountID: 1, LoadRate: 10},
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: repo,
+ cache: cache,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
+ if err != nil {
+ t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
+ }
+ if selection == nil || selection.WaitPlan == nil {
+ t.Fatalf("expected wait plan fallback")
+ }
+ if selection.Account == nil || selection.Account.ID != 1 {
+ t.Fatalf("expected account 1")
+ }
+}
+
+func TestOpenAISelectAccountForModelWithExclusions_SetsStickyBinding(t *testing.T) {
+ sessionHash := "bind"
+ repo := stubOpenAIAccountRepo{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
+ },
+ }
+ cache := &stubGatewayCache{}
+
+ svc := &OpenAIGatewayService{
+ accountRepo: repo,
+ cache: cache,
+ }
+
+ acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
+ if err != nil {
+ t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
+ }
+ if acc == nil || acc.ID != 1 {
+ t.Fatalf("expected account 1")
+ }
+ if cache.sessionBindings["openai:"+sessionHash] != 1 {
+ t.Fatalf("expected sticky session binding")
+ }
+}
+
+func TestOpenAISelectAccountWithLoadAwareness_StickyWaitPlan(t *testing.T) {
+ sessionHash := "sticky-wait"
+ groupID := int64(1)
+ repo := stubOpenAIAccountRepo{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
+ },
+ }
+ cache := &stubGatewayCache{
+ sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
+ }
+ concurrencyCache := stubConcurrencyCache{
+ acquireResults: map[int64]bool{1: false},
+ waitCounts: map[int64]int{1: 0},
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: repo,
+ cache: cache,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, sessionHash, "gpt-4", nil)
+ if err != nil {
+ t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
+ }
+ if selection == nil || selection.WaitPlan == nil {
+ t.Fatalf("expected sticky wait plan")
+ }
+ if selection.Account == nil || selection.Account.ID != 1 {
+ t.Fatalf("expected account 1")
+ }
+}
+
+func TestOpenAISelectAccountWithLoadAwareness_PrefersLowerLoad(t *testing.T) {
+ groupID := int64(1)
+ repo := stubOpenAIAccountRepo{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
+ {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
+ },
+ }
+ cache := &stubGatewayCache{}
+ concurrencyCache := stubConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 1: {AccountID: 1, LoadRate: 80},
+ 2: {AccountID: 2, LoadRate: 10},
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: repo,
+ cache: cache,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "load", "gpt-4", nil)
+ if err != nil {
+ t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
+ }
+ if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
+ t.Fatalf("expected account 2")
+ }
+ if cache.sessionBindings["openai:load"] != 2 {
+ t.Fatalf("expected sticky session updated")
+ }
+}
+
+func TestOpenAISelectAccountForModelWithExclusions_StickyExcludedFallback(t *testing.T) {
+ sessionHash := "excluded"
+ repo := stubOpenAIAccountRepo{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
+ {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
+ },
+ }
+ cache := &stubGatewayCache{
+ sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: repo,
+ cache: cache,
+ }
+
+ excluded := map[int64]struct{}{1: {}}
+ acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", excluded)
+ if err != nil {
+ t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
+ }
+ if acc == nil || acc.ID != 2 {
+ t.Fatalf("expected account 2")
+ }
+}
+
+func TestOpenAISelectAccountForModelWithExclusions_StickyNonOpenAI(t *testing.T) {
+ sessionHash := "non-openai"
+ repo := stubOpenAIAccountRepo{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
+ {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
+ },
+ }
+ cache := &stubGatewayCache{
+ sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: repo,
+ cache: cache,
+ }
+
+ acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
+ if err != nil {
+ t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
+ }
+ if acc == nil || acc.ID != 2 {
+ t.Fatalf("expected account 2")
+ }
+}
+
+func TestOpenAISelectAccountForModelWithExclusions_NoAccounts(t *testing.T) {
+ repo := stubOpenAIAccountRepo{accounts: []Account{}}
+ cache := &stubGatewayCache{}
+
+ svc := &OpenAIGatewayService{
+ accountRepo: repo,
+ cache: cache,
+ }
+
+ acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "", nil)
+ if err == nil {
+ t.Fatalf("expected error for no accounts")
+ }
+ if acc != nil {
+ t.Fatalf("expected nil account")
+ }
+ if !strings.Contains(err.Error(), "no available OpenAI accounts") {
+ t.Fatalf("unexpected error: %v", err)
+ }
+}
+
+func TestOpenAISelectAccountWithLoadAwareness_NoCandidates(t *testing.T) {
+ groupID := int64(1)
+ resetAt := time.Now().Add(1 * time.Hour)
+ repo := stubOpenAIAccountRepo{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, RateLimitResetAt: &resetAt},
+ },
+ }
+ cache := &stubGatewayCache{}
+ concurrencyCache := stubConcurrencyCache{}
+
+ svc := &OpenAIGatewayService{
+ accountRepo: repo,
+ cache: cache,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
+ if err == nil {
+ t.Fatalf("expected error for no candidates")
+ }
+ if selection != nil {
+ t.Fatalf("expected nil selection")
+ }
+}
+
+func TestOpenAISelectAccountWithLoadAwareness_AllFullWaitPlan(t *testing.T) {
+ groupID := int64(1)
+ repo := stubOpenAIAccountRepo{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
+ },
+ }
+ cache := &stubGatewayCache{}
+ concurrencyCache := stubConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 1: {AccountID: 1, LoadRate: 100},
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: repo,
+ cache: cache,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
+ if err != nil {
+ t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
+ }
+ if selection == nil || selection.WaitPlan == nil {
+ t.Fatalf("expected wait plan")
+ }
+}
+
+func TestOpenAISelectAccountWithLoadAwareness_LoadBatchErrorNoAcquire(t *testing.T) {
+ groupID := int64(1)
+ repo := stubOpenAIAccountRepo{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
+ },
+ }
+ cache := &stubGatewayCache{}
+ concurrencyCache := stubConcurrencyCache{
+ loadBatchErr: errors.New("load batch failed"),
+ acquireResults: map[int64]bool{1: false},
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: repo,
+ cache: cache,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
+ if err != nil {
+ t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
+ }
+ if selection == nil || selection.WaitPlan == nil {
+ t.Fatalf("expected wait plan")
+ }
+}
+
+func TestOpenAISelectAccountWithLoadAwareness_MissingLoadInfo(t *testing.T) {
+ groupID := int64(1)
+ repo := stubOpenAIAccountRepo{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
+ {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
+ },
+ }
+ cache := &stubGatewayCache{}
+ concurrencyCache := stubConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 1: {AccountID: 1, LoadRate: 50},
+ },
+ skipDefaultLoad: true,
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: repo,
+ cache: cache,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
+ if err != nil {
+ t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
+ }
+ if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
+ t.Fatalf("expected account 2")
+ }
+}
+
+func TestOpenAISelectAccountForModelWithExclusions_LeastRecentlyUsed(t *testing.T) {
+ oldTime := time.Now().Add(-2 * time.Hour)
+ newTime := time.Now().Add(-1 * time.Hour)
+ repo := stubOpenAIAccountRepo{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1, LastUsedAt: &newTime},
+ {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1, LastUsedAt: &oldTime},
+ },
+ }
+ cache := &stubGatewayCache{}
+
+ svc := &OpenAIGatewayService{
+ accountRepo: repo,
+ cache: cache,
+ }
+
+ acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "gpt-4", nil)
+ if err != nil {
+ t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
+ }
+ if acc == nil || acc.ID != 2 {
+ t.Fatalf("expected account 2")
+ }
+}
+
+func TestOpenAISelectAccountWithLoadAwareness_PreferNeverUsed(t *testing.T) {
+ groupID := int64(1)
+ lastUsed := time.Now().Add(-1 * time.Hour)
+ repo := stubOpenAIAccountRepo{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, LastUsedAt: &lastUsed},
+ {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
+ },
+ }
+ cache := &stubGatewayCache{}
+ concurrencyCache := stubConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 1: {AccountID: 1, LoadRate: 10},
+ 2: {AccountID: 2, LoadRate: 10},
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: repo,
+ cache: cache,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
+ if err != nil {
+ t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
+ }
+ if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
+ t.Fatalf("expected account 2")
+ }
+}
+
func TestOpenAIStreamingTimeout(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index 5ab73588..d77dd30d 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -60,6 +60,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
keys := []string{
SettingKeyRegistrationEnabled,
SettingKeyEmailVerifyEnabled,
+ SettingKeyPromoCodeEnabled,
SettingKeyTurnstileEnabled,
SettingKeyTurnstileSiteKey,
SettingKeySiteName,
@@ -88,6 +89,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
return &PublicSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
+ PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
@@ -125,6 +127,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
return &struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
+ PromoCodeEnabled bool `json:"promo_code_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key,omitempty"`
SiteName string `json:"site_name"`
@@ -140,6 +143,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
}{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
+ PromoCodeEnabled: settings.PromoCodeEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
@@ -162,6 +166,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
// 注册设置
updates[SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled)
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
+ updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled)
// 邮件服务设置(只有非空才更新密码)
updates[SettingKeySMTPHost] = settings.SMTPHost
@@ -248,6 +253,15 @@ func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
return value == "true"
}
+// IsPromoCodeEnabled 检查是否启用优惠码功能
+func (s *SettingService) IsPromoCodeEnabled(ctx context.Context) bool {
+ value, err := s.settingRepo.GetValue(ctx, SettingKeyPromoCodeEnabled)
+ if err != nil {
+ return true // 默认启用
+ }
+ return value != "false"
+}
+
// GetSiteName 获取网站名称
func (s *SettingService) GetSiteName(ctx context.Context) string {
value, err := s.settingRepo.GetValue(ctx, SettingKeySiteName)
@@ -297,6 +311,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
defaults := map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "false",
+ SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
SettingKeySiteName: "Sub2API",
SettingKeySiteLogo: "",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
@@ -328,6 +343,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result := &SystemSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
+ PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
SMTPHost: settings[SettingKeySMTPHost],
SMTPUsername: settings[SettingKeySMTPUsername],
SMTPFrom: settings[SettingKeySMTPFrom],
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index 05494272..919344e5 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -3,6 +3,7 @@ package service
type SystemSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
+ PromoCodeEnabled bool
SMTPHost string
SMTPPort int
@@ -58,6 +59,7 @@ type SystemSettings struct {
type PublicSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
+ PromoCodeEnabled bool
TurnstileEnabled bool
TurnstileSiteKey string
SiteName string
diff --git a/backend/internal/service/sticky_session_test.go b/backend/internal/service/sticky_session_test.go
new file mode 100644
index 00000000..4bd06b7b
--- /dev/null
+++ b/backend/internal/service/sticky_session_test.go
@@ -0,0 +1,54 @@
+//go:build unit
+
+// Package service 提供 API 网关核心服务。
+// 本文件包含 shouldClearStickySession 函数的单元测试,
+// 验证粘性会话清理逻辑在各种账号状态下的正确行为。
+//
+// This file contains unit tests for the shouldClearStickySession function,
+// verifying correct sticky session clearing behavior under various account states.
+package service
+
+import (
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+// TestShouldClearStickySession 测试粘性会话清理判断逻辑。
+// 验证在以下情况下是否正确判断需要清理粘性会话:
+// - nil 账号:不清理(返回 false)
+// - 状态为错误或禁用:清理
+// - 不可调度:清理
+// - 临时不可调度且未过期:清理
+// - 临时不可调度已过期:不清理
+// - 正常可调度状态:不清理
+//
+// TestShouldClearStickySession tests the sticky session clearing logic.
+// Verifies correct behavior for various account states including:
+// nil account, error/disabled status, unschedulable, temporary unschedulable.
+func TestShouldClearStickySession(t *testing.T) {
+ now := time.Now()
+ future := now.Add(1 * time.Hour)
+ past := now.Add(-1 * time.Hour)
+
+ tests := []struct {
+ name string
+ account *Account
+ want bool
+ }{
+ {name: "nil account", account: nil, want: false},
+ {name: "status error", account: &Account{Status: StatusError, Schedulable: true}, want: true},
+ {name: "status disabled", account: &Account{Status: StatusDisabled, Schedulable: true}, want: true},
+ {name: "schedulable false", account: &Account{Status: StatusActive, Schedulable: false}, want: true},
+ {name: "temp unschedulable", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &future}, want: true},
+ {name: "temp unschedulable expired", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &past}, want: false},
+ {name: "active schedulable", account: &Account{Status: StatusActive, Schedulable: true}, want: false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ require.Equal(t, tt.want, shouldClearStickySession(tt.account))
+ })
+ }
+}
diff --git a/backend/internal/service/usage_cleanup_service_test.go b/backend/internal/service/usage_cleanup_service_test.go
index 05c423bc..c6c309b6 100644
--- a/backend/internal/service/usage_cleanup_service_test.go
+++ b/backend/internal/service/usage_cleanup_service_test.go
@@ -345,6 +345,9 @@ func TestUsageCleanupServiceRunOnceSuccess(t *testing.T) {
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.deleteCalls, 3)
+ require.Equal(t, 2, repo.deleteCalls[0].limit)
+ require.True(t, repo.deleteCalls[0].filters.StartTime.Equal(start))
+ require.True(t, repo.deleteCalls[0].filters.EndTime.Equal(end))
require.Len(t, repo.markSucceeded, 1)
require.Empty(t, repo.markFailed)
require.Equal(t, int64(5), repo.markSucceeded[0].taskID)
diff --git a/build_image.sh b/deploy/build_image.sh
similarity index 100%
rename from build_image.sh
rename to deploy/build_image.sh
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index c9a09e7d..6e2ade00 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -12,6 +12,7 @@ export interface SystemSettings {
// Registration settings
registration_enabled: boolean
email_verify_enabled: boolean
+ promo_code_enabled: boolean
// Default settings
default_balance: number
default_concurrency: number
@@ -64,6 +65,7 @@ export interface SystemSettings {
export interface UpdateSettingsRequest {
registration_enabled?: boolean
email_verify_enabled?: boolean
+ promo_code_enabled?: boolean
default_balance?: number
default_concurrency?: number
site_name?: string
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index d1eca6a1..120dac27 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -2726,7 +2726,9 @@ export default {
enableRegistration: 'Enable Registration',
enableRegistrationHint: 'Allow new users to register',
emailVerification: 'Email Verification',
- emailVerificationHint: 'Require email verification for new registrations'
+ emailVerificationHint: 'Require email verification for new registrations',
+ promoCode: 'Promo Code',
+ promoCodeHint: 'Allow users to use promo codes during registration'
},
turnstile: {
title: 'Cloudflare Turnstile',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 86ac7ae5..4f7dcf64 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -2879,7 +2879,9 @@ export default {
enableRegistration: '开放注册',
enableRegistrationHint: '允许新用户注册',
emailVerification: '邮箱验证',
- emailVerificationHint: '新用户注册时需要验证邮箱'
+ emailVerificationHint: '新用户注册时需要验证邮箱',
+ promoCode: '优惠码',
+ promoCodeHint: '允许用户在注册时使用优惠码'
},
turnstile: {
title: 'Cloudflare Turnstile',
diff --git a/frontend/src/stores/app.ts b/frontend/src/stores/app.ts
index 7e3c71a0..9c4db599 100644
--- a/frontend/src/stores/app.ts
+++ b/frontend/src/stores/app.ts
@@ -312,6 +312,7 @@ export const useAppStore = defineStore('app', () => {
return {
registration_enabled: false,
email_verify_enabled: false,
+ promo_code_enabled: true,
turnstile_enabled: false,
turnstile_site_key: '',
site_name: siteName.value,
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts
index 1b7ae15d..37c9f030 100644
--- a/frontend/src/types/index.ts
+++ b/frontend/src/types/index.ts
@@ -70,6 +70,7 @@ export interface SendVerifyCodeResponse {
export interface PublicSettings {
registration_enabled: boolean
email_verify_enabled: boolean
+ promo_code_enabled: boolean
turnstile_enabled: boolean
turnstile_site_key: string
site_name: string
diff --git a/frontend/src/views/admin/RedeemView.vue b/frontend/src/views/admin/RedeemView.vue
index 50c55ba3..907c7541 100644
--- a/frontend/src/views/admin/RedeemView.vue
+++ b/frontend/src/views/admin/RedeemView.vue
@@ -238,7 +238,30 @@
v-model="generateForm.group_id"
:options="subscriptionGroupOptions"
:placeholder="t('admin.redeem.selectGroupPlaceholder')"
- />
+ >
+
+
+ {{ t('admin.settings.registration.promoCodeHint') }} +
+{{ t('admin.subscriptions.groupHint') }}