Merge pull request #343 from mt21625457/main

fix(调度): 完善粘性会话清理与账号调度刷新 和 启用 OpenAI OAuth HTTP/2 并修复清理任务 lint
This commit is contained in:
Wesley Liddick
2026-01-20 16:05:53 +08:00
committed by GitHub
20 changed files with 3401 additions and 261 deletions

View File

@@ -39,9 +39,15 @@ import (
// 设计说明:
// - client: Ent 客户端,用于类型安全的 ORM 操作
// - sql: 原生 SQL 执行器,用于复杂查询和批量操作
// - schedulerCache: 调度器缓存,用于在账号状态变更时同步快照
type accountRepository struct {
client *dbent.Client // Ent ORM 客户端
sql sqlExecutor // 原生 SQL 执行接口
// schedulerCache 用于在账号状态变更时主动同步快照到缓存,
// 确保粘性会话能及时感知账号不可用状态。
// Used to proactively sync account snapshot to cache when status changes,
// ensuring sticky sessions can promptly detect unavailable accounts.
schedulerCache service.SchedulerCache
}
type tempUnschedSnapshot struct {
@@ -51,14 +57,14 @@ type tempUnschedSnapshot struct {
// NewAccountRepository 创建账户仓储实例。
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB) service.AccountRepository {
return newAccountRepositoryWithSQL(client, sqlDB)
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
return newAccountRepositoryWithSQL(client, sqlDB, schedulerCache)
}
// newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。
// 这种设计便于单元测试时注入 mock 对象。
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accountRepository {
return &accountRepository{client: client, sql: sqlq}
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor, schedulerCache service.SchedulerCache) *accountRepository {
return &accountRepository{client: client, sql: sqlq, schedulerCache: schedulerCache}
}
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
@@ -356,6 +362,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
}
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
r.syncSchedulerAccountSnapshot(ctx, account.ID)
}
return nil
}
@@ -540,9 +549,32 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
}
r.syncSchedulerAccountSnapshot(ctx, id)
return nil
}
// syncSchedulerAccountSnapshot 在账号状态变更时主动同步快照到调度器缓存。
// 当账号被设置为错误、禁用、不可调度或临时不可调度时调用,
// 确保调度器和粘性会话逻辑能及时感知账号的最新状态,避免继续使用不可用账号。
//
// syncSchedulerAccountSnapshot proactively syncs account snapshot to scheduler cache
// when account status changes. Called when account is set to error, disabled,
// unschedulable, or temporarily unschedulable, ensuring scheduler and sticky session
// logic can promptly detect the latest account state and avoid using unavailable accounts.
func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, accountID int64) {
if r == nil || r.schedulerCache == nil || accountID <= 0 {
return
}
account, err := r.GetByID(ctx, accountID)
if err != nil {
log.Printf("[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err)
return
}
if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
log.Printf("[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err)
}
}
func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
@@ -873,6 +905,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
}
r.syncSchedulerAccountSnapshot(ctx, id)
return nil
}
@@ -992,6 +1025,9 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
}
if !schedulable {
r.syncSchedulerAccountSnapshot(ctx, id)
}
return nil
}
@@ -1146,6 +1182,18 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
}
shouldSync := false
if updates.Status != nil && (*updates.Status == service.StatusError || *updates.Status == service.StatusDisabled) {
shouldSync = true
}
if updates.Schedulable != nil && !*updates.Schedulable {
shouldSync = true
}
if shouldSync {
for _, id := range ids {
r.syncSchedulerAccountSnapshot(ctx, id)
}
}
}
return rows, nil
}

View File

@@ -21,11 +21,56 @@ type AccountRepoSuite struct {
repo *accountRepository
}
type schedulerCacheRecorder struct {
setAccounts []*service.Account
}
func (s *schedulerCacheRecorder) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
return nil, false, nil
}
func (s *schedulerCacheRecorder) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
return nil
}
func (s *schedulerCacheRecorder) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
return nil, nil
}
func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *service.Account) error {
s.setAccounts = append(s.setAccounts, account)
return nil
}
func (s *schedulerCacheRecorder) DeleteAccount(ctx context.Context, accountID int64) error {
return nil
}
func (s *schedulerCacheRecorder) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
return nil
}
func (s *schedulerCacheRecorder) TryLockBucket(ctx context.Context, bucket service.SchedulerBucket, ttl time.Duration) (bool, error) {
return true, nil
}
func (s *schedulerCacheRecorder) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
return nil, nil
}
func (s *schedulerCacheRecorder) GetOutboxWatermark(ctx context.Context) (int64, error) {
return 0, nil
}
func (s *schedulerCacheRecorder) SetOutboxWatermark(ctx context.Context, id int64) error {
return nil
}
func (s *AccountRepoSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.client = tx.Client()
s.repo = newAccountRepositoryWithSQL(s.client, tx)
s.repo = newAccountRepositoryWithSQL(s.client, tx, nil)
}
func TestAccountRepoSuite(t *testing.T) {
@@ -73,6 +118,20 @@ func (s *AccountRepoSuite) TestUpdate() {
s.Require().Equal("updated", got.Name)
}
func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "sync-update", Status: service.StatusActive, Schedulable: true})
cacheRecorder := &schedulerCacheRecorder{}
s.repo.schedulerCache = cacheRecorder
account.Status = service.StatusDisabled
err := s.repo.Update(s.ctx, account)
s.Require().NoError(err, "Update")
s.Require().Len(cacheRecorder.setAccounts, 1)
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status)
}
func (s *AccountRepoSuite) TestDelete() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
@@ -174,7 +233,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
// 每个 case 重新获取隔离资源
tx := testEntTx(s.T())
client := tx.Client()
repo := newAccountRepositoryWithSQL(client, tx)
repo := newAccountRepositoryWithSQL(client, tx, nil)
ctx := context.Background()
tt.setup(client)
@@ -365,12 +424,38 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
func (s *AccountRepoSuite) TestSetSchedulable() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-sched", Schedulable: true})
cacheRecorder := &schedulerCacheRecorder{}
s.repo.schedulerCache = cacheRecorder
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().False(got.Schedulable)
s.Require().Len(cacheRecorder.setAccounts, 1)
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
}
func (s *AccountRepoSuite) TestBulkUpdate_SyncSchedulerSnapshotOnDisabled() {
account1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-1", Status: service.StatusActive, Schedulable: true})
account2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-2", Status: service.StatusActive, Schedulable: true})
cacheRecorder := &schedulerCacheRecorder{}
s.repo.schedulerCache = cacheRecorder
disabled := service.StatusDisabled
rows, err := s.repo.BulkUpdate(s.ctx, []int64{account1.ID, account2.ID}, service.AccountBulkUpdate{
Status: &disabled,
})
s.Require().NoError(err)
s.Require().Equal(int64(2), rows)
s.Require().Len(cacheRecorder.setAccounts, 2)
ids := map[int64]struct{}{}
for _, acc := range cacheRecorder.setAccounts {
ids[acc.ID] = struct{}{}
}
s.Require().Contains(ids, account1.ID)
s.Require().Contains(ids, account2.ID)
}
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---

View File

@@ -39,3 +39,15 @@ func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, ses
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Expire(ctx, key, ttl).Err()
}
// DeleteSessionAccountID 删除粘性会话与账号的绑定关系。
// 当检测到绑定的账号不可用(如状态错误、禁用、不可调度等)时调用,
// 以便下次请求能够重新选择可用账号。
//
// DeleteSessionAccountID removes the sticky session binding for the given session.
// Called when the bound account becomes unavailable (e.g., error status, disabled,
// or unschedulable), allowing subsequent requests to select a new available account.
func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -78,6 +78,19 @@ func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
}
func (s *GatewayCacheSuite) TestDeleteSessionAccountID() {
sessionID := "openai:s4"
accountID := int64(102)
groupID := int64(1)
sessionTTL := 1 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
require.NoError(s.T(), s.cache.DeleteSessionAccountID(s.ctx, groupID, sessionID), "DeleteSessionAccountID")
_, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
}
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
sessionID := "corrupted"
groupID := int64(1)

View File

@@ -24,7 +24,7 @@ func (s *GatewayRoutingSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.client = tx.Client()
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx)
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx, nil)
}
func TestGatewayRoutingSuite(t *testing.T) {

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/url"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
@@ -21,7 +22,7 @@ type openaiOAuthService struct {
}
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL)
client := createOpenAIReqClient(s.tokenURL, proxyURL)
if redirectURI == "" {
redirectURI = openai.DefaultRedirectURI
@@ -54,7 +55,7 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
}
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL)
client := createOpenAIReqClient(s.tokenURL, proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
@@ -81,9 +82,14 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
return &tokenResp, nil
}
func createOpenAIReqClient(proxyURL string) *req.Client {
func createOpenAIReqClient(tokenURL, proxyURL string) *req.Client {
forceHTTP2 := false
if parsedURL, err := url.Parse(tokenURL); err == nil {
forceHTTP2 = strings.EqualFold(parsedURL.Scheme, "https")
}
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 60 * time.Second,
ProxyURL: proxyURL,
Timeout: 120 * time.Second,
ForceHTTP2: forceHTTP2,
})
}

View File

@@ -244,6 +244,13 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
require.ErrorContains(s.T(), err, "status 401")
}
func TestNewOpenAIOAuthClient_DefaultTokenURL(t *testing.T) {
client := NewOpenAIOAuthClient()
svc, ok := client.(*openaiOAuthService)
require.True(t, ok)
require.Equal(t, openai.TokenURL, svc.tokenURL)
}
func TestOpenAIOAuthServiceSuite(t *testing.T) {
suite.Run(t, new(OpenAIOAuthServiceSuite))
}

View File

@@ -14,6 +14,7 @@ type reqClientOptions struct {
ProxyURL string // 代理 URL支持 http/https/socks5
Timeout time.Duration // 请求超时时间
Impersonate bool // 是否模拟 Chrome 浏览器指纹
ForceHTTP2 bool // 是否强制使用 HTTP/2
}
// sharedReqClients 存储按配置参数缓存的 req 客户端实例
@@ -41,6 +42,9 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
}
client := req.C().SetTimeout(opts.Timeout)
if opts.ForceHTTP2 {
client = client.EnableForceHTTP2()
}
if opts.Impersonate {
client = client.ImpersonateChrome()
}
@@ -56,9 +60,10 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
}
func buildReqClientKey(opts reqClientOptions) string {
return fmt.Sprintf("%s|%s|%t",
return fmt.Sprintf("%s|%s|%t|%t",
strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(),
opts.Impersonate,
opts.ForceHTTP2,
)
}

View File

@@ -0,0 +1,102 @@
package repository
import (
"reflect"
"sync"
"testing"
"time"
"unsafe"
"github.com/imroc/req/v3"
"github.com/stretchr/testify/require"
)
func forceHTTPVersion(t *testing.T, client *req.Client) string {
t.Helper()
transport := client.GetTransport()
field := reflect.ValueOf(transport).Elem().FieldByName("forceHttpVersion")
require.True(t, field.IsValid(), "forceHttpVersion field not found")
require.True(t, field.CanAddr(), "forceHttpVersion field not addressable")
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().String()
}
func TestGetSharedReqClient_ForceHTTP2SeparatesCache(t *testing.T) {
sharedReqClients = sync.Map{}
base := reqClientOptions{
ProxyURL: "http://proxy.local:8080",
Timeout: time.Second,
}
clientDefault := getSharedReqClient(base)
force := base
force.ForceHTTP2 = true
clientForce := getSharedReqClient(force)
require.NotSame(t, clientDefault, clientForce)
require.NotEqual(t, buildReqClientKey(base), buildReqClientKey(force))
}
func TestGetSharedReqClient_ReuseCachedClient(t *testing.T) {
sharedReqClients = sync.Map{}
opts := reqClientOptions{
ProxyURL: "http://proxy.local:8080",
Timeout: 2 * time.Second,
}
first := getSharedReqClient(opts)
second := getSharedReqClient(opts)
require.Same(t, first, second)
}
func TestGetSharedReqClient_IgnoresNonClientCache(t *testing.T) {
sharedReqClients = sync.Map{}
opts := reqClientOptions{
ProxyURL: " http://proxy.local:8080 ",
Timeout: 3 * time.Second,
}
key := buildReqClientKey(opts)
sharedReqClients.Store(key, "invalid")
client := getSharedReqClient(opts)
require.NotNil(t, client)
loaded, ok := sharedReqClients.Load(key)
require.True(t, ok)
require.IsType(t, "invalid", loaded)
}
func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) {
sharedReqClients = sync.Map{}
opts := reqClientOptions{
ProxyURL: " http://proxy.local:8080 ",
Timeout: 4 * time.Second,
Impersonate: true,
}
client := getSharedReqClient(opts)
require.NotNil(t, client)
require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts))
}
func TestCreateOpenAIReqClient_ForceHTTP2Enabled(t *testing.T) {
sharedReqClients = sync.Map{}
client := createOpenAIReqClient("https://auth.openai.com/oauth/token", "http://proxy.local:8080")
require.Equal(t, "2", forceHTTPVersion(t, client))
}
func TestCreateOpenAIReqClient_ForceHTTP2DisabledForHTTP(t *testing.T) {
sharedReqClients = sync.Map{}
client := createOpenAIReqClient("http://localhost/oauth/token", "http://proxy.local:8080")
require.Equal(t, "", forceHTTPVersion(t, client))
}
func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) {
sharedReqClients = sync.Map{}
client := createOpenAIReqClient("https://auth.openai.com/oauth/token", "http://proxy.local:8080")
require.Equal(t, 120*time.Second, client.GetClient().Timeout)
}
func TestCreateGeminiReqClient_ForceHTTP2Disabled(t *testing.T) {
sharedReqClients = sync.Map{}
client := createGeminiReqClient("http://proxy.local:8080")
require.Equal(t, "", forceHTTPVersion(t, client))
}

View File

@@ -19,7 +19,7 @@ func TestSchedulerSnapshotOutboxReplay(t *testing.T) {
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox")
accountRepo := newAccountRepositoryWithSQL(client, integrationDB)
accountRepo := newAccountRepositoryWithSQL(client, integrationDB, nil)
outboxRepo := NewSchedulerOutboxRepository(integrationDB)
cache := NewSchedulerCache(rdb)