Merge branch 'main' of github.com:Wei-Shaw/sub2api
This commit is contained in:
@@ -84,7 +84,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
}
|
||||
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
|
||||
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
|
||||
accountRepository := repository.NewAccountRepository(client, db)
|
||||
schedulerCache := repository.NewSchedulerCache(redisClient)
|
||||
accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
|
||||
proxyRepository := repository.NewProxyRepository(client, db)
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||
@@ -129,7 +130,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
adminRedeemHandler := admin.NewRedeemHandler(adminService)
|
||||
promoHandler := admin.NewPromoHandler(promoService)
|
||||
opsRepository := repository.NewOpsRepository(db)
|
||||
schedulerCache := repository.NewSchedulerCache(redisClient)
|
||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
|
||||
|
||||
@@ -47,6 +47,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
SMTPHost: settings.SMTPHost,
|
||||
SMTPPort: settings.SMTPPort,
|
||||
SMTPUsername: settings.SMTPUsername,
|
||||
@@ -90,6 +91,7 @@ type UpdateSettingsRequest struct {
|
||||
// 注册设置
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
|
||||
// 邮件服务设置
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
@@ -240,6 +242,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
settings := &service.SystemSettings{
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
@@ -314,6 +317,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
||||
SMTPHost: updatedSettings.SMTPHost,
|
||||
SMTPPort: updatedSettings.SMTPPort,
|
||||
SMTPUsername: updatedSettings.SMTPUsername,
|
||||
|
||||
@@ -195,6 +195,15 @@ type ValidatePromoCodeResponse struct {
|
||||
// ValidatePromoCode 验证优惠码(公开接口,注册前调用)
|
||||
// POST /api/v1/auth/validate-promo-code
|
||||
func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
|
||||
// 检查优惠码功能是否启用
|
||||
if h.settingSvc != nil && !h.settingSvc.IsPromoCodeEnabled(c.Request.Context()) {
|
||||
response.Success(c, ValidatePromoCodeResponse{
|
||||
Valid: false,
|
||||
ErrorCode: "PROMO_CODE_DISABLED",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var req ValidatePromoCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
|
||||
@@ -4,6 +4,7 @@ package dto
|
||||
type SystemSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
@@ -55,6 +56,7 @@ type SystemSettings struct {
|
||||
type PublicSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
|
||||
@@ -34,6 +34,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
response.Success(c, dto.PublicSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
|
||||
@@ -305,3 +305,139 @@ func mustParseURL(rawURL string) *url.URL {
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
// TestProfileExpectation defines expected fingerprint values for a profile.
|
||||
type TestProfileExpectation struct {
|
||||
Profile *Profile
|
||||
ExpectedJA3 string // Expected JA3 hash (empty = don't check)
|
||||
ExpectedJA4 string // Expected full JA4 (empty = don't check)
|
||||
JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check)
|
||||
}
|
||||
|
||||
// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
|
||||
// Run with: go test -v -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
|
||||
func TestAllProfiles(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// Define all profiles to test with their expected fingerprints
|
||||
// These profiles are from config.yaml gateway.tls_fingerprint.profiles
|
||||
profiles := []TestProfileExpectation{
|
||||
{
|
||||
// Linux x64 Node.js v22.17.1
|
||||
// Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
|
||||
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
|
||||
Profile: &Profile{
|
||||
Name: "linux_x64_node_v22171",
|
||||
EnableGREASE: false,
|
||||
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
|
||||
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
|
||||
PointFormats: []uint8{0, 1, 2},
|
||||
},
|
||||
JA4CipherHash: "a33745022dd6", // stable part
|
||||
},
|
||||
{
|
||||
// MacOS arm64 Node.js v22.18.0
|
||||
// Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
|
||||
// Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
|
||||
Profile: &Profile{
|
||||
Name: "macos_arm64_node_v22180",
|
||||
EnableGREASE: false,
|
||||
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
|
||||
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
|
||||
PointFormats: []uint8{0, 1, 2},
|
||||
},
|
||||
JA4CipherHash: "a33745022dd6", // stable part (same cipher suites)
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range profiles {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.Profile.Name, func(t *testing.T) {
|
||||
fp := fetchFingerprint(t, tc.Profile)
|
||||
if fp == nil {
|
||||
return // fetchFingerprint already called t.Fatal
|
||||
}
|
||||
|
||||
t.Logf("Profile: %s", tc.Profile.Name)
|
||||
t.Logf(" JA3: %s", fp.JA3)
|
||||
t.Logf(" JA3 Hash: %s", fp.JA3Hash)
|
||||
t.Logf(" JA4: %s", fp.JA4)
|
||||
t.Logf(" PeetPrint: %s", fp.PeetPrint)
|
||||
t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash)
|
||||
|
||||
// Verify expectations
|
||||
if tc.ExpectedJA3 != "" {
|
||||
if fp.JA3Hash == tc.ExpectedJA3 {
|
||||
t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3)
|
||||
} else {
|
||||
t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3)
|
||||
}
|
||||
}
|
||||
|
||||
if tc.ExpectedJA4 != "" {
|
||||
if fp.JA4 == tc.ExpectedJA4 {
|
||||
t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4)
|
||||
} else {
|
||||
t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4)
|
||||
}
|
||||
}
|
||||
|
||||
// Check JA4 cipher hash (stable middle part)
|
||||
// JA4 format: prefix_cipherHash_extHash
|
||||
if tc.JA4CipherHash != "" {
|
||||
if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") {
|
||||
t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash)
|
||||
} else {
|
||||
t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
|
||||
func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo {
|
||||
t.Helper()
|
||||
|
||||
dialer := NewDialer(profile, nil)
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialTLSContext: dialer.DialTLSContext,
|
||||
},
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
return nil
|
||||
}
|
||||
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get fingerprint: %v", err)
|
||||
return nil
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read response: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
var fpResp FingerprintResponse
|
||||
if err := json.Unmarshal(body, &fpResp); err != nil {
|
||||
t.Logf("Response body: %s", string(body))
|
||||
t.Fatalf("failed to parse fingerprint response: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return &fpResp.TLS
|
||||
}
|
||||
|
||||
@@ -39,9 +39,15 @@ import (
|
||||
// 设计说明:
|
||||
// - client: Ent 客户端,用于类型安全的 ORM 操作
|
||||
// - sql: 原生 SQL 执行器,用于复杂查询和批量操作
|
||||
// - schedulerCache: 调度器缓存,用于在账号状态变更时同步快照
|
||||
type accountRepository struct {
|
||||
client *dbent.Client // Ent ORM 客户端
|
||||
sql sqlExecutor // 原生 SQL 执行接口
|
||||
// schedulerCache 用于在账号状态变更时主动同步快照到缓存,
|
||||
// 确保粘性会话能及时感知账号不可用状态。
|
||||
// Used to proactively sync account snapshot to cache when status changes,
|
||||
// ensuring sticky sessions can promptly detect unavailable accounts.
|
||||
schedulerCache service.SchedulerCache
|
||||
}
|
||||
|
||||
type tempUnschedSnapshot struct {
|
||||
@@ -51,14 +57,14 @@ type tempUnschedSnapshot struct {
|
||||
|
||||
// NewAccountRepository 创建账户仓储实例。
|
||||
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
|
||||
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB) service.AccountRepository {
|
||||
return newAccountRepositoryWithSQL(client, sqlDB)
|
||||
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
|
||||
return newAccountRepositoryWithSQL(client, sqlDB, schedulerCache)
|
||||
}
|
||||
|
||||
// newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。
|
||||
// 这种设计便于单元测试时注入 mock 对象。
|
||||
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accountRepository {
|
||||
return &accountRepository{client: client, sql: sqlq}
|
||||
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor, schedulerCache service.SchedulerCache) *accountRepository {
|
||||
return &accountRepository{client: client, sql: sqlq, schedulerCache: schedulerCache}
|
||||
}
|
||||
|
||||
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
|
||||
@@ -356,6 +362,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||
}
|
||||
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
|
||||
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -540,9 +549,32 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// syncSchedulerAccountSnapshot 在账号状态变更时主动同步快照到调度器缓存。
|
||||
// 当账号被设置为错误、禁用、不可调度或临时不可调度时调用,
|
||||
// 确保调度器和粘性会话逻辑能及时感知账号的最新状态,避免继续使用不可用账号。
|
||||
//
|
||||
// syncSchedulerAccountSnapshot proactively syncs account snapshot to scheduler cache
|
||||
// when account status changes. Called when account is set to error, disabled,
|
||||
// unschedulable, or temporarily unschedulable, ensuring scheduler and sticky session
|
||||
// logic can promptly detect the latest account state and avoid using unavailable accounts.
|
||||
func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, accountID int64) {
|
||||
if r == nil || r.schedulerCache == nil || accountID <= 0 {
|
||||
return
|
||||
}
|
||||
account, err := r.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err)
|
||||
return
|
||||
}
|
||||
if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
|
||||
log.Printf("[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
|
||||
_, err := r.client.Account.Update().
|
||||
Where(dbaccount.IDEQ(id)).
|
||||
@@ -873,6 +905,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -992,6 +1025,9 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
|
||||
}
|
||||
if !schedulable {
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1146,6 +1182,18 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
|
||||
}
|
||||
shouldSync := false
|
||||
if updates.Status != nil && (*updates.Status == service.StatusError || *updates.Status == service.StatusDisabled) {
|
||||
shouldSync = true
|
||||
}
|
||||
if updates.Schedulable != nil && !*updates.Schedulable {
|
||||
shouldSync = true
|
||||
}
|
||||
if shouldSync {
|
||||
for _, id := range ids {
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
@@ -21,11 +21,56 @@ type AccountRepoSuite struct {
|
||||
repo *accountRepository
|
||||
}
|
||||
|
||||
type schedulerCacheRecorder struct {
|
||||
setAccounts []*service.Account
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *service.Account) error {
|
||||
s.setAccounts = append(s.setAccounts, account)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) DeleteAccount(ctx context.Context, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) TryLockBucket(ctx context.Context, bucket service.SchedulerBucket, ttl time.Duration) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) GetOutboxWatermark(ctx context.Context) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) SetOutboxWatermark(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.repo = newAccountRepositoryWithSQL(s.client, tx)
|
||||
s.repo = newAccountRepositoryWithSQL(s.client, tx, nil)
|
||||
}
|
||||
|
||||
func TestAccountRepoSuite(t *testing.T) {
|
||||
@@ -73,6 +118,20 @@ func (s *AccountRepoSuite) TestUpdate() {
|
||||
s.Require().Equal("updated", got.Name)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "sync-update", Status: service.StatusActive, Schedulable: true})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
account.Status = service.StatusDisabled
|
||||
err := s.repo.Update(s.ctx, account)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestDelete() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
|
||||
|
||||
@@ -174,7 +233,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
// 每个 case 重新获取隔离资源
|
||||
tx := testEntTx(s.T())
|
||||
client := tx.Client()
|
||||
repo := newAccountRepositoryWithSQL(client, tx)
|
||||
repo := newAccountRepositoryWithSQL(client, tx, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
tt.setup(client)
|
||||
@@ -365,12 +424,38 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
|
||||
|
||||
func (s *AccountRepoSuite) TestSetSchedulable() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-sched", Schedulable: true})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(got.Schedulable)
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_SyncSchedulerSnapshotOnDisabled() {
|
||||
account1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-1", Status: service.StatusActive, Schedulable: true})
|
||||
account2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-2", Status: service.StatusActive, Schedulable: true})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
disabled := service.StatusDisabled
|
||||
rows, err := s.repo.BulkUpdate(s.ctx, []int64{account1.ID, account2.ID}, service.AccountBulkUpdate{
|
||||
Status: &disabled,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(int64(2), rows)
|
||||
|
||||
s.Require().Len(cacheRecorder.setAccounts, 2)
|
||||
ids := map[int64]struct{}{}
|
||||
for _, acc := range cacheRecorder.setAccounts {
|
||||
ids[acc.ID] = struct{}{}
|
||||
}
|
||||
s.Require().Contains(ids, account1.ID)
|
||||
s.Require().Contains(ids, account2.ID)
|
||||
}
|
||||
|
||||
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---
|
||||
|
||||
@@ -39,3 +39,15 @@ func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, ses
|
||||
key := buildSessionKey(groupID, sessionHash)
|
||||
return c.rdb.Expire(ctx, key, ttl).Err()
|
||||
}
|
||||
|
||||
// DeleteSessionAccountID 删除粘性会话与账号的绑定关系。
|
||||
// 当检测到绑定的账号不可用(如状态错误、禁用、不可调度等)时调用,
|
||||
// 以便下次请求能够重新选择可用账号。
|
||||
//
|
||||
// DeleteSessionAccountID removes the sticky session binding for the given session.
|
||||
// Called when the bound account becomes unavailable (e.g., error status, disabled,
|
||||
// or unschedulable), allowing subsequent requests to select a new available account.
|
||||
func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
|
||||
key := buildSessionKey(groupID, sessionHash)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
@@ -78,6 +78,19 @@ func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
|
||||
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestDeleteSessionAccountID() {
|
||||
sessionID := "openai:s4"
|
||||
accountID := int64(102)
|
||||
groupID := int64(1)
|
||||
sessionTTL := 1 * time.Minute
|
||||
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
|
||||
require.NoError(s.T(), s.cache.DeleteSessionAccountID(s.ctx, groupID, sessionID), "DeleteSessionAccountID")
|
||||
|
||||
_, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
|
||||
sessionID := "corrupted"
|
||||
groupID := int64(1)
|
||||
|
||||
@@ -24,7 +24,7 @@ func (s *GatewayRoutingSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx)
|
||||
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx, nil)
|
||||
}
|
||||
|
||||
func TestGatewayRoutingSuite(t *testing.T) {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
@@ -21,7 +22,7 @@ type openaiOAuthService struct {
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
|
||||
client := createOpenAIReqClient(proxyURL)
|
||||
client := createOpenAIReqClient(s.tokenURL, proxyURL)
|
||||
|
||||
if redirectURI == "" {
|
||||
redirectURI = openai.DefaultRedirectURI
|
||||
@@ -54,7 +55,7 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||
client := createOpenAIReqClient(proxyURL)
|
||||
client := createOpenAIReqClient(s.tokenURL, proxyURL)
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "refresh_token")
|
||||
@@ -81,9 +82,14 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func createOpenAIReqClient(proxyURL string) *req.Client {
|
||||
func createOpenAIReqClient(tokenURL, proxyURL string) *req.Client {
|
||||
forceHTTP2 := false
|
||||
if parsedURL, err := url.Parse(tokenURL); err == nil {
|
||||
forceHTTP2 = strings.EqualFold(parsedURL.Scheme, "https")
|
||||
}
|
||||
return getSharedReqClient(reqClientOptions{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 60 * time.Second,
|
||||
Timeout: 120 * time.Second,
|
||||
ForceHTTP2: forceHTTP2,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -244,6 +244,13 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
|
||||
require.ErrorContains(s.T(), err, "status 401")
|
||||
}
|
||||
|
||||
func TestNewOpenAIOAuthClient_DefaultTokenURL(t *testing.T) {
|
||||
client := NewOpenAIOAuthClient()
|
||||
svc, ok := client.(*openaiOAuthService)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, openai.TokenURL, svc.tokenURL)
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(OpenAIOAuthServiceSuite))
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ type reqClientOptions struct {
|
||||
ProxyURL string // 代理 URL(支持 http/https/socks5)
|
||||
Timeout time.Duration // 请求超时时间
|
||||
Impersonate bool // 是否模拟 Chrome 浏览器指纹
|
||||
ForceHTTP2 bool // 是否强制使用 HTTP/2
|
||||
}
|
||||
|
||||
// sharedReqClients 存储按配置参数缓存的 req 客户端实例
|
||||
@@ -41,6 +42,9 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
|
||||
}
|
||||
|
||||
client := req.C().SetTimeout(opts.Timeout)
|
||||
if opts.ForceHTTP2 {
|
||||
client = client.EnableForceHTTP2()
|
||||
}
|
||||
if opts.Impersonate {
|
||||
client = client.ImpersonateChrome()
|
||||
}
|
||||
@@ -56,9 +60,10 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
|
||||
}
|
||||
|
||||
func buildReqClientKey(opts reqClientOptions) string {
|
||||
return fmt.Sprintf("%s|%s|%t",
|
||||
return fmt.Sprintf("%s|%s|%t|%t",
|
||||
strings.TrimSpace(opts.ProxyURL),
|
||||
opts.Timeout.String(),
|
||||
opts.Impersonate,
|
||||
opts.ForceHTTP2,
|
||||
)
|
||||
}
|
||||
|
||||
102
backend/internal/repository/req_client_pool_test.go
Normal file
102
backend/internal/repository/req_client_pool_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func forceHTTPVersion(t *testing.T, client *req.Client) string {
|
||||
t.Helper()
|
||||
transport := client.GetTransport()
|
||||
field := reflect.ValueOf(transport).Elem().FieldByName("forceHttpVersion")
|
||||
require.True(t, field.IsValid(), "forceHttpVersion field not found")
|
||||
require.True(t, field.CanAddr(), "forceHttpVersion field not addressable")
|
||||
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().String()
|
||||
}
|
||||
|
||||
func TestGetSharedReqClient_ForceHTTP2SeparatesCache(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
base := reqClientOptions{
|
||||
ProxyURL: "http://proxy.local:8080",
|
||||
Timeout: time.Second,
|
||||
}
|
||||
clientDefault := getSharedReqClient(base)
|
||||
|
||||
force := base
|
||||
force.ForceHTTP2 = true
|
||||
clientForce := getSharedReqClient(force)
|
||||
|
||||
require.NotSame(t, clientDefault, clientForce)
|
||||
require.NotEqual(t, buildReqClientKey(base), buildReqClientKey(force))
|
||||
}
|
||||
|
||||
func TestGetSharedReqClient_ReuseCachedClient(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
opts := reqClientOptions{
|
||||
ProxyURL: "http://proxy.local:8080",
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
first := getSharedReqClient(opts)
|
||||
second := getSharedReqClient(opts)
|
||||
require.Same(t, first, second)
|
||||
}
|
||||
|
||||
func TestGetSharedReqClient_IgnoresNonClientCache(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
opts := reqClientOptions{
|
||||
ProxyURL: " http://proxy.local:8080 ",
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
key := buildReqClientKey(opts)
|
||||
sharedReqClients.Store(key, "invalid")
|
||||
|
||||
client := getSharedReqClient(opts)
|
||||
|
||||
require.NotNil(t, client)
|
||||
loaded, ok := sharedReqClients.Load(key)
|
||||
require.True(t, ok)
|
||||
require.IsType(t, "invalid", loaded)
|
||||
}
|
||||
|
||||
func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
opts := reqClientOptions{
|
||||
ProxyURL: " http://proxy.local:8080 ",
|
||||
Timeout: 4 * time.Second,
|
||||
Impersonate: true,
|
||||
}
|
||||
client := getSharedReqClient(opts)
|
||||
|
||||
require.NotNil(t, client)
|
||||
require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts))
|
||||
}
|
||||
|
||||
func TestCreateOpenAIReqClient_ForceHTTP2Enabled(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
client := createOpenAIReqClient("https://auth.openai.com/oauth/token", "http://proxy.local:8080")
|
||||
require.Equal(t, "2", forceHTTPVersion(t, client))
|
||||
}
|
||||
|
||||
func TestCreateOpenAIReqClient_ForceHTTP2DisabledForHTTP(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
client := createOpenAIReqClient("http://localhost/oauth/token", "http://proxy.local:8080")
|
||||
require.Equal(t, "", forceHTTPVersion(t, client))
|
||||
}
|
||||
|
||||
func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
client := createOpenAIReqClient("https://auth.openai.com/oauth/token", "http://proxy.local:8080")
|
||||
require.Equal(t, 120*time.Second, client.GetClient().Timeout)
|
||||
}
|
||||
|
||||
func TestCreateGeminiReqClient_ForceHTTP2Disabled(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
client := createGeminiReqClient("http://proxy.local:8080")
|
||||
require.Equal(t, "", forceHTTPVersion(t, client))
|
||||
}
|
||||
@@ -19,7 +19,7 @@ func TestSchedulerSnapshotOutboxReplay(t *testing.T) {
|
||||
|
||||
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox")
|
||||
|
||||
accountRepo := newAccountRepositoryWithSQL(client, integrationDB)
|
||||
accountRepo := newAccountRepositoryWithSQL(client, integrationDB, nil)
|
||||
outboxRepo := NewSchedulerOutboxRepository(integrationDB)
|
||||
cache := NewSchedulerCache(rdb)
|
||||
|
||||
|
||||
@@ -412,6 +412,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
deps.settingRepo.SetAll(map[string]string{
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
service.SettingKeyEmailVerifyEnabled: "false",
|
||||
service.SettingKeyPromoCodeEnabled: "true",
|
||||
|
||||
service.SettingKeySMTPHost: "smtp.example.com",
|
||||
service.SettingKeySMTPPort: "587",
|
||||
@@ -450,6 +451,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"data": {
|
||||
"registration_enabled": true,
|
||||
"email_verify_enabled": false,
|
||||
"promo_code_enabled": true,
|
||||
"smtp_host": "smtp.example.com",
|
||||
"smtp_port": 587,
|
||||
"smtp_username": "user",
|
||||
|
||||
@@ -153,8 +153,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// 应用优惠码(如果提供)
|
||||
if promoCode != "" && s.promoService != nil {
|
||||
// 应用优惠码(如果提供且功能已启用)
|
||||
if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) {
|
||||
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
|
||||
// 优惠码应用失败不影响注册,只记录日志
|
||||
log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err)
|
||||
|
||||
@@ -71,6 +71,7 @@ const (
|
||||
// 注册设置
|
||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
|
||||
|
||||
// 邮件服务设置
|
||||
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -136,11 +136,24 @@ var allowedHeaders = map[string]bool{
|
||||
"content-type": true,
|
||||
}
|
||||
|
||||
// GatewayCache defines cache operations for gateway service
|
||||
// GatewayCache 定义网关服务的缓存操作接口。
|
||||
// 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。
|
||||
//
|
||||
// GatewayCache defines cache operations for gateway service.
|
||||
// Provides sticky session storage, retrieval, refresh and deletion capabilities.
|
||||
type GatewayCache interface {
|
||||
// GetSessionAccountID 获取粘性会话绑定的账号 ID
|
||||
// Get the account ID bound to a sticky session
|
||||
GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error)
|
||||
// SetSessionAccountID 设置粘性会话与账号的绑定关系
|
||||
// Set the binding between sticky session and account
|
||||
SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error
|
||||
// RefreshSessionTTL 刷新粘性会话的过期时间
|
||||
// Refresh the expiration time of a sticky session
|
||||
RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error
|
||||
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
|
||||
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
|
||||
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
|
||||
}
|
||||
|
||||
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
|
||||
@@ -151,6 +164,28 @@ func derefGroupID(groupID *int64) int64 {
|
||||
return *groupID
|
||||
}
|
||||
|
||||
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
|
||||
// 当账号状态为错误、禁用、不可调度,或处于临时不可调度期间时,返回 true。
|
||||
// 这确保后续请求不会继续使用不可用的账号。
|
||||
//
|
||||
// shouldClearStickySession checks if an account is in an unschedulable state
|
||||
// and the sticky session binding should be cleared.
|
||||
// Returns true when account status is error/disabled, schedulable is false,
|
||||
// or within temporary unschedulable period.
|
||||
// This ensures subsequent requests won't continue using unavailable accounts.
|
||||
func shouldClearStickySession(account *Account) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if account.Status == StatusError || account.Status == StatusDisabled || !account.Schedulable {
|
||||
return true
|
||||
}
|
||||
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type AccountWaitPlan struct {
|
||||
AccountID int64
|
||||
MaxConcurrency int
|
||||
@@ -1067,6 +1102,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
|
||||
}
|
||||
} else {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1173,7 +1210,14 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||
account, ok := accountByID[accountID]
|
||||
if ok && s.isAccountInGroup(account, groupID) &&
|
||||
if ok {
|
||||
// 检查账户是否需要清理粘性会话绑定
|
||||
// Check if the account needs sticky session cleanup
|
||||
clearSticky := shouldClearStickySession(account)
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) &&
|
||||
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||
account.IsSchedulableForModel(requestedModel) &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) &&
|
||||
@@ -1181,6 +1225,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
// Session count limit check
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
|
||||
} else {
|
||||
@@ -1196,8 +1241,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||
// Session count limit check (wait plan also requires session quota)
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
// 会话限制已满,继续到 Layer 2
|
||||
// Session limit full, continue to Layer 2
|
||||
} else {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
@@ -1213,6 +1260,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Layer 2: 负载感知选择 ============
|
||||
candidates := make([]*Account, 0, len(accounts))
|
||||
@@ -1827,7 +1875,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err == nil {
|
||||
clearSticky := shouldClearStickySession(account)
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
@@ -1839,6 +1892,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2) Select an account from the routed candidates.
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
@@ -1924,7 +1978,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err == nil {
|
||||
clearSticky := shouldClearStickySession(account)
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
@@ -1933,6 +1992,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 获取可调度账号列表(单平台)
|
||||
if !accountsLoaded {
|
||||
@@ -2028,7 +2088,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err == nil {
|
||||
clearSticky := shouldClearStickySession(account)
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
@@ -2042,6 +2107,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2) Select an account from the routed candidates.
|
||||
var err error
|
||||
@@ -2127,7 +2193,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err == nil {
|
||||
clearSticky := shouldClearStickySession(account)
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
@@ -2138,6 +2209,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 获取可调度账号列表
|
||||
if !accountsLoaded {
|
||||
|
||||
@@ -82,70 +82,23 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||
// 优先检查 context 中的强制平台(/antigravity 路由)
|
||||
var platform string
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
if hasForcePlatform && forcePlatform != "" {
|
||||
platform = forcePlatform
|
||||
} else if groupID != nil {
|
||||
// 根据分组 platform 决定查询哪种账号
|
||||
var group *Group
|
||||
if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID {
|
||||
group = ctxGroup
|
||||
} else {
|
||||
var err error
|
||||
group, err = s.groupRepo.GetByIDLite(ctx, *groupID)
|
||||
// 1. 确定目标平台和调度模式
|
||||
// Determine target platform and scheduling mode
|
||||
platform, useMixedScheduling, hasForcePlatform, err := s.resolvePlatformAndSchedulingMode(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group failed: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
platform = group.Platform
|
||||
} else {
|
||||
// 无分组时只使用原生 gemini 平台
|
||||
platform = PlatformGemini
|
||||
}
|
||||
|
||||
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
||||
// 注意:强制平台模式不走混合调度
|
||||
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
|
||||
|
||||
cacheKey := "gemini:" + sessionHash
|
||||
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
|
||||
if err == nil && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
valid := false
|
||||
if account.Platform == platform {
|
||||
valid = true
|
||||
} else if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
|
||||
valid = true
|
||||
}
|
||||
if valid {
|
||||
usable := true
|
||||
if s.rateLimitService != nil && requestedModel != "" {
|
||||
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
|
||||
if err != nil {
|
||||
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
|
||||
}
|
||||
if !ok {
|
||||
usable = false
|
||||
}
|
||||
}
|
||||
if usable {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL)
|
||||
// 2. 尝试粘性会话命中
|
||||
// Try sticky session hit
|
||||
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs, platform, useMixedScheduling); account != nil {
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
|
||||
// 3. 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
|
||||
// Query schedulable accounts (force platform mode: try group first, fallback to all)
|
||||
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, platform, hasForcePlatform)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
@@ -158,56 +111,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
}
|
||||
}
|
||||
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
// 混合调度模式下:原生平台直接通过,antigravity 需要启用 mixed_scheduling
|
||||
// 非混合调度模式(antigravity 分组):不需要过滤
|
||||
if useMixedScheduling && acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||
continue
|
||||
}
|
||||
if !acc.IsSchedulableForModel(requestedModel) {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if s.rateLimitService != nil && requestedModel != "" {
|
||||
ok, err := s.rateLimitService.PreCheckUsage(ctx, acc, requestedModel)
|
||||
if err != nil {
|
||||
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", acc.ID, err)
|
||||
}
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if selected == nil {
|
||||
selected = acc
|
||||
continue
|
||||
}
|
||||
if acc.Priority < selected.Priority {
|
||||
selected = acc
|
||||
} else if acc.Priority == selected.Priority {
|
||||
switch {
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
|
||||
selected = acc
|
||||
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
||||
// keep selected (never used is preferred)
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||||
// Prefer OAuth accounts when both are unused (more compatible for Code Assist flows).
|
||||
if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth {
|
||||
selected = acc
|
||||
}
|
||||
default:
|
||||
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
||||
selected = acc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// 4. 按优先级 + LRU 选择最佳账号
|
||||
// Select best account by priority + LRU
|
||||
selected := s.selectBestGeminiAccount(ctx, accounts, requestedModel, excludedIDs, platform, useMixedScheduling)
|
||||
|
||||
if selected == nil {
|
||||
if requestedModel != "" {
|
||||
@@ -216,6 +122,8 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
return nil, errors.New("no available Gemini accounts")
|
||||
}
|
||||
|
||||
// 5. 设置粘性会话绑定
|
||||
// Set sticky session binding
|
||||
if sessionHash != "" {
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL)
|
||||
}
|
||||
@@ -223,6 +131,229 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// resolvePlatformAndSchedulingMode 解析目标平台和调度模式。
|
||||
// 返回:平台名称、是否使用混合调度、是否强制平台、错误。
|
||||
//
|
||||
// resolvePlatformAndSchedulingMode resolves target platform and scheduling mode.
|
||||
// Returns: platform name, whether to use mixed scheduling, whether force platform, error.
|
||||
func (s *GeminiMessagesCompatService) resolvePlatformAndSchedulingMode(ctx context.Context, groupID *int64) (platform string, useMixedScheduling bool, hasForcePlatform bool, err error) {
|
||||
// 优先检查 context 中的强制平台(/antigravity 路由)
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
if hasForcePlatform && forcePlatform != "" {
|
||||
return forcePlatform, false, true, nil
|
||||
}
|
||||
|
||||
if groupID != nil {
|
||||
// 根据分组 platform 决定查询哪种账号
|
||||
var group *Group
|
||||
if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID {
|
||||
group = ctxGroup
|
||||
} else {
|
||||
group, err = s.groupRepo.GetByIDLite(ctx, *groupID)
|
||||
if err != nil {
|
||||
return "", false, false, fmt.Errorf("get group failed: %w", err)
|
||||
}
|
||||
}
|
||||
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
||||
return group.Platform, group.Platform == PlatformGemini, false, nil
|
||||
}
|
||||
|
||||
// 无分组时只使用原生 gemini 平台
|
||||
return PlatformGemini, true, false, nil
|
||||
}
|
||||
|
||||
// tryStickySessionHit 尝试从粘性会话获取账号。
|
||||
// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。
|
||||
//
|
||||
// tryStickySessionHit attempts to get account from sticky session.
|
||||
// Returns account if hit and usable; clears session and returns nil if account unavailable.
|
||||
func (s *GeminiMessagesCompatService) tryStickySessionHit(
|
||||
ctx context.Context,
|
||||
groupID *int64,
|
||||
sessionHash, cacheKey, requestedModel string,
|
||||
excludedIDs map[int64]struct{},
|
||||
platform string,
|
||||
useMixedScheduling bool,
|
||||
) *Account {
|
||||
if sessionHash == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||
if err != nil || accountID <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, excluded := excludedIDs[accountID]; excluded {
|
||||
return nil
|
||||
}
|
||||
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查账号是否需要清理粘性会话
|
||||
// Check if sticky session should be cleared
|
||||
if shouldClearStickySession(account) {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 验证账号是否可用于当前请求
|
||||
// Verify account is usable for current request
|
||||
if !s.isAccountUsableForRequest(ctx, account, requestedModel, platform, useMixedScheduling) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 刷新会话 TTL 并返回账号
|
||||
// Refresh session TTL and return account
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL)
|
||||
return account
|
||||
}
|
||||
|
||||
// isAccountUsableForRequest 检查账号是否可用于当前请求。
|
||||
// 验证:模型调度、模型支持、平台匹配、速率限制预检。
|
||||
//
|
||||
// isAccountUsableForRequest checks if account is usable for current request.
|
||||
// Validates: model scheduling, model support, platform matching, rate limit precheck.
|
||||
func (s *GeminiMessagesCompatService) isAccountUsableForRequest(
|
||||
ctx context.Context,
|
||||
account *Account,
|
||||
requestedModel, platform string,
|
||||
useMixedScheduling bool,
|
||||
) bool {
|
||||
// 检查模型调度能力
|
||||
// Check model scheduling capability
|
||||
if !account.IsSchedulableForModel(requestedModel) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查模型支持
|
||||
// Check model support
|
||||
if requestedModel != "" && !s.isModelSupportedByAccount(account, requestedModel) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查平台匹配
|
||||
// Check platform matching
|
||||
if !s.isAccountValidForPlatform(account, platform, useMixedScheduling) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 速率限制预检
|
||||
// Rate limit precheck
|
||||
if !s.passesRateLimitPreCheck(ctx, account, requestedModel) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// isAccountValidForPlatform 检查账号是否匹配目标平台。
|
||||
// 原生平台直接匹配;混合调度模式下 antigravity 需要启用 mixed_scheduling。
|
||||
//
|
||||
// isAccountValidForPlatform checks if account matches target platform.
|
||||
// Native platform matches directly; mixed scheduling mode requires antigravity to enable mixed_scheduling.
|
||||
func (s *GeminiMessagesCompatService) isAccountValidForPlatform(account *Account, platform string, useMixedScheduling bool) bool {
|
||||
if account.Platform == platform {
|
||||
return true
|
||||
}
|
||||
if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// passesRateLimitPreCheck 执行速率限制预检。
|
||||
// 返回 true 表示通过预检或无需预检。
|
||||
//
|
||||
// passesRateLimitPreCheck performs rate limit precheck.
|
||||
// Returns true if passed or precheck not required.
|
||||
func (s *GeminiMessagesCompatService) passesRateLimitPreCheck(ctx context.Context, account *Account, requestedModel string) bool {
|
||||
if s.rateLimitService == nil || requestedModel == "" {
|
||||
return true
|
||||
}
|
||||
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
|
||||
if err != nil {
|
||||
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
// selectBestGeminiAccount 从候选账号中选择最佳账号(优先级 + LRU + OAuth 优先)。
|
||||
// 返回 nil 表示无可用账号。
|
||||
//
|
||||
// selectBestGeminiAccount selects best account from candidates (priority + LRU + OAuth preferred).
|
||||
// Returns nil if no available account.
|
||||
func (s *GeminiMessagesCompatService) selectBestGeminiAccount(
|
||||
ctx context.Context,
|
||||
accounts []Account,
|
||||
requestedModel string,
|
||||
excludedIDs map[int64]struct{},
|
||||
platform string,
|
||||
useMixedScheduling bool,
|
||||
) *Account {
|
||||
var selected *Account
|
||||
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
|
||||
// 跳过被排除的账号
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查账号是否可用于当前请求
|
||||
if !s.isAccountUsableForRequest(ctx, acc, requestedModel, platform, useMixedScheduling) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 选择最佳账号
|
||||
if selected == nil {
|
||||
selected = acc
|
||||
continue
|
||||
}
|
||||
|
||||
if s.isBetterGeminiAccount(acc, selected) {
|
||||
selected = acc
|
||||
}
|
||||
}
|
||||
|
||||
return selected
|
||||
}
|
||||
|
||||
// isBetterGeminiAccount 判断 candidate 是否比 current 更优。
|
||||
// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先(OAuth > 非 OAuth),其次是最久未使用的。
|
||||
//
|
||||
// isBetterGeminiAccount checks if candidate is better than current.
|
||||
// Rules: higher priority (lower value) wins; same priority: never used (OAuth > non-OAuth) > least recently used.
|
||||
func (s *GeminiMessagesCompatService) isBetterGeminiAccount(candidate, current *Account) bool {
|
||||
// 优先级更高(数值更小)
|
||||
if candidate.Priority < current.Priority {
|
||||
return true
|
||||
}
|
||||
if candidate.Priority > current.Priority {
|
||||
return false
|
||||
}
|
||||
|
||||
// 同优先级,比较最后使用时间
|
||||
switch {
|
||||
case candidate.LastUsedAt == nil && current.LastUsedAt != nil:
|
||||
// candidate 从未使用,优先
|
||||
return true
|
||||
case candidate.LastUsedAt != nil && current.LastUsedAt == nil:
|
||||
// current 从未使用,保持
|
||||
return false
|
||||
case candidate.LastUsedAt == nil && current.LastUsedAt == nil:
|
||||
// 都未使用,优先选择 OAuth 账号(更兼容 Code Assist 流程)
|
||||
return candidate.Type == AccountTypeOAuth && current.Type != AccountTypeOAuth
|
||||
default:
|
||||
// 都使用过,选择最久未使用的
|
||||
return candidate.LastUsedAt.Before(*current.LastUsedAt)
|
||||
}
|
||||
}
|
||||
|
||||
// isModelSupportedByAccount 根据账户平台检查模型支持
|
||||
func (s *GeminiMessagesCompatService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
|
||||
if account.Platform == PlatformAntigravity {
|
||||
|
||||
@@ -17,6 +17,8 @@ import (
|
||||
type mockAccountRepoForGemini struct {
|
||||
accounts []Account
|
||||
accountsByID map[int64]*Account
|
||||
listByGroupFunc func(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
|
||||
listByPlatformFunc func(ctx context.Context, platforms []string) ([]Account, error)
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForGemini) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
@@ -107,6 +109,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupID(ctx context.Context,
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
|
||||
if m.listByPlatformFunc != nil {
|
||||
return m.listByPlatformFunc(ctx, platforms)
|
||||
}
|
||||
var result []Account
|
||||
platformSet := make(map[string]bool)
|
||||
for _, p := range platforms {
|
||||
@@ -120,6 +125,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Contex
|
||||
return result, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
|
||||
if m.listByGroupFunc != nil {
|
||||
return m.listByGroupFunc(ctx, groupID, platforms)
|
||||
}
|
||||
return m.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
@@ -215,6 +223,7 @@ var _ GroupRepository = (*mockGroupRepoForGemini)(nil)
|
||||
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
|
||||
type mockGatewayCacheForGemini struct {
|
||||
sessionBindings map[string]int64
|
||||
deletedSessions map[string]int
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
|
||||
@@ -236,6 +245,18 @@ func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, group
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
|
||||
if m.sessionBindings == nil {
|
||||
return nil
|
||||
}
|
||||
if m.deletedSessions == nil {
|
||||
m.deletedSessions = make(map[string]int)
|
||||
}
|
||||
m.deletedSessions[sessionHash]++
|
||||
delete(m.sessionBindings, sessionHash)
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
@@ -526,6 +547,274 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickyS
|
||||
// 粘性会话未命中,按优先级选择
|
||||
require.Equal(t, int64(2), acc.ID, "粘性会话未命中,应按优先级选择")
|
||||
})
|
||||
|
||||
t.Run("粘性会话不可调度-清理并回退选择", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 2, Status: StatusDisabled, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{
|
||||
sessionBindings: map[string]int64{"gemini:session-123": 1},
|
||||
}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID)
|
||||
require.Equal(t, 1, cache.deletedSessions["gemini:session-123"])
|
||||
require.Equal(t, int64(2), cache.sessionBindings["gemini:session-123"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ForcePlatformFallback(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(9)
|
||||
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, PlatformAntigravity)
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
listByGroupFunc: func(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
|
||||
return nil, nil
|
||||
},
|
||||
listByPlatformFunc: func(ctx context.Context, platforms []string) ([]Account, error) {
|
||||
return []Account{
|
||||
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
}, nil
|
||||
},
|
||||
accountsByID: map[int64]*Account{
|
||||
1: {ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(1), acc.ID)
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_NoModelSupport(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{
|
||||
ID: 1,
|
||||
Platform: PlatformGemini,
|
||||
Priority: 1,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-1.0-pro": "gemini-1.0-pro"}},
|
||||
},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, acc)
|
||||
require.Contains(t, err.Error(), "supporting model")
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickyMixedScheduling(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{
|
||||
sessionBindings: map[string]int64{"gemini:session-999": 1},
|
||||
}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-999", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(1), acc.ID)
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_SkipDisabledMixedScheduling(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID)
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ExcludedAccount(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
excluded := map[int64]struct{}{1: {}}
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", excluded)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID)
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ListError(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := &mockAccountRepoForGemini{
|
||||
listByPlatformFunc: func(ctx context.Context, platforms []string) ([]Account, error) {
|
||||
return nil, errors.New("query failed")
|
||||
},
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, acc)
|
||||
require.Contains(t, err.Error(), "query accounts failed")
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_PreferOAuth(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-pro", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID)
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_PreferLeastRecentlyUsed(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
oldTime := time.Now().Add(-2 * time.Hour)
|
||||
newTime := time.Now().Add(-1 * time.Hour)
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: &newTime},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: &oldTime},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-pro", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID)
|
||||
}
|
||||
|
||||
// TestGeminiPlatformRouting_DocumentRouteDecision 测试平台路由决策逻辑
|
||||
|
||||
@@ -180,67 +180,26 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
|
||||
}
|
||||
|
||||
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
|
||||
// SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。
|
||||
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||
// 1. Check sticky session
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
// Refresh sticky session TTL
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
|
||||
cacheKey := "openai:" + sessionHash
|
||||
|
||||
// 1. 尝试粘性会话命中
|
||||
// Try sticky session hit
|
||||
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs); account != nil {
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Get schedulable OpenAI accounts
|
||||
// 2. 获取可调度的 OpenAI 账号
|
||||
// Get schedulable OpenAI accounts
|
||||
accounts, err := s.listSchedulableAccounts(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
|
||||
// 3. Select by priority + LRU
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
|
||||
// avoid selecting accounts that were recently rate-limited/overloaded.
|
||||
if !acc.IsSchedulable() {
|
||||
continue
|
||||
}
|
||||
// Check model support
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
continue
|
||||
}
|
||||
if selected == nil {
|
||||
selected = acc
|
||||
continue
|
||||
}
|
||||
// Lower priority value means higher priority
|
||||
if acc.Priority < selected.Priority {
|
||||
selected = acc
|
||||
} else if acc.Priority == selected.Priority {
|
||||
switch {
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
|
||||
selected = acc
|
||||
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
||||
// keep selected (never used is preferred)
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||||
// keep selected (both never used)
|
||||
default:
|
||||
// Same priority, select least recently used
|
||||
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
||||
selected = acc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// 3. 按优先级 + LRU 选择最佳账号
|
||||
// Select by priority + LRU
|
||||
selected := s.selectBestAccount(accounts, requestedModel, excludedIDs)
|
||||
|
||||
if selected == nil {
|
||||
if requestedModel != "" {
|
||||
@@ -249,14 +208,138 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
|
||||
return nil, errors.New("no available OpenAI accounts")
|
||||
}
|
||||
|
||||
// 4. Set sticky session
|
||||
// 4. 设置粘性会话绑定
|
||||
// Set sticky session binding
|
||||
if sessionHash != "" {
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, selected.ID, openaiStickySessionTTL)
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, openaiStickySessionTTL)
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// tryStickySessionHit 尝试从粘性会话获取账号。
|
||||
// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。
|
||||
//
|
||||
// tryStickySessionHit attempts to get account from sticky session.
|
||||
// Returns account if hit and usable; clears session and returns nil if account is unavailable.
|
||||
func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, cacheKey, requestedModel string, excludedIDs map[int64]struct{}) *Account {
|
||||
if sessionHash == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||
if err != nil || accountID <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, excluded := excludedIDs[accountID]; excluded {
|
||||
return nil
|
||||
}
|
||||
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查账号是否需要清理粘性会话
|
||||
// Check if sticky session should be cleared
|
||||
if shouldClearStickySession(account) {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 验证账号是否可用于当前请求
|
||||
// Verify account is usable for current request
|
||||
if !account.IsSchedulable() || !account.IsOpenAI() {
|
||||
return nil
|
||||
}
|
||||
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 刷新会话 TTL 并返回账号
|
||||
// Refresh session TTL and return account
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, openaiStickySessionTTL)
|
||||
return account
|
||||
}
|
||||
|
||||
// selectBestAccount 从候选账号中选择最佳账号(优先级 + LRU)。
|
||||
// 返回 nil 表示无可用账号。
|
||||
//
|
||||
// selectBestAccount selects the best account from candidates (priority + LRU).
|
||||
// Returns nil if no available account.
|
||||
func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
|
||||
var selected *Account
|
||||
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
|
||||
// 跳过被排除的账号
|
||||
// Skip excluded accounts
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
|
||||
// 调度器快照可能暂时过时,这里重新检查可调度性和平台
|
||||
// Scheduler snapshots can be temporarily stale; re-check schedulability and platform
|
||||
if !acc.IsSchedulable() || !acc.IsOpenAI() {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查模型支持
|
||||
// Check model support
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 选择优先级最高且最久未使用的账号
|
||||
// Select highest priority and least recently used
|
||||
if selected == nil {
|
||||
selected = acc
|
||||
continue
|
||||
}
|
||||
|
||||
if s.isBetterAccount(acc, selected) {
|
||||
selected = acc
|
||||
}
|
||||
}
|
||||
|
||||
return selected
|
||||
}
|
||||
|
||||
// isBetterAccount 判断 candidate 是否比 current 更优。
|
||||
// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先,其次是最久未使用的。
|
||||
//
|
||||
// isBetterAccount checks if candidate is better than current.
|
||||
// Rules: higher priority (lower value) wins; same priority: never used > least recently used.
|
||||
func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool {
|
||||
// 优先级更高(数值更小)
|
||||
// Higher priority (lower value)
|
||||
if candidate.Priority < current.Priority {
|
||||
return true
|
||||
}
|
||||
if candidate.Priority > current.Priority {
|
||||
return false
|
||||
}
|
||||
|
||||
// 同优先级,比较最后使用时间
|
||||
// Same priority, compare last used time
|
||||
switch {
|
||||
case candidate.LastUsedAt == nil && current.LastUsedAt != nil:
|
||||
// candidate 从未使用,优先
|
||||
return true
|
||||
case candidate.LastUsedAt != nil && current.LastUsedAt == nil:
|
||||
// current 从未使用,保持
|
||||
return false
|
||||
case candidate.LastUsedAt == nil && current.LastUsedAt == nil:
|
||||
// 都未使用,保持
|
||||
return false
|
||||
default:
|
||||
// 都使用过,选择最久未使用的
|
||||
return candidate.LastUsedAt.Before(*current.LastUsedAt)
|
||||
}
|
||||
}
|
||||
|
||||
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
|
||||
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
|
||||
cfg := s.schedulingConfig()
|
||||
@@ -325,7 +408,12 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
|
||||
if err == nil {
|
||||
clearSticky := shouldClearStickySession(account)
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
||||
}
|
||||
if !clearSticky && account.IsSchedulable() && account.IsOpenAI() &&
|
||||
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
@@ -352,6 +440,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Layer 2: Load-aware selection ============
|
||||
candidates := make([]*Account, 0, len(accounts))
|
||||
|
||||
@@ -21,16 +21,42 @@ type stubOpenAIAccountRepo struct {
|
||||
accounts []Account
|
||||
}
|
||||
|
||||
func (r stubOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
for i := range r.accounts {
|
||||
if r.accounts[i].ID == id {
|
||||
return &r.accounts[i], nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("account not found")
|
||||
}
|
||||
|
||||
func (r stubOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
|
||||
return append([]Account(nil), r.accounts...), nil
|
||||
var result []Account
|
||||
for _, acc := range r.accounts {
|
||||
if acc.Platform == platform {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
return append([]Account(nil), r.accounts...), nil
|
||||
var result []Account
|
||||
for _, acc := range r.accounts {
|
||||
if acc.Platform == platform {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type stubConcurrencyCache struct {
|
||||
ConcurrencyCache
|
||||
loadBatchErr error
|
||||
loadMap map[int64]*AccountLoadInfo
|
||||
acquireResults map[int64]bool
|
||||
waitCounts map[int64]int
|
||||
skipDefaultLoad bool
|
||||
}
|
||||
|
||||
type cancelReadCloser struct{}
|
||||
@@ -53,6 +79,11 @@ func (w *failingGinWriter) Write(p []byte) (int, error) {
|
||||
}
|
||||
|
||||
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
if c.acquireResults != nil {
|
||||
if result, ok := c.acquireResults[accountID]; ok {
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -61,8 +92,25 @@ func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||
if c.loadBatchErr != nil {
|
||||
return nil, c.loadBatchErr
|
||||
}
|
||||
out := make(map[int64]*AccountLoadInfo, len(accounts))
|
||||
if c.skipDefaultLoad && c.loadMap != nil {
|
||||
for _, acc := range accounts {
|
||||
if load, ok := c.loadMap[acc.ID]; ok {
|
||||
out[acc.ID] = load
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
for _, acc := range accounts {
|
||||
if c.loadMap != nil {
|
||||
if load, ok := c.loadMap[acc.ID]; ok {
|
||||
out[acc.ID] = load
|
||||
continue
|
||||
}
|
||||
}
|
||||
out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
|
||||
}
|
||||
return out, nil
|
||||
@@ -111,6 +159,51 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||
if c.waitCounts != nil {
|
||||
if count, ok := c.waitCounts[accountID]; ok {
|
||||
return count, nil
|
||||
}
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
type stubGatewayCache struct {
|
||||
sessionBindings map[string]int64
|
||||
deletedSessions map[string]int
|
||||
}
|
||||
|
||||
func (c *stubGatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
|
||||
if id, ok := c.sessionBindings[sessionHash]; ok {
|
||||
return id, nil
|
||||
}
|
||||
return 0, errors.New("not found")
|
||||
}
|
||||
|
||||
func (c *stubGatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||
if c.sessionBindings == nil {
|
||||
c.sessionBindings = make(map[string]int64)
|
||||
}
|
||||
c.sessionBindings[sessionHash] = accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stubGatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stubGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
|
||||
if c.sessionBindings == nil {
|
||||
return nil
|
||||
}
|
||||
if c.deletedSessions == nil {
|
||||
c.deletedSessions = make(map[string]int)
|
||||
}
|
||||
c.deletedSessions[sessionHash]++
|
||||
delete(c.sessionBindings, sessionHash)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
|
||||
now := time.Now()
|
||||
resetAt := now.Add(10 * time.Minute)
|
||||
@@ -201,6 +294,515 @@ func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurre
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_StickyUnschedulableClearsSession(t *testing.T) {
|
||||
sessionHash := "session-1"
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusDisabled, Schedulable: true, Concurrency: 1},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
||||
}
|
||||
if acc == nil || acc.ID != 2 {
|
||||
t.Fatalf("expected account 2, got %+v", acc)
|
||||
}
|
||||
if cache.deletedSessions["openai:"+sessionHash] != 1 {
|
||||
t.Fatalf("expected sticky session to be deleted")
|
||||
}
|
||||
if cache.sessionBindings["openai:"+sessionHash] != 2 {
|
||||
t.Fatalf("expected sticky session to bind to account 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_StickyUnschedulableClearsSession(t *testing.T) {
|
||||
sessionHash := "session-2"
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusDisabled, Schedulable: true, Concurrency: 1},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, sessionHash, "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
|
||||
t.Fatalf("expected account 2, got %+v", selection)
|
||||
}
|
||||
if cache.deletedSessions["openai:"+sessionHash] != 1 {
|
||||
t.Fatalf("expected sticky session to be deleted")
|
||||
}
|
||||
if cache.sessionBindings["openai:"+sessionHash] != 2 {
|
||||
t.Fatalf("expected sticky session to bind to account 2")
|
||||
}
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_NoModelSupport(t *testing.T) {
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{
|
||||
ID: 1,
|
||||
Platform: PlatformOpenAI,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"model_mapping": map[string]any{"gpt-3.5-turbo": "gpt-3.5-turbo"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "gpt-4", nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for unsupported model")
|
||||
}
|
||||
if acc != nil {
|
||||
t.Fatalf("expected nil account for unsupported model")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "supporting model") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_LoadBatchErrorFallback(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadBatchErr: errors.New("load batch failed"),
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "fallback", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
t.Fatalf("expected selection")
|
||||
}
|
||||
if selection.Account.ID != 2 {
|
||||
t.Fatalf("expected account 2, got %d", selection.Account.ID)
|
||||
}
|
||||
if cache.sessionBindings["openai:fallback"] != 2 {
|
||||
t.Fatalf("expected sticky session updated")
|
||||
}
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_NoSlotFallbackWait(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
acquireResults: map[int64]bool{1: false},
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, LoadRate: 10},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.WaitPlan == nil {
|
||||
t.Fatalf("expected wait plan fallback")
|
||||
}
|
||||
if selection.Account == nil || selection.Account.ID != 1 {
|
||||
t.Fatalf("expected account 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_SetsStickyBinding(t *testing.T) {
|
||||
sessionHash := "bind"
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
||||
}
|
||||
if acc == nil || acc.ID != 1 {
|
||||
t.Fatalf("expected account 1")
|
||||
}
|
||||
if cache.sessionBindings["openai:"+sessionHash] != 1 {
|
||||
t.Fatalf("expected sticky session binding")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_StickyWaitPlan(t *testing.T) {
|
||||
sessionHash := "sticky-wait"
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
||||
}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
acquireResults: map[int64]bool{1: false},
|
||||
waitCounts: map[int64]int{1: 0},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, sessionHash, "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.WaitPlan == nil {
|
||||
t.Fatalf("expected sticky wait plan")
|
||||
}
|
||||
if selection.Account == nil || selection.Account.ID != 1 {
|
||||
t.Fatalf("expected account 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_PrefersLowerLoad(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, LoadRate: 80},
|
||||
2: {AccountID: 2, LoadRate: 10},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "load", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
|
||||
t.Fatalf("expected account 2")
|
||||
}
|
||||
if cache.sessionBindings["openai:load"] != 2 {
|
||||
t.Fatalf("expected sticky session updated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_StickyExcludedFallback(t *testing.T) {
|
||||
sessionHash := "excluded"
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
excluded := map[int64]struct{}{1: {}}
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", excluded)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
||||
}
|
||||
if acc == nil || acc.ID != 2 {
|
||||
t.Fatalf("expected account 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_StickyNonOpenAI(t *testing.T) {
|
||||
sessionHash := "non-openai"
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
||||
}
|
||||
if acc == nil || acc.ID != 2 {
|
||||
t.Fatalf("expected account 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_NoAccounts(t *testing.T) {
|
||||
repo := stubOpenAIAccountRepo{accounts: []Account{}}
|
||||
cache := &stubGatewayCache{}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "", nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for no accounts")
|
||||
}
|
||||
if acc != nil {
|
||||
t.Fatalf("expected nil account")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no available OpenAI accounts") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_NoCandidates(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
resetAt := time.Now().Add(1 * time.Hour)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, RateLimitResetAt: &resetAt},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for no candidates")
|
||||
}
|
||||
if selection != nil {
|
||||
t.Fatalf("expected nil selection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_AllFullWaitPlan(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, LoadRate: 100},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.WaitPlan == nil {
|
||||
t.Fatalf("expected wait plan")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_LoadBatchErrorNoAcquire(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadBatchErr: errors.New("load batch failed"),
|
||||
acquireResults: map[int64]bool{1: false},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.WaitPlan == nil {
|
||||
t.Fatalf("expected wait plan")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_MissingLoadInfo(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, LoadRate: 50},
|
||||
},
|
||||
skipDefaultLoad: true,
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
|
||||
t.Fatalf("expected account 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_LeastRecentlyUsed(t *testing.T) {
|
||||
oldTime := time.Now().Add(-2 * time.Hour)
|
||||
newTime := time.Now().Add(-1 * time.Hour)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1, LastUsedAt: &newTime},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1, LastUsedAt: &oldTime},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
||||
}
|
||||
if acc == nil || acc.ID != 2 {
|
||||
t.Fatalf("expected account 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_PreferNeverUsed(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
lastUsed := time.Now().Add(-1 * time.Hour)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, LastUsedAt: &lastUsed},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, LoadRate: 10},
|
||||
2: {AccountID: 2, LoadRate: 10},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
|
||||
t.Fatalf("expected account 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingTimeout(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
|
||||
@@ -60,6 +60,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
keys := []string{
|
||||
SettingKeyRegistrationEnabled,
|
||||
SettingKeyEmailVerifyEnabled,
|
||||
SettingKeyPromoCodeEnabled,
|
||||
SettingKeyTurnstileEnabled,
|
||||
SettingKeyTurnstileSiteKey,
|
||||
SettingKeySiteName,
|
||||
@@ -88,6 +89,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
return &PublicSettings{
|
||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
|
||||
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
|
||||
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
|
||||
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
|
||||
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
|
||||
@@ -125,6 +127,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
return &struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key,omitempty"`
|
||||
SiteName string `json:"site_name"`
|
||||
@@ -140,6 +143,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
}{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
@@ -162,6 +166,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
// 注册设置
|
||||
updates[SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled)
|
||||
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
|
||||
updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled)
|
||||
|
||||
// 邮件服务设置(只有非空才更新密码)
|
||||
updates[SettingKeySMTPHost] = settings.SMTPHost
|
||||
@@ -248,6 +253,15 @@ func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
|
||||
return value == "true"
|
||||
}
|
||||
|
||||
// IsPromoCodeEnabled 检查是否启用优惠码功能
|
||||
func (s *SettingService) IsPromoCodeEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyPromoCodeEnabled)
|
||||
if err != nil {
|
||||
return true // 默认启用
|
||||
}
|
||||
return value != "false"
|
||||
}
|
||||
|
||||
// GetSiteName 获取网站名称
|
||||
func (s *SettingService) GetSiteName(ctx context.Context) string {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeySiteName)
|
||||
@@ -297,6 +311,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
defaults := map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "false",
|
||||
SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
|
||||
SettingKeySiteName: "Sub2API",
|
||||
SettingKeySiteLogo: "",
|
||||
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
||||
@@ -328,6 +343,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
result := &SystemSettings{
|
||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
|
||||
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
|
||||
SMTPHost: settings[SettingKeySMTPHost],
|
||||
SMTPUsername: settings[SettingKeySMTPUsername],
|
||||
SMTPFrom: settings[SettingKeySMTPFrom],
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
type SystemSettings struct {
|
||||
RegistrationEnabled bool
|
||||
EmailVerifyEnabled bool
|
||||
PromoCodeEnabled bool
|
||||
|
||||
SMTPHost string
|
||||
SMTPPort int
|
||||
@@ -58,6 +59,7 @@ type SystemSettings struct {
|
||||
type PublicSettings struct {
|
||||
RegistrationEnabled bool
|
||||
EmailVerifyEnabled bool
|
||||
PromoCodeEnabled bool
|
||||
TurnstileEnabled bool
|
||||
TurnstileSiteKey string
|
||||
SiteName string
|
||||
|
||||
54
backend/internal/service/sticky_session_test.go
Normal file
54
backend/internal/service/sticky_session_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
//go:build unit
|
||||
|
||||
// Package service 提供 API 网关核心服务。
|
||||
// 本文件包含 shouldClearStickySession 函数的单元测试,
|
||||
// 验证粘性会话清理逻辑在各种账号状态下的正确行为。
|
||||
//
|
||||
// This file contains unit tests for the shouldClearStickySession function,
|
||||
// verifying correct sticky session clearing behavior under various account states.
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestShouldClearStickySession 测试粘性会话清理判断逻辑。
|
||||
// 验证在以下情况下是否正确判断需要清理粘性会话:
|
||||
// - nil 账号:不清理(返回 false)
|
||||
// - 状态为错误或禁用:清理
|
||||
// - 不可调度:清理
|
||||
// - 临时不可调度且未过期:清理
|
||||
// - 临时不可调度已过期:不清理
|
||||
// - 正常可调度状态:不清理
|
||||
//
|
||||
// TestShouldClearStickySession tests the sticky session clearing logic.
|
||||
// Verifies correct behavior for various account states including:
|
||||
// nil account, error/disabled status, unschedulable, temporary unschedulable.
|
||||
func TestShouldClearStickySession(t *testing.T) {
|
||||
now := time.Now()
|
||||
future := now.Add(1 * time.Hour)
|
||||
past := now.Add(-1 * time.Hour)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
want bool
|
||||
}{
|
||||
{name: "nil account", account: nil, want: false},
|
||||
{name: "status error", account: &Account{Status: StatusError, Schedulable: true}, want: true},
|
||||
{name: "status disabled", account: &Account{Status: StatusDisabled, Schedulable: true}, want: true},
|
||||
{name: "schedulable false", account: &Account{Status: StatusActive, Schedulable: false}, want: true},
|
||||
{name: "temp unschedulable", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &future}, want: true},
|
||||
{name: "temp unschedulable expired", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &past}, want: false},
|
||||
{name: "active schedulable", account: &Account{Status: StatusActive, Schedulable: true}, want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, shouldClearStickySession(tt.account))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -345,6 +345,9 @@ func TestUsageCleanupServiceRunOnceSuccess(t *testing.T) {
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.deleteCalls, 3)
|
||||
require.Equal(t, 2, repo.deleteCalls[0].limit)
|
||||
require.True(t, repo.deleteCalls[0].filters.StartTime.Equal(start))
|
||||
require.True(t, repo.deleteCalls[0].filters.EndTime.Equal(end))
|
||||
require.Len(t, repo.markSucceeded, 1)
|
||||
require.Empty(t, repo.markFailed)
|
||||
require.Equal(t, int64(5), repo.markSucceeded[0].taskID)
|
||||
|
||||
@@ -12,6 +12,7 @@ export interface SystemSettings {
|
||||
// Registration settings
|
||||
registration_enabled: boolean
|
||||
email_verify_enabled: boolean
|
||||
promo_code_enabled: boolean
|
||||
// Default settings
|
||||
default_balance: number
|
||||
default_concurrency: number
|
||||
@@ -64,6 +65,7 @@ export interface SystemSettings {
|
||||
export interface UpdateSettingsRequest {
|
||||
registration_enabled?: boolean
|
||||
email_verify_enabled?: boolean
|
||||
promo_code_enabled?: boolean
|
||||
default_balance?: number
|
||||
default_concurrency?: number
|
||||
site_name?: string
|
||||
|
||||
@@ -2726,7 +2726,9 @@ export default {
|
||||
enableRegistration: 'Enable Registration',
|
||||
enableRegistrationHint: 'Allow new users to register',
|
||||
emailVerification: 'Email Verification',
|
||||
emailVerificationHint: 'Require email verification for new registrations'
|
||||
emailVerificationHint: 'Require email verification for new registrations',
|
||||
promoCode: 'Promo Code',
|
||||
promoCodeHint: 'Allow users to use promo codes during registration'
|
||||
},
|
||||
turnstile: {
|
||||
title: 'Cloudflare Turnstile',
|
||||
|
||||
@@ -2879,7 +2879,9 @@ export default {
|
||||
enableRegistration: '开放注册',
|
||||
enableRegistrationHint: '允许新用户注册',
|
||||
emailVerification: '邮箱验证',
|
||||
emailVerificationHint: '新用户注册时需要验证邮箱'
|
||||
emailVerificationHint: '新用户注册时需要验证邮箱',
|
||||
promoCode: '优惠码',
|
||||
promoCodeHint: '允许用户在注册时使用优惠码'
|
||||
},
|
||||
turnstile: {
|
||||
title: 'Cloudflare Turnstile',
|
||||
|
||||
@@ -312,6 +312,7 @@ export const useAppStore = defineStore('app', () => {
|
||||
return {
|
||||
registration_enabled: false,
|
||||
email_verify_enabled: false,
|
||||
promo_code_enabled: true,
|
||||
turnstile_enabled: false,
|
||||
turnstile_site_key: '',
|
||||
site_name: siteName.value,
|
||||
|
||||
@@ -70,6 +70,7 @@ export interface SendVerifyCodeResponse {
|
||||
export interface PublicSettings {
|
||||
registration_enabled: boolean
|
||||
email_verify_enabled: boolean
|
||||
promo_code_enabled: boolean
|
||||
turnstile_enabled: boolean
|
||||
turnstile_site_key: string
|
||||
site_name: string
|
||||
|
||||
@@ -238,7 +238,30 @@
|
||||
v-model="generateForm.group_id"
|
||||
:options="subscriptionGroupOptions"
|
||||
:placeholder="t('admin.redeem.selectGroupPlaceholder')"
|
||||
>
|
||||
<template #selected="{ option }">
|
||||
<GroupBadge
|
||||
v-if="option"
|
||||
:name="(option as unknown as GroupOption).label"
|
||||
:platform="(option as unknown as GroupOption).platform"
|
||||
:subscription-type="(option as unknown as GroupOption).subscriptionType"
|
||||
:rate-multiplier="(option as unknown as GroupOption).rate"
|
||||
/>
|
||||
<span v-else class="text-gray-400">{{
|
||||
t('admin.redeem.selectGroupPlaceholder')
|
||||
}}</span>
|
||||
</template>
|
||||
<template #option="{ option, selected }">
|
||||
<GroupOptionItem
|
||||
:name="(option as unknown as GroupOption).label"
|
||||
:platform="(option as unknown as GroupOption).platform"
|
||||
:subscription-type="(option as unknown as GroupOption).subscriptionType"
|
||||
:rate-multiplier="(option as unknown as GroupOption).rate"
|
||||
:description="(option as unknown as GroupOption).description"
|
||||
:selected="selected"
|
||||
/>
|
||||
</template>
|
||||
</Select>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.redeem.validityDays') }}</label>
|
||||
@@ -370,7 +393,7 @@ import { useAppStore } from '@/stores/app'
|
||||
import { useClipboard } from '@/composables/useClipboard'
|
||||
import { adminAPI } from '@/api/admin'
|
||||
import { formatDateTime } from '@/utils/format'
|
||||
import type { RedeemCode, RedeemCodeType, Group } from '@/types'
|
||||
import type { RedeemCode, RedeemCodeType, Group, GroupPlatform, SubscriptionType } from '@/types'
|
||||
import type { Column } from '@/components/common/types'
|
||||
import AppLayout from '@/components/layout/AppLayout.vue'
|
||||
import TablePageLayout from '@/components/layout/TablePageLayout.vue'
|
||||
@@ -378,12 +401,23 @@ import DataTable from '@/components/common/DataTable.vue'
|
||||
import Pagination from '@/components/common/Pagination.vue'
|
||||
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
|
||||
import Select from '@/components/common/Select.vue'
|
||||
import GroupBadge from '@/components/common/GroupBadge.vue'
|
||||
import GroupOptionItem from '@/components/common/GroupOptionItem.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
|
||||
const { t } = useI18n()
|
||||
const appStore = useAppStore()
|
||||
const { copyToClipboard: clipboardCopy } = useClipboard()
|
||||
|
||||
interface GroupOption {
|
||||
value: number
|
||||
label: string
|
||||
description: string | null
|
||||
platform: GroupPlatform
|
||||
subscriptionType: SubscriptionType
|
||||
rate: number
|
||||
}
|
||||
|
||||
const showGenerateDialog = ref(false)
|
||||
const showResultDialog = ref(false)
|
||||
const generatedCodes = ref<RedeemCode[]>([])
|
||||
@@ -395,7 +429,11 @@ const subscriptionGroupOptions = computed(() => {
|
||||
.filter((g) => g.subscription_type === 'subscription')
|
||||
.map((g) => ({
|
||||
value: g.id,
|
||||
label: g.name
|
||||
label: g.name,
|
||||
description: g.description,
|
||||
platform: g.platform,
|
||||
subscriptionType: g.subscription_type,
|
||||
rate: g.rate_multiplier
|
||||
}))
|
||||
})
|
||||
|
||||
|
||||
@@ -323,6 +323,21 @@
|
||||
</div>
|
||||
<Toggle v-model="form.email_verify_enabled" />
|
||||
</div>
|
||||
|
||||
<!-- Promo Code -->
|
||||
<div
|
||||
class="flex items-center justify-between border-t border-gray-100 pt-4 dark:border-dark-700"
|
||||
>
|
||||
<div>
|
||||
<label class="font-medium text-gray-900 dark:text-white">{{
|
||||
t('admin.settings.registration.promoCode')
|
||||
}}</label>
|
||||
<p class="text-sm text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.settings.registration.promoCodeHint') }}
|
||||
</p>
|
||||
</div>
|
||||
<Toggle v-model="form.promo_code_enabled" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -1013,6 +1028,7 @@ type SettingsForm = SystemSettings & {
|
||||
const form = reactive<SettingsForm>({
|
||||
registration_enabled: true,
|
||||
email_verify_enabled: false,
|
||||
promo_code_enabled: true,
|
||||
default_balance: 0,
|
||||
default_concurrency: 1,
|
||||
site_name: 'Sub2API',
|
||||
@@ -1135,6 +1151,7 @@ async function saveSettings() {
|
||||
const payload: UpdateSettingsRequest = {
|
||||
registration_enabled: form.registration_enabled,
|
||||
email_verify_enabled: form.email_verify_enabled,
|
||||
promo_code_enabled: form.promo_code_enabled,
|
||||
default_balance: form.default_balance,
|
||||
default_concurrency: form.default_concurrency,
|
||||
site_name: form.site_name,
|
||||
|
||||
@@ -466,7 +466,28 @@
|
||||
v-model="assignForm.group_id"
|
||||
:options="subscriptionGroupOptions"
|
||||
:placeholder="t('admin.subscriptions.selectGroup')"
|
||||
>
|
||||
<template #selected="{ option }">
|
||||
<GroupBadge
|
||||
v-if="option"
|
||||
:name="(option as unknown as GroupOption).label"
|
||||
:platform="(option as unknown as GroupOption).platform"
|
||||
:subscription-type="(option as unknown as GroupOption).subscriptionType"
|
||||
:rate-multiplier="(option as unknown as GroupOption).rate"
|
||||
/>
|
||||
<span v-else class="text-gray-400">{{ t('admin.subscriptions.selectGroup') }}</span>
|
||||
</template>
|
||||
<template #option="{ option, selected }">
|
||||
<GroupOptionItem
|
||||
:name="(option as unknown as GroupOption).label"
|
||||
:platform="(option as unknown as GroupOption).platform"
|
||||
:subscription-type="(option as unknown as GroupOption).subscriptionType"
|
||||
:rate-multiplier="(option as unknown as GroupOption).rate"
|
||||
:description="(option as unknown as GroupOption).description"
|
||||
:selected="selected"
|
||||
/>
|
||||
</template>
|
||||
</Select>
|
||||
<p class="input-hint">{{ t('admin.subscriptions.groupHint') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
@@ -599,7 +620,7 @@ import { ref, reactive, computed, onMounted, onUnmounted } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { useAppStore } from '@/stores/app'
|
||||
import { adminAPI } from '@/api/admin'
|
||||
import type { UserSubscription, Group } from '@/types'
|
||||
import type { UserSubscription, Group, GroupPlatform, SubscriptionType } from '@/types'
|
||||
import type { SimpleUser } from '@/api/admin/usage'
|
||||
import type { Column } from '@/components/common/types'
|
||||
import { formatDateOnly } from '@/utils/format'
|
||||
@@ -612,11 +633,21 @@ import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
|
||||
import EmptyState from '@/components/common/EmptyState.vue'
|
||||
import Select from '@/components/common/Select.vue'
|
||||
import GroupBadge from '@/components/common/GroupBadge.vue'
|
||||
import GroupOptionItem from '@/components/common/GroupOptionItem.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
|
||||
const { t } = useI18n()
|
||||
const appStore = useAppStore()
|
||||
|
||||
interface GroupOption {
|
||||
value: number
|
||||
label: string
|
||||
description: string | null
|
||||
platform: GroupPlatform
|
||||
subscriptionType: SubscriptionType
|
||||
rate: number
|
||||
}
|
||||
|
||||
// User column display mode: 'email' or 'username'
|
||||
const userColumnMode = ref<'email' | 'username'>('email')
|
||||
const USER_COLUMN_MODE_KEY = 'subscription-user-column-mode'
|
||||
@@ -792,7 +823,14 @@ const groupOptions = computed(() => [
|
||||
const subscriptionGroupOptions = computed(() =>
|
||||
groups.value
|
||||
.filter((g) => g.subscription_type === 'subscription' && g.status === 'active')
|
||||
.map((g) => ({ value: g.id, label: g.name }))
|
||||
.map((g) => ({
|
||||
value: g.id,
|
||||
label: g.name,
|
||||
description: g.description,
|
||||
platform: g.platform,
|
||||
subscriptionType: g.subscription_type,
|
||||
rate: g.rate_multiplier
|
||||
}))
|
||||
)
|
||||
|
||||
const applyFilters = () => {
|
||||
|
||||
@@ -96,7 +96,7 @@
|
||||
</div>
|
||||
|
||||
<!-- Promo Code Input (Optional) -->
|
||||
<div>
|
||||
<div v-if="promoCodeEnabled">
|
||||
<label for="promo_code" class="input-label">
|
||||
{{ t('auth.promoCodeLabel') }}
|
||||
<span class="ml-1 text-xs font-normal text-gray-400 dark:text-dark-500">({{ t('common.optional') }})</span>
|
||||
@@ -260,6 +260,7 @@ const showPassword = ref<boolean>(false)
|
||||
// Public settings
|
||||
const registrationEnabled = ref<boolean>(true)
|
||||
const emailVerifyEnabled = ref<boolean>(false)
|
||||
const promoCodeEnabled = ref<boolean>(true)
|
||||
const turnstileEnabled = ref<boolean>(false)
|
||||
const turnstileSiteKey = ref<string>('')
|
||||
const siteName = ref<string>('Sub2API')
|
||||
@@ -294,22 +295,25 @@ const errors = reactive({
|
||||
// ==================== Lifecycle ====================
|
||||
|
||||
onMounted(async () => {
|
||||
// Read promo code from URL parameter
|
||||
try {
|
||||
const settings = await getPublicSettings()
|
||||
registrationEnabled.value = settings.registration_enabled
|
||||
emailVerifyEnabled.value = settings.email_verify_enabled
|
||||
promoCodeEnabled.value = settings.promo_code_enabled
|
||||
turnstileEnabled.value = settings.turnstile_enabled
|
||||
turnstileSiteKey.value = settings.turnstile_site_key || ''
|
||||
siteName.value = settings.site_name || 'Sub2API'
|
||||
linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled
|
||||
|
||||
// Read promo code from URL parameter only if promo code is enabled
|
||||
if (promoCodeEnabled.value) {
|
||||
const promoParam = route.query.promo as string
|
||||
if (promoParam) {
|
||||
formData.promo_code = promoParam
|
||||
// Validate the promo code from URL
|
||||
await validatePromoCodeDebounced(promoParam)
|
||||
}
|
||||
|
||||
try {
|
||||
const settings = await getPublicSettings()
|
||||
registrationEnabled.value = settings.registration_enabled
|
||||
emailVerifyEnabled.value = settings.email_verify_enabled
|
||||
turnstileEnabled.value = settings.turnstile_enabled
|
||||
turnstileSiteKey.value = settings.turnstile_site_key || ''
|
||||
siteName.value = settings.site_name || 'Sub2API'
|
||||
linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to load public settings:', error)
|
||||
} finally {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { defineConfig, Plugin } from 'vite'
|
||||
import { defineConfig, loadEnv, Plugin } from 'vite'
|
||||
import vue from '@vitejs/plugin-vue'
|
||||
import checker from 'vite-plugin-checker'
|
||||
import { resolve } from 'path'
|
||||
@@ -7,9 +7,7 @@ import { resolve } from 'path'
|
||||
* Vite 插件:开发模式下注入公开配置到 index.html
|
||||
* 与生产模式的后端注入行为保持一致,消除闪烁
|
||||
*/
|
||||
function injectPublicSettings(): Plugin {
|
||||
const backendUrl = process.env.VITE_DEV_PROXY_TARGET || 'http://localhost:8080'
|
||||
|
||||
function injectPublicSettings(backendUrl: string): Plugin {
|
||||
return {
|
||||
name: 'inject-public-settings',
|
||||
transformIndexHtml: {
|
||||
@@ -35,14 +33,20 @@ function injectPublicSettings(): Plugin {
|
||||
}
|
||||
}
|
||||
|
||||
export default defineConfig({
|
||||
export default defineConfig(({ mode }) => {
|
||||
// 加载环境变量
|
||||
const env = loadEnv(mode, process.cwd(), '')
|
||||
const backendUrl = env.VITE_DEV_PROXY_TARGET || 'http://localhost:8080'
|
||||
const devPort = Number(env.VITE_DEV_PORT || 3000)
|
||||
|
||||
return {
|
||||
plugins: [
|
||||
vue(),
|
||||
checker({
|
||||
typescript: true,
|
||||
vueTsc: true
|
||||
}),
|
||||
injectPublicSettings()
|
||||
injectPublicSettings(backendUrl)
|
||||
],
|
||||
resolve: {
|
||||
alias: {
|
||||
@@ -104,16 +108,17 @@ export default defineConfig({
|
||||
},
|
||||
server: {
|
||||
host: '0.0.0.0',
|
||||
port: Number(process.env.VITE_DEV_PORT || 3000),
|
||||
port: devPort,
|
||||
proxy: {
|
||||
'/api': {
|
||||
target: process.env.VITE_DEV_PROXY_TARGET || 'http://localhost:8080',
|
||||
target: backendUrl,
|
||||
changeOrigin: true
|
||||
},
|
||||
'/setup': {
|
||||
target: process.env.VITE_DEV_PROXY_TARGET || 'http://localhost:8080',
|
||||
target: backendUrl,
|
||||
changeOrigin: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user