Merge branch 'main' into feat/api-key-ip-restriction

This commit is contained in:
Edric Li
2026-01-10 18:49:50 +08:00
108 changed files with 18223 additions and 144 deletions

View File

@@ -49,10 +49,12 @@ type AccountRepository interface {
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
ClearTempUnschedulable(ctx context.Context, id int64) error
ClearRateLimit(ctx context.Context, id int64) error
ClearAntigravityQuotaScopes(ctx context.Context, id int64) error
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)

View File

@@ -139,6 +139,10 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt
panic("unexpected SetRateLimited call")
}
func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
panic("unexpected SetAntigravityQuotaScopeLimit call")
}
func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
panic("unexpected SetOverloaded call")
}
@@ -155,6 +159,10 @@ func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error {
panic("unexpected ClearRateLimit call")
}
func (s *accountRepoStub) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
panic("unexpected ClearAntigravityQuotaScopes call")
}
func (s *accountRepoStub) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
panic("unexpected UpdateSessionWindow call")
}

View File

@@ -576,18 +576,33 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
return fmt.Errorf("cannot set self as fallback group")
}
// 检查降级分组是否存在
fallbackGroup, err := s.groupRepo.GetByID(ctx, fallbackGroupID)
if err != nil {
return fmt.Errorf("fallback group not found: %w", err)
}
visited := map[int64]struct{}{}
nextID := fallbackGroupID
for {
if _, seen := visited[nextID]; seen {
return fmt.Errorf("fallback group cycle detected")
}
visited[nextID] = struct{}{}
if currentGroupID > 0 && nextID == currentGroupID {
return fmt.Errorf("fallback group cycle detected")
}
// 降级分组不能启用 claude_code_only否则会造成死循环
if fallbackGroup.ClaudeCodeOnly {
return fmt.Errorf("fallback group cannot have claude_code_only enabled")
}
// 检查降级分组是否存在
fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, nextID)
if err != nil {
return fmt.Errorf("fallback group not found: %w", err)
}
return nil
// 降级分组不能启用 claude_code_only否则会造成死循环
if nextID == fallbackGroupID && fallbackGroup.ClaudeCodeOnly {
return fmt.Errorf("fallback group cannot have claude_code_only enabled")
}
if fallbackGroup.FallbackGroupID == nil {
return nil
}
nextID = *fallbackGroup.FallbackGroupID
}
}
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {

View File

@@ -107,6 +107,10 @@ func (s *groupRepoStub) GetByID(ctx context.Context, id int64) (*Group, error) {
panic("unexpected GetByID call")
}
func (s *groupRepoStub) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
panic("unexpected GetByIDLite call")
}
func (s *groupRepoStub) Update(ctx context.Context, group *Group) error {
panic("unexpected Update call")
}

View File

@@ -45,6 +45,13 @@ func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, err
return s.getByID, nil
}
func (s *groupRepoStubForAdmin) GetByIDLite(_ context.Context, _ int64) (*Group, error) {
if s.getErr != nil {
return nil, s.getErr
}
return s.getByID, nil
}
func (s *groupRepoStubForAdmin) Delete(_ context.Context, _ int64) error {
panic("unexpected Delete call")
}
@@ -290,3 +297,84 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
require.True(t, *repo.listWithFiltersIsExclusive)
})
}
func TestAdminService_ValidateFallbackGroup_DetectsCycle(t *testing.T) {
groupID := int64(1)
fallbackID := int64(2)
repo := &groupRepoStubForFallbackCycle{
groups: map[int64]*Group{
groupID: {
ID: groupID,
FallbackGroupID: &fallbackID,
},
fallbackID: {
ID: fallbackID,
FallbackGroupID: &groupID,
},
},
}
svc := &adminServiceImpl{groupRepo: repo}
err := svc.validateFallbackGroup(context.Background(), groupID, fallbackID)
require.Error(t, err)
require.Contains(t, err.Error(), "fallback group cycle")
}
type groupRepoStubForFallbackCycle struct {
groups map[int64]*Group
}
func (s *groupRepoStubForFallbackCycle) Create(_ context.Context, _ *Group) error {
panic("unexpected Create call")
}
func (s *groupRepoStubForFallbackCycle) Update(_ context.Context, _ *Group) error {
panic("unexpected Update call")
}
func (s *groupRepoStubForFallbackCycle) GetByID(ctx context.Context, id int64) (*Group, error) {
return s.GetByIDLite(ctx, id)
}
func (s *groupRepoStubForFallbackCycle) GetByIDLite(_ context.Context, id int64) (*Group, error) {
if g, ok := s.groups[id]; ok {
return g, nil
}
return nil, ErrGroupNotFound
}
func (s *groupRepoStubForFallbackCycle) Delete(_ context.Context, _ int64) error {
panic("unexpected Delete call")
}
func (s *groupRepoStubForFallbackCycle) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
panic("unexpected DeleteCascade call")
}
func (s *groupRepoStubForFallbackCycle) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *groupRepoStubForFallbackCycle) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *groupRepoStubForFallbackCycle) ListActive(_ context.Context) ([]Group, error) {
panic("unexpected ListActive call")
}
func (s *groupRepoStubForFallbackCycle) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
panic("unexpected ListActiveByPlatform call")
}
func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string) (bool, error) {
panic("unexpected ExistsByName call")
}
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) {
panic("unexpected GetAccountCount call")
}
func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
panic("unexpected DeleteAccountGroupsByGroupID call")
}

View File

@@ -93,6 +93,7 @@ var antigravityPrefixMapping = []struct {
// 长前缀优先
{"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → 3-pro-image
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
{"gemini-3-flash", "gemini-3-flash"}, // gemini-3-flash-preview 等 → gemini-3-flash
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
{"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet
@@ -502,6 +503,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
originalModel := claudeReq.Model
mappedModel := s.getMappedModel(account, claudeReq.Model)
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
// 获取 access_token
if s.tokenProvider == nil {
@@ -603,7 +605,7 @@ urlFallbackLoop:
}
// 所有重试都失败,标记限流状态
if resp.StatusCode == 429 {
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
}
// 最后一次尝试也失败
resp = &http.Response{
@@ -696,7 +698,7 @@ urlFallbackLoop:
// 处理错误响应(重试后仍失败或不触发重试)
if resp.StatusCode >= 400 {
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
@@ -1021,6 +1023,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if len(body) == 0 {
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
}
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
// 解析请求以获取 image_size用于图片计费
imageSize := s.extractImageSize(body)
@@ -1146,7 +1149,7 @@ urlFallbackLoop:
}
// 所有重试都失败,标记限流状态
if resp.StatusCode == 429 {
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
}
resp = &http.Response{
StatusCode: resp.StatusCode,
@@ -1200,7 +1203,7 @@ urlFallbackLoop:
goto handleSuccess
}
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
@@ -1314,7 +1317,7 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
}
}
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte) {
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
if statusCode == 429 {
resetAt := ParseGeminiRateLimitResetTime(body)
@@ -1325,13 +1328,23 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre
defaultDur = 5 * time.Minute
}
ra := time.Now().Add(defaultDur)
log.Printf("%s status=429 rate_limited reset_in=%v (fallback)", prefix, defaultDur)
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra)
log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur)
if quotaScope == "" {
return
}
if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, ra); err != nil {
log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
}
return
}
resetTime := time.Unix(*resetAt, 0)
log.Printf("%s status=429 rate_limited reset_at=%v reset_in=%v", prefix, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
_ = s.accountRepo.SetRateLimited(ctx, account.ID, resetTime)
log.Printf("%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v", prefix, quotaScope, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
if quotaScope == "" {
return
}
if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, resetTime); err != nil {
log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
}
return
}
// 其他错误码继续使用 rateLimitService

View File

@@ -0,0 +1,88 @@
package service
import (
"strings"
"time"
)
const antigravityQuotaScopesKey = "antigravity_quota_scopes"
// AntigravityQuotaScope 表示 Antigravity 的配额域
type AntigravityQuotaScope string
const (
AntigravityQuotaScopeClaude AntigravityQuotaScope = "claude"
AntigravityQuotaScopeGeminiText AntigravityQuotaScope = "gemini_text"
AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image"
)
// resolveAntigravityQuotaScope 根据模型名称解析配额域
func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
model := normalizeAntigravityModelName(requestedModel)
if model == "" {
return "", false
}
switch {
case strings.HasPrefix(model, "claude-"):
return AntigravityQuotaScopeClaude, true
case strings.HasPrefix(model, "gemini-"):
if isImageGenerationModel(model) {
return AntigravityQuotaScopeGeminiImage, true
}
return AntigravityQuotaScopeGeminiText, true
default:
return "", false
}
}
func normalizeAntigravityModelName(model string) string {
normalized := strings.ToLower(strings.TrimSpace(model))
normalized = strings.TrimPrefix(normalized, "models/")
return normalized
}
// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度
func (a *Account) IsSchedulableForModel(requestedModel string) bool {
if a == nil {
return false
}
if !a.IsSchedulable() {
return false
}
if a.Platform != PlatformAntigravity {
return true
}
scope, ok := resolveAntigravityQuotaScope(requestedModel)
if !ok {
return true
}
resetAt := a.antigravityQuotaScopeResetAt(scope)
if resetAt == nil {
return true
}
now := time.Now()
return !now.Before(*resetAt)
}
func (a *Account) antigravityQuotaScopeResetAt(scope AntigravityQuotaScope) *time.Time {
if a == nil || a.Extra == nil || scope == "" {
return nil
}
rawScopes, ok := a.Extra[antigravityQuotaScopesKey].(map[string]any)
if !ok {
return nil
}
rawScope, ok := rawScopes[string(scope)].(map[string]any)
if !ok {
return nil
}
resetAtRaw, ok := rawScope["rate_limit_reset_at"].(string)
if !ok || strings.TrimSpace(resetAtRaw) == "" {
return nil
}
resetAt, err := time.Parse(time.RFC3339, resetAtRaw)
if err != nil {
return nil
}
return &resetAt
}

View File

@@ -52,6 +52,7 @@ type AuthService struct {
emailService *EmailService
turnstileService *TurnstileService
emailQueueService *EmailQueueService
promoService *PromoService
}
// NewAuthService 创建认证服务实例
@@ -62,6 +63,7 @@ func NewAuthService(
emailService *EmailService,
turnstileService *TurnstileService,
emailQueueService *EmailQueueService,
promoService *PromoService,
) *AuthService {
return &AuthService{
userRepo: userRepo,
@@ -70,16 +72,17 @@ func NewAuthService(
emailService: emailService,
turnstileService: turnstileService,
emailQueueService: emailQueueService,
promoService: promoService,
}
}
// Register 用户注册返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "")
return s.RegisterWithVerification(ctx, email, password, "", "")
}
// RegisterWithVerification 用户注册支持邮件验证返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) {
// RegisterWithVerification 用户注册(支持邮件验证和优惠码返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode string) (string, *User, error) {
// 检查是否开放注册默认关闭settingService 未配置时不允许注册)
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled
@@ -150,6 +153,19 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, ErrServiceUnavailable
}
// 应用优惠码(如果提供)
if promoCode != "" && s.promoService != nil {
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
// 优惠码应用失败不影响注册,只记录日志
log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err)
} else {
// 重新获取用户信息以获取更新后的余额
if updatedUser, err := s.userRepo.GetByID(ctx, user.ID); err == nil {
user = updatedUser
}
}
}
// 生成token
token, err := s.GenerateToken(user)
if err != nil {

View File

@@ -100,6 +100,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
emailService,
nil,
nil,
nil, // promoService
)
}
@@ -131,7 +132,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
}, nil)
// 应返回服务不可用错误,而不是允许绕过验证
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code")
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "")
require.ErrorIs(t, err, ErrServiceUnavailable)
}
@@ -143,7 +144,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true",
}, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "")
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "")
require.ErrorIs(t, err, ErrEmailVerifyRequired)
}
@@ -157,7 +158,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true",
}, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong")
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "")
require.ErrorIs(t, err, ErrInvalidVerifyCode)
require.ErrorContains(t, err, "verify code")
}

View File

@@ -38,6 +38,12 @@ const (
RedeemTypeSubscription = "subscription"
)
// PromoCode status constants
const (
PromoCodeStatusActive = "active"
PromoCodeStatusDisabled = "disabled"
)
// Admin adjustment type constants
const (
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额

View File

@@ -9,6 +9,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
@@ -23,9 +24,11 @@ type mockAccountRepoForPlatform struct {
accounts []Account
accountsByID map[int64]*Account
listPlatformFunc func(ctx context.Context, platform string) ([]Account, error)
getByIDCalls int
}
func (m *mockAccountRepoForPlatform) GetByID(ctx context.Context, id int64) (*Account, error) {
m.getByIDCalls++
if acc, ok := m.accountsByID[id]; ok {
return acc, nil
}
@@ -136,6 +139,9 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx co
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil
}
@@ -148,6 +154,9 @@ func (m *mockAccountRepoForPlatform) ClearTempUnschedulable(ctx context.Context,
func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForPlatform) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForPlatform) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil
}
@@ -185,6 +194,56 @@ func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, gro
return nil
}
type mockGroupRepoForGateway struct {
groups map[int64]*Group
getByIDCalls int
getByIDLiteCalls int
}
func (m *mockGroupRepoForGateway) GetByID(ctx context.Context, id int64) (*Group, error) {
m.getByIDCalls++
if g, ok := m.groups[id]; ok {
return g, nil
}
return nil, ErrGroupNotFound
}
func (m *mockGroupRepoForGateway) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
m.getByIDLiteCalls++
if g, ok := m.groups[id]; ok {
return g, nil
}
return nil, ErrGroupNotFound
}
func (m *mockGroupRepoForGateway) Create(ctx context.Context, group *Group) error { return nil }
func (m *mockGroupRepoForGateway) Update(ctx context.Context, group *Group) error { return nil }
func (m *mockGroupRepoForGateway) Delete(ctx context.Context, id int64) error { return nil }
func (m *mockGroupRepoForGateway) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
return nil, nil
}
func (m *mockGroupRepoForGateway) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockGroupRepoForGateway) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockGroupRepoForGateway) ListActive(ctx context.Context) ([]Group, error) {
return nil, nil
}
func (m *mockGroupRepoForGateway) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) {
return nil, nil
}
func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil
}
func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
}
func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
}
func ptr[T any](v T) *T {
return &v
}
@@ -891,6 +950,74 @@ func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, acc
return m.accountWaitCounts[accountID], nil
}
type mockConcurrencyCache struct {
acquireAccountCalls int
loadBatchCalls int
}
func (m *mockConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
m.acquireAccountCalls++
return true, nil
}
func (m *mockConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
return nil
}
func (m *mockConcurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (m *mockConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return true, nil
}
func (m *mockConcurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
return nil
}
func (m *mockConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (m *mockConcurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
}
func (m *mockConcurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
return nil
}
func (m *mockConcurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
return 0, nil
}
func (m *mockConcurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
return true, nil
}
func (m *mockConcurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
return nil
}
func (m *mockConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
m.loadBatchCalls++
result := make(map[int64]*AccountLoadInfo, len(accounts))
for _, acc := range accounts {
result[acc.ID] = &AccountLoadInfo{
AccountID: acc.ID,
CurrentConcurrency: 0,
WaitingCount: 0,
LoadRate: 0,
}
}
return result, nil
}
func (m *mockConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
return nil
}
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
ctx := context.Background()
@@ -989,6 +1116,78 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
require.Equal(t, int64(2), result.Account.ID, "不应选择被排除的账号")
})
t.Run("粘性命中-不调用GetByID", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"sticky": 1},
}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
concurrencyCache := &mockConcurrencyCache{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(1), result.Account.ID)
require.Equal(t, 0, repo.getByIDCalls, "粘性命中不应调用GetByID")
require.Equal(t, 0, concurrencyCache.loadBatchCalls, "粘性命中应在负载批量查询前返回")
})
t.Run("粘性账号不在候选集-回退负载感知选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"sticky": 1},
}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
concurrencyCache := &mockConcurrencyCache{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "粘性账号不在候选集时应回退到可用账号")
require.Equal(t, 0, repo.getByIDCalls, "粘性账号缺失不应回退到GetByID")
require.Equal(t, 1, concurrencyCache.loadBatchCalls, "应继续进行负载批量查询")
})
t.Run("无可用账号-返回错误", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{},
@@ -1013,3 +1212,190 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
require.Contains(t, err.Error(), "no available accounts")
})
}
func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) {
ctx := context.Background()
groupID := int64(42)
group := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
}
ctx = context.WithValue(ctx, ctxkey.Group, group)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{groupID: group},
}
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cfg: testConfig(),
}
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, 0, groupRepo.getByIDCalls)
require.Equal(t, 0, groupRepo.getByIDLiteCalls)
}
func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T) {
ctx := context.Background()
groupID := int64(42)
ctxGroup := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
}
ctx = context.WithValue(ctx, ctxkey.Group, ctxGroup)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
group := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{groupID: group},
}
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cfg: testConfig(),
}
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, 0, groupRepo.getByIDCalls)
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
}
func TestGatewayService_GroupContext_OverwritesInvalidContextGroup(t *testing.T) {
groupID := int64(42)
invalidGroup := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
}
hydratedGroup := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
}
ctx := context.WithValue(context.Background(), ctxkey.Group, invalidGroup)
svc := &GatewayService{}
ctx = svc.withGroupContext(ctx, hydratedGroup)
got, ok := ctx.Value(ctxkey.Group).(*Group)
require.True(t, ok)
require.Same(t, hydratedGroup, got)
}
func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) {
ctx := context.Background()
groupID := int64(10)
fallbackID := int64(11)
group := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
ClaudeCodeOnly: true,
FallbackGroupID: &fallbackID,
Hydrated: true,
}
fallbackGroup := &Group{
ID: fallbackID,
Platform: PlatformAnthropic,
Status: StatusActive,
}
ctx = context.WithValue(ctx, ctxkey.Group, group)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{fallbackID: fallbackGroup},
}
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cfg: testConfig(),
}
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, 0, groupRepo.getByIDCalls)
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
}
func TestGatewayService_ResolveGatewayGroup_DetectsFallbackCycle(t *testing.T) {
ctx := context.Background()
groupID := int64(10)
fallbackID := int64(11)
group := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
ClaudeCodeOnly: true,
FallbackGroupID: &fallbackID,
}
fallbackGroup := &Group{
ID: fallbackID,
Platform: PlatformAnthropic,
Status: StatusActive,
ClaudeCodeOnly: true,
FallbackGroupID: &groupID,
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: group,
fallbackID: fallbackGroup,
},
}
svc := &GatewayService{
groupRepo: groupRepo,
}
gotGroup, gotID, err := svc.resolveGatewayGroup(ctx, &groupID)
require.Error(t, err)
require.Nil(t, gotGroup)
require.Nil(t, gotID)
require.Contains(t, err.Error(), "fallback group cycle")
}

View File

@@ -33,7 +33,7 @@ const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL = time.Hour // 粘性会话TTL
defaultMaxLineSize = 10 * 1024 * 1024
defaultMaxLineSize = 40 * 1024 * 1024
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量
)
@@ -361,27 +361,13 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
if hasForcePlatform && forcePlatform != "" {
platform = forcePlatform
} else if groupID != nil {
// 根据分组 platform 决定查询哪种账号
group, err := s.groupRepo.GetByID(ctx, *groupID)
group, resolvedGroupID, err := s.resolveGatewayGroup(ctx, groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
return nil, err
}
groupID = resolvedGroupID
ctx = s.withGroupContext(ctx, group)
platform = group.Platform
// 检查 Claude Code 客户端限制
if group.ClaudeCodeOnly {
isClaudeCode := IsClaudeCodeClient(ctx)
if !isClaudeCode {
// 非 Claude Code 客户端,检查是否有降级分组
if group.FallbackGroupID != nil {
// 使用降级分组重新调度
fallbackGroupID := *group.FallbackGroupID
return s.SelectAccountForModelWithExclusions(ctx, &fallbackGroupID, sessionHash, requestedModel, excludedIDs)
}
// 无降级分组,拒绝访问
return nil, ErrClaudeCodeOnly
}
}
} else {
// 无分组时只使用原生 anthropic 平台
platform = PlatformAnthropic
@@ -409,10 +395,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
// 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组)
groupID, err := s.checkClaudeCodeRestriction(ctx, groupID)
group, groupID, err := s.checkClaudeCodeRestriction(ctx, groupID)
if err != nil {
return nil, err
}
ctx = s.withGroupContext(ctx, group)
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
@@ -452,7 +439,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}, nil
}
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID)
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID, group)
if err != nil {
return nil, err
}
@@ -478,10 +465,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && s.isAccountInGroup(account, groupID) &&
// 粘性命中仅在当前可调度候选集中生效。
accountByID := make(map[int64]*Account, len(accounts))
for i := range accounts {
accountByID[accounts[i].ID] = &accounts[i]
}
account, ok := accountByID[accountID]
if ok && s.isAccountInGroup(account, groupID) &&
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulable() &&
account.IsSchedulableForModel(requestedModel) &&
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
@@ -519,6 +511,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
@@ -652,51 +647,97 @@ func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
}
}
func (s *GatewayService) withGroupContext(ctx context.Context, group *Group) context.Context {
if !IsGroupContextValid(group) {
return ctx
}
if existing, ok := ctx.Value(ctxkey.Group).(*Group); ok && existing != nil && existing.ID == group.ID && IsGroupContextValid(existing) {
return ctx
}
return context.WithValue(ctx, ctxkey.Group, group)
}
func (s *GatewayService) groupFromContext(ctx context.Context, groupID int64) *Group {
if group, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(group) && group.ID == groupID {
return group
}
return nil
}
func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*Group, error) {
if group := s.groupFromContext(ctx, groupID); group != nil {
return group, nil
}
group, err := s.groupRepo.GetByIDLite(ctx, groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
}
return group, nil
}
func (s *GatewayService) resolveGatewayGroup(ctx context.Context, groupID *int64) (*Group, *int64, error) {
if groupID == nil {
return nil, nil, nil
}
currentID := *groupID
visited := map[int64]struct{}{}
for {
if _, seen := visited[currentID]; seen {
return nil, nil, fmt.Errorf("fallback group cycle detected")
}
visited[currentID] = struct{}{}
group, err := s.resolveGroupByID(ctx, currentID)
if err != nil {
return nil, nil, err
}
if !group.ClaudeCodeOnly || IsClaudeCodeClient(ctx) {
return group, &currentID, nil
}
if group.FallbackGroupID == nil {
return nil, nil, ErrClaudeCodeOnly
}
currentID = *group.FallbackGroupID
}
}
// checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制
// 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端:
// - 有降级分组:返回降级分组的 ID
// - 无降级分组:返回 ErrClaudeCodeOnly 错误
func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*int64, error) {
func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*Group, *int64, error) {
if groupID == nil {
return groupID, nil
return nil, groupID, nil
}
// 强制平台模式不检查 Claude Code 限制
if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform {
return groupID, nil
return nil, groupID, nil
}
group, err := s.groupRepo.GetByID(ctx, *groupID)
group, resolvedID, err := s.resolveGatewayGroup(ctx, groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
return nil, nil, err
}
if !group.ClaudeCodeOnly {
return groupID, nil
}
// 分组启用了 Claude Code 限制
if IsClaudeCodeClient(ctx) {
return groupID, nil
}
// 非 Claude Code 客户端,检查降级分组
if group.FallbackGroupID != nil {
return group.FallbackGroupID, nil
}
return nil, ErrClaudeCodeOnly
return group, resolvedID, nil
}
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) {
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, group *Group) (string, bool, error) {
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
return forcePlatform, true, nil
}
if group != nil {
return group.Platform, false, nil
}
if groupID != nil {
group, err := s.groupRepo.GetByID(ctx, *groupID)
group, err := s.resolveGroupByID(ctx, *groupID)
if err != nil {
return "", false, fmt.Errorf("get group failed: %w", err)
return "", false, err
}
return group.Platform, false, nil
}
@@ -812,7 +853,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
@@ -844,6 +885,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
@@ -901,7 +945,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
// 检查账号分组归属和有效性原生平台直接匹配antigravity 需要启用混合调度
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
@@ -936,6 +980,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}

View File

@@ -86,9 +86,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
platform = forcePlatform
} else if groupID != nil {
// 根据分组 platform 决定查询哪种账号
group, err := s.groupRepo.GetByID(ctx, *groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
var group *Group
if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID {
group = ctxGroup
} else {
var err error
group, err = s.groupRepo.GetByIDLite(ctx, *groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
}
}
platform = group.Platform
} else {
@@ -114,7 +120,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
// 检查账号是否有效原生平台直接匹配antigravity 需要启用混合调度
if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err == nil && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
valid := false
if account.Platform == platform {
valid = true
@@ -172,6 +178,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
if useMixedScheduling && acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}

View File

@@ -8,6 +8,7 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
@@ -121,6 +122,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx cont
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil
}
@@ -131,6 +135,9 @@ func (m *mockAccountRepoForGemini) ClearTempUnschedulable(ctx context.Context, i
return nil
}
func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64) error { return nil }
func (m *mockAccountRepoForGemini) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil
}
@@ -146,10 +153,21 @@ var _ AccountRepository = (*mockAccountRepoForGemini)(nil)
// mockGroupRepoForGemini Gemini 测试用的 group repo mock
type mockGroupRepoForGemini struct {
groups map[int64]*Group
groups map[int64]*Group
getByIDCalls int
getByIDLiteCalls int
}
func (m *mockGroupRepoForGemini) GetByID(ctx context.Context, id int64) (*Group, error) {
m.getByIDCalls++
if g, ok := m.groups[id]; ok {
return g, nil
}
return nil, errors.New("group not found")
}
func (m *mockGroupRepoForGemini) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
m.getByIDLiteCalls++
if g, ok := m.groups[id]; ok {
return g, nil
}
@@ -242,6 +260,77 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiP
require.Equal(t, PlatformGemini, acc.Platform, "无分组时应只返回 gemini 平台账户")
}
func TestGeminiMessagesCompatService_GroupResolution_ReusesContextGroup(t *testing.T) {
ctx := context.Background()
groupID := int64(7)
group := &Group{
ID: groupID,
Platform: PlatformGemini,
Status: StatusActive,
Hydrated: true,
}
ctx = context.WithValue(ctx, ctxkey.Group, group)
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, 0, groupRepo.getByIDCalls)
require.Equal(t, 0, groupRepo.getByIDLiteCalls)
}
func TestGeminiMessagesCompatService_GroupResolution_UsesLiteFetch(t *testing.T) {
ctx := context.Background()
groupID := int64(7)
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{
groups: map[int64]*Group{
groupID: {ID: groupID, Platform: PlatformGemini},
},
}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, 0, groupRepo.getByIDCalls)
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup 测试 antigravity 分组
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup(t *testing.T) {
ctx := context.Background()

View File

@@ -10,6 +10,7 @@ type Group struct {
RateMultiplier float64
IsExclusive bool
Status string
Hydrated bool // indicates the group was loaded from a trusted repository source
SubscriptionType string
DailyLimitUSD *float64
@@ -72,3 +73,20 @@ func (g *Group) GetImagePrice(imageSize string) *float64 {
return g.ImagePrice2K
}
}
// IsGroupContextValid reports whether a group from context has the fields required for routing decisions.
func IsGroupContextValid(group *Group) bool {
if group == nil {
return false
}
if group.ID <= 0 {
return false
}
if !group.Hydrated {
return false
}
if group.Platform == "" || group.Status == "" {
return false
}
return true
}

View File

@@ -16,6 +16,7 @@ var (
type GroupRepository interface {
Create(ctx context.Context, group *Group) error
GetByID(ctx context.Context, id int64) (*Group, error)
GetByIDLite(ctx context.Context, id int64) (*Group, error)
Update(ctx context.Context, group *Group) error
Delete(ctx context.Context, id int64) error
DeleteCascade(ctx context.Context, id int64) ([]int64, error)

View File

@@ -0,0 +1,73 @@
package service
import (
"time"
)
// PromoCode 注册优惠码
type PromoCode struct {
ID int64
Code string
BonusAmount float64
MaxUses int
UsedCount int
Status string
ExpiresAt *time.Time
Notes string
CreatedAt time.Time
UpdatedAt time.Time
// 关联
UsageRecords []PromoCodeUsage
}
// PromoCodeUsage 优惠码使用记录
type PromoCodeUsage struct {
ID int64
PromoCodeID int64
UserID int64
BonusAmount float64
UsedAt time.Time
// 关联
PromoCode *PromoCode
User *User
}
// CanUse 检查优惠码是否可用
func (p *PromoCode) CanUse() bool {
if p.Status != PromoCodeStatusActive {
return false
}
if p.ExpiresAt != nil && time.Now().After(*p.ExpiresAt) {
return false
}
if p.MaxUses > 0 && p.UsedCount >= p.MaxUses {
return false
}
return true
}
// IsExpired 检查是否已过期
func (p *PromoCode) IsExpired() bool {
return p.ExpiresAt != nil && time.Now().After(*p.ExpiresAt)
}
// CreatePromoCodeInput 创建优惠码输入
type CreatePromoCodeInput struct {
Code string
BonusAmount float64
MaxUses int
ExpiresAt *time.Time
Notes string
}
// UpdatePromoCodeInput 更新优惠码输入
type UpdatePromoCodeInput struct {
Code *string
BonusAmount *float64
MaxUses *int
Status *string
ExpiresAt *time.Time
Notes *string
}

View File

@@ -0,0 +1,30 @@
package service
import (
"context"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
// PromoCodeRepository 优惠码仓储接口
type PromoCodeRepository interface {
// 基础 CRUD
Create(ctx context.Context, code *PromoCode) error
GetByID(ctx context.Context, id int64) (*PromoCode, error)
GetByCode(ctx context.Context, code string) (*PromoCode, error)
GetByCodeForUpdate(ctx context.Context, code string) (*PromoCode, error) // 带行锁的查询,用于并发控制
Update(ctx context.Context, code *PromoCode) error
Delete(ctx context.Context, id int64) error
// 列表查询
List(ctx context.Context, params pagination.PaginationParams) ([]PromoCode, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, search string) ([]PromoCode, *pagination.PaginationResult, error)
// 使用记录
CreateUsage(ctx context.Context, usage *PromoCodeUsage) error
GetUsageByPromoCodeAndUser(ctx context.Context, promoCodeID, userID int64) (*PromoCodeUsage, error)
ListUsagesByPromoCode(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]PromoCodeUsage, *pagination.PaginationResult, error)
// 计数操作
IncrementUsedCount(ctx context.Context, id int64) error
}

View File

@@ -0,0 +1,256 @@
package service
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
var (
ErrPromoCodeNotFound = infraerrors.NotFound("PROMO_CODE_NOT_FOUND", "promo code not found")
ErrPromoCodeExpired = infraerrors.BadRequest("PROMO_CODE_EXPIRED", "promo code has expired")
ErrPromoCodeDisabled = infraerrors.BadRequest("PROMO_CODE_DISABLED", "promo code is disabled")
ErrPromoCodeMaxUsed = infraerrors.BadRequest("PROMO_CODE_MAX_USED", "promo code has reached maximum uses")
ErrPromoCodeAlreadyUsed = infraerrors.Conflict("PROMO_CODE_ALREADY_USED", "you have already used this promo code")
ErrPromoCodeInvalid = infraerrors.BadRequest("PROMO_CODE_INVALID", "invalid promo code")
)
// PromoService 优惠码服务
type PromoService struct {
promoRepo PromoCodeRepository
userRepo UserRepository
billingCacheService *BillingCacheService
entClient *dbent.Client
}
// NewPromoService 创建优惠码服务实例
func NewPromoService(
promoRepo PromoCodeRepository,
userRepo UserRepository,
billingCacheService *BillingCacheService,
entClient *dbent.Client,
) *PromoService {
return &PromoService{
promoRepo: promoRepo,
userRepo: userRepo,
billingCacheService: billingCacheService,
entClient: entClient,
}
}
// ValidatePromoCode 验证优惠码(注册前调用)
// 返回 nil, nil 表示空码(不报错)
func (s *PromoService) ValidatePromoCode(ctx context.Context, code string) (*PromoCode, error) {
code = strings.TrimSpace(code)
if code == "" {
return nil, nil // 空码不报错,直接返回
}
promoCode, err := s.promoRepo.GetByCode(ctx, code)
if err != nil {
// 保留原始错误类型,不要统一映射为 NotFound
return nil, err
}
if err := s.validatePromoCodeStatus(promoCode); err != nil {
return nil, err
}
return promoCode, nil
}
// validatePromoCodeStatus 验证优惠码状态
func (s *PromoService) validatePromoCodeStatus(promoCode *PromoCode) error {
if !promoCode.CanUse() {
if promoCode.IsExpired() {
return ErrPromoCodeExpired
}
if promoCode.Status == PromoCodeStatusDisabled {
return ErrPromoCodeDisabled
}
if promoCode.MaxUses > 0 && promoCode.UsedCount >= promoCode.MaxUses {
return ErrPromoCodeMaxUsed
}
return ErrPromoCodeInvalid
}
return nil
}
// ApplyPromoCode 应用优惠码(注册成功后调用)
// 使用事务和行锁确保并发安全
func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code string) error {
code = strings.TrimSpace(code)
if code == "" {
return nil
}
// 开启事务
tx, err := s.entClient.Tx(ctx)
if err != nil {
return fmt.Errorf("begin transaction: %w", err)
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
// 在事务中获取并锁定优惠码记录FOR UPDATE
promoCode, err := s.promoRepo.GetByCodeForUpdate(txCtx, code)
if err != nil {
return err
}
// 在事务中验证优惠码状态
if err := s.validatePromoCodeStatus(promoCode); err != nil {
return err
}
// 在事务中检查用户是否已使用过此优惠码
existing, err := s.promoRepo.GetUsageByPromoCodeAndUser(txCtx, promoCode.ID, userID)
if err != nil {
return fmt.Errorf("check existing usage: %w", err)
}
if existing != nil {
return ErrPromoCodeAlreadyUsed
}
// 增加用户余额
if err := s.userRepo.UpdateBalance(txCtx, userID, promoCode.BonusAmount); err != nil {
return fmt.Errorf("update user balance: %w", err)
}
// 创建使用记录
usage := &PromoCodeUsage{
PromoCodeID: promoCode.ID,
UserID: userID,
BonusAmount: promoCode.BonusAmount,
UsedAt: time.Now(),
}
if err := s.promoRepo.CreateUsage(txCtx, usage); err != nil {
return fmt.Errorf("create usage record: %w", err)
}
// 增加使用次数
if err := s.promoRepo.IncrementUsedCount(txCtx, promoCode.ID); err != nil {
return fmt.Errorf("increment used count: %w", err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit transaction: %w", err)
}
// 失效余额缓存
if s.billingCacheService != nil {
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
}()
}
return nil
}
// GenerateRandomCode 生成随机优惠码
func (s *PromoService) GenerateRandomCode() (string, error) {
bytes := make([]byte, 8)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("generate random bytes: %w", err)
}
return strings.ToUpper(hex.EncodeToString(bytes)), nil
}
// Create 创建优惠码
func (s *PromoService) Create(ctx context.Context, input *CreatePromoCodeInput) (*PromoCode, error) {
code := strings.TrimSpace(input.Code)
if code == "" {
// 自动生成
var err error
code, err = s.GenerateRandomCode()
if err != nil {
return nil, err
}
}
promoCode := &PromoCode{
Code: strings.ToUpper(code),
BonusAmount: input.BonusAmount,
MaxUses: input.MaxUses,
UsedCount: 0,
Status: PromoCodeStatusActive,
ExpiresAt: input.ExpiresAt,
Notes: input.Notes,
}
if err := s.promoRepo.Create(ctx, promoCode); err != nil {
return nil, fmt.Errorf("create promo code: %w", err)
}
return promoCode, nil
}
// GetByID 根据ID获取优惠码
func (s *PromoService) GetByID(ctx context.Context, id int64) (*PromoCode, error) {
code, err := s.promoRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
return code, nil
}
// Update 更新优惠码
func (s *PromoService) Update(ctx context.Context, id int64, input *UpdatePromoCodeInput) (*PromoCode, error) {
promoCode, err := s.promoRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if input.Code != nil {
promoCode.Code = strings.ToUpper(strings.TrimSpace(*input.Code))
}
if input.BonusAmount != nil {
promoCode.BonusAmount = *input.BonusAmount
}
if input.MaxUses != nil {
promoCode.MaxUses = *input.MaxUses
}
if input.Status != nil {
promoCode.Status = *input.Status
}
if input.ExpiresAt != nil {
promoCode.ExpiresAt = input.ExpiresAt
}
if input.Notes != nil {
promoCode.Notes = *input.Notes
}
if err := s.promoRepo.Update(ctx, promoCode); err != nil {
return nil, fmt.Errorf("update promo code: %w", err)
}
return promoCode, nil
}
// Delete 删除优惠码
func (s *PromoService) Delete(ctx context.Context, id int64) error {
if err := s.promoRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete promo code: %w", err)
}
return nil
}
// List 获取优惠码列表
func (s *PromoService) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]PromoCode, *pagination.PaginationResult, error) {
return s.promoRepo.ListWithFilters(ctx, params, status, search)
}
// ListUsages 获取使用记录
func (s *PromoService) ListUsages(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]PromoCodeUsage, *pagination.PaginationResult, error) {
return s.promoRepo.ListUsagesByPromoCode(ctx, promoCodeID, params)
}

View File

@@ -345,7 +345,7 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
// 如果状态为allowed且之前有限流说明窗口已重置清除限流状态
if status == "allowed" && account.IsRateLimited() {
if err := s.accountRepo.ClearRateLimit(ctx, account.ID); err != nil {
if err := s.ClearRateLimit(ctx, account.ID); err != nil {
log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err)
}
}
@@ -353,7 +353,10 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
// ClearRateLimit 清除账号的限流状态
func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) error {
return s.accountRepo.ClearRateLimit(ctx, accountID)
if err := s.accountRepo.ClearRateLimit(ctx, accountID); err != nil {
return err
}
return s.accountRepo.ClearAntigravityQuotaScopes(ctx, accountID)
}
func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID int64) error {

View File

@@ -87,6 +87,7 @@ var ProviderSet = wire.NewSet(
NewAccountService,
NewProxyService,
NewRedeemService,
NewPromoService,
NewUsageService,
NewDashboardService,
ProvidePricingService,