Merge pull request #815 from mt21625457/pr/openai-user-group-rate-upstream

fix(openai): 统一专属倍率计费链路并补齐回归测试
This commit is contained in:
Wesley Liddick
2026-03-06 17:33:09 +08:00
committed by GitHub
8 changed files with 690 additions and 112 deletions

View File

@@ -501,33 +501,34 @@ func (s *GatewayService) TempUnscheduleRetryableError(ctx context.Context, accou
// GatewayService handles API gateway operations
type GatewayService struct {
accountRepo AccountRepository
groupRepo GroupRepository
usageLogRepo UsageLogRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
userGroupRateRepo UserGroupRateRepository
cache GatewayCache
digestStore *DigestSessionStore
cfg *config.Config
schedulerSnapshot *SchedulerSnapshotService
billingService *BillingService
rateLimitService *RateLimitService
billingCacheService *BillingCacheService
identityService *IdentityService
httpUpstream HTTPUpstream
deferredService *DeferredService
concurrencyService *ConcurrencyService
claudeTokenProvider *ClaudeTokenProvider
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken
userGroupRateCache *gocache.Cache
userGroupRateSF singleflight.Group
modelsListCache *gocache.Cache
modelsListCacheTTL time.Duration
responseHeaderFilter *responseheaders.CompiledHeaderFilter
debugModelRouting atomic.Bool
debugClaudeMimic atomic.Bool
accountRepo AccountRepository
groupRepo GroupRepository
usageLogRepo UsageLogRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
userGroupRateRepo UserGroupRateRepository
cache GatewayCache
digestStore *DigestSessionStore
cfg *config.Config
schedulerSnapshot *SchedulerSnapshotService
billingService *BillingService
rateLimitService *RateLimitService
billingCacheService *BillingCacheService
identityService *IdentityService
httpUpstream HTTPUpstream
deferredService *DeferredService
concurrencyService *ConcurrencyService
claudeTokenProvider *ClaudeTokenProvider
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken
userGroupRateResolver *userGroupRateResolver
userGroupRateCache *gocache.Cache
userGroupRateSF singleflight.Group
modelsListCache *gocache.Cache
modelsListCacheTTL time.Duration
responseHeaderFilter *responseheaders.CompiledHeaderFilter
debugModelRouting atomic.Bool
debugClaudeMimic atomic.Bool
}
// NewGatewayService creates a new GatewayService
@@ -582,6 +583,13 @@ func NewGatewayService(
modelsListCacheTTL: modelsListTTL,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
}
svc.userGroupRateResolver = newUserGroupRateResolver(
userGroupRateRepo,
svc.userGroupRateCache,
userGroupRateTTL,
&svc.userGroupRateSF,
"service.gateway",
)
svc.debugModelRouting.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING")))
svc.debugClaudeMimic.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC")))
return svc
@@ -6336,63 +6344,20 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
}
func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 {
if s == nil || userID <= 0 || groupID <= 0 {
if s == nil {
return groupDefaultMultiplier
}
key := fmt.Sprintf("%d:%d", userID, groupID)
if s.userGroupRateCache != nil {
if cached, ok := s.userGroupRateCache.Get(key); ok {
if multiplier, castOK := cached.(float64); castOK {
userGroupRateCacheHitTotal.Add(1)
return multiplier
}
}
resolver := s.userGroupRateResolver
if resolver == nil {
resolver = newUserGroupRateResolver(
s.userGroupRateRepo,
s.userGroupRateCache,
resolveUserGroupRateCacheTTL(s.cfg),
&s.userGroupRateSF,
"service.gateway",
)
}
if s.userGroupRateRepo == nil {
return groupDefaultMultiplier
}
userGroupRateCacheMissTotal.Add(1)
value, err, shared := s.userGroupRateSF.Do(key, func() (any, error) {
if s.userGroupRateCache != nil {
if cached, ok := s.userGroupRateCache.Get(key); ok {
if multiplier, castOK := cached.(float64); castOK {
userGroupRateCacheHitTotal.Add(1)
return multiplier, nil
}
}
}
userGroupRateCacheLoadTotal.Add(1)
userRate, repoErr := s.userGroupRateRepo.GetByUserAndGroup(ctx, userID, groupID)
if repoErr != nil {
return nil, repoErr
}
multiplier := groupDefaultMultiplier
if userRate != nil {
multiplier = *userRate
}
if s.userGroupRateCache != nil {
s.userGroupRateCache.Set(key, multiplier, resolveUserGroupRateCacheTTL(s.cfg))
}
return multiplier, nil
})
if shared {
userGroupRateCacheSFSharedTotal.Add(1)
}
if err != nil {
userGroupRateCacheFallbackTotal.Add(1)
logger.LegacyPrintf("service.gateway", "get user group rate failed, fallback to group default: user=%d group=%d err=%v", userID, groupID, err)
return groupDefaultMultiplier
}
multiplier, ok := value.(float64)
if !ok {
userGroupRateCacheFallbackTotal.Add(1)
return groupDefaultMultiplier
}
return multiplier
return resolver.Resolve(ctx, userID, groupID, groupDefaultMultiplier)
}
// RecordUsageInput 记录使用量的输入参数