fix(openai): 统一专属倍率计费链路并补齐回归测试
抽取共享的用户分组专属倍率解析器,统一缓存、singleflight 与回退逻辑。\n\n让 OpenAI 独立计费链路复用专属倍率解析,修复 usage 记录与实际扣费未命中用户专属倍率的问题。\n\n补齐 OpenAI 计费与解析器单元测试,并修复全量回归中暴露的 lint 阻塞项。\n\nCo-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -164,7 +164,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
digestSessionStore := service.NewDigestSessionStore()
|
digestSessionStore := service.NewDigestSessionStore()
|
||||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore)
|
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore)
|
||||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||||
|
|||||||
@@ -501,33 +501,34 @@ func (s *GatewayService) TempUnscheduleRetryableError(ctx context.Context, accou
|
|||||||
|
|
||||||
// GatewayService handles API gateway operations
|
// GatewayService handles API gateway operations
|
||||||
type GatewayService struct {
|
type GatewayService struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
groupRepo GroupRepository
|
groupRepo GroupRepository
|
||||||
usageLogRepo UsageLogRepository
|
usageLogRepo UsageLogRepository
|
||||||
userRepo UserRepository
|
userRepo UserRepository
|
||||||
userSubRepo UserSubscriptionRepository
|
userSubRepo UserSubscriptionRepository
|
||||||
userGroupRateRepo UserGroupRateRepository
|
userGroupRateRepo UserGroupRateRepository
|
||||||
cache GatewayCache
|
cache GatewayCache
|
||||||
digestStore *DigestSessionStore
|
digestStore *DigestSessionStore
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
schedulerSnapshot *SchedulerSnapshotService
|
schedulerSnapshot *SchedulerSnapshotService
|
||||||
billingService *BillingService
|
billingService *BillingService
|
||||||
rateLimitService *RateLimitService
|
rateLimitService *RateLimitService
|
||||||
billingCacheService *BillingCacheService
|
billingCacheService *BillingCacheService
|
||||||
identityService *IdentityService
|
identityService *IdentityService
|
||||||
httpUpstream HTTPUpstream
|
httpUpstream HTTPUpstream
|
||||||
deferredService *DeferredService
|
deferredService *DeferredService
|
||||||
concurrencyService *ConcurrencyService
|
concurrencyService *ConcurrencyService
|
||||||
claudeTokenProvider *ClaudeTokenProvider
|
claudeTokenProvider *ClaudeTokenProvider
|
||||||
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
|
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
|
||||||
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken)
|
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken)
|
||||||
userGroupRateCache *gocache.Cache
|
userGroupRateResolver *userGroupRateResolver
|
||||||
userGroupRateSF singleflight.Group
|
userGroupRateCache *gocache.Cache
|
||||||
modelsListCache *gocache.Cache
|
userGroupRateSF singleflight.Group
|
||||||
modelsListCacheTTL time.Duration
|
modelsListCache *gocache.Cache
|
||||||
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
modelsListCacheTTL time.Duration
|
||||||
debugModelRouting atomic.Bool
|
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
||||||
debugClaudeMimic atomic.Bool
|
debugModelRouting atomic.Bool
|
||||||
|
debugClaudeMimic atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGatewayService creates a new GatewayService
|
// NewGatewayService creates a new GatewayService
|
||||||
@@ -582,6 +583,13 @@ func NewGatewayService(
|
|||||||
modelsListCacheTTL: modelsListTTL,
|
modelsListCacheTTL: modelsListTTL,
|
||||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||||
}
|
}
|
||||||
|
svc.userGroupRateResolver = newUserGroupRateResolver(
|
||||||
|
userGroupRateRepo,
|
||||||
|
svc.userGroupRateCache,
|
||||||
|
userGroupRateTTL,
|
||||||
|
&svc.userGroupRateSF,
|
||||||
|
"service.gateway",
|
||||||
|
)
|
||||||
svc.debugModelRouting.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING")))
|
svc.debugModelRouting.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING")))
|
||||||
svc.debugClaudeMimic.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC")))
|
svc.debugClaudeMimic.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC")))
|
||||||
return svc
|
return svc
|
||||||
@@ -6332,63 +6340,20 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 {
|
func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 {
|
||||||
if s == nil || userID <= 0 || groupID <= 0 {
|
if s == nil {
|
||||||
return groupDefaultMultiplier
|
return groupDefaultMultiplier
|
||||||
}
|
}
|
||||||
|
resolver := s.userGroupRateResolver
|
||||||
key := fmt.Sprintf("%d:%d", userID, groupID)
|
if resolver == nil {
|
||||||
if s.userGroupRateCache != nil {
|
resolver = newUserGroupRateResolver(
|
||||||
if cached, ok := s.userGroupRateCache.Get(key); ok {
|
s.userGroupRateRepo,
|
||||||
if multiplier, castOK := cached.(float64); castOK {
|
s.userGroupRateCache,
|
||||||
userGroupRateCacheHitTotal.Add(1)
|
resolveUserGroupRateCacheTTL(s.cfg),
|
||||||
return multiplier
|
&s.userGroupRateSF,
|
||||||
}
|
"service.gateway",
|
||||||
}
|
)
|
||||||
}
|
}
|
||||||
if s.userGroupRateRepo == nil {
|
return resolver.Resolve(ctx, userID, groupID, groupDefaultMultiplier)
|
||||||
return groupDefaultMultiplier
|
|
||||||
}
|
|
||||||
userGroupRateCacheMissTotal.Add(1)
|
|
||||||
|
|
||||||
value, err, shared := s.userGroupRateSF.Do(key, func() (any, error) {
|
|
||||||
if s.userGroupRateCache != nil {
|
|
||||||
if cached, ok := s.userGroupRateCache.Get(key); ok {
|
|
||||||
if multiplier, castOK := cached.(float64); castOK {
|
|
||||||
userGroupRateCacheHitTotal.Add(1)
|
|
||||||
return multiplier, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
userGroupRateCacheLoadTotal.Add(1)
|
|
||||||
userRate, repoErr := s.userGroupRateRepo.GetByUserAndGroup(ctx, userID, groupID)
|
|
||||||
if repoErr != nil {
|
|
||||||
return nil, repoErr
|
|
||||||
}
|
|
||||||
multiplier := groupDefaultMultiplier
|
|
||||||
if userRate != nil {
|
|
||||||
multiplier = *userRate
|
|
||||||
}
|
|
||||||
if s.userGroupRateCache != nil {
|
|
||||||
s.userGroupRateCache.Set(key, multiplier, resolveUserGroupRateCacheTTL(s.cfg))
|
|
||||||
}
|
|
||||||
return multiplier, nil
|
|
||||||
})
|
|
||||||
if shared {
|
|
||||||
userGroupRateCacheSFSharedTotal.Add(1)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
userGroupRateCacheFallbackTotal.Add(1)
|
|
||||||
logger.LegacyPrintf("service.gateway", "get user group rate failed, fallback to group default: user=%d group=%d err=%v", userID, groupID, err)
|
|
||||||
return groupDefaultMultiplier
|
|
||||||
}
|
|
||||||
|
|
||||||
multiplier, ok := value.(float64)
|
|
||||||
if !ok {
|
|
||||||
userGroupRateCacheFallbackTotal.Add(1)
|
|
||||||
return groupDefaultMultiplier
|
|
||||||
}
|
|
||||||
return multiplier
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecordUsageInput 记录使用量的输入参数
|
// RecordUsageInput 记录使用量的输入参数
|
||||||
|
|||||||
338
backend/internal/service/openai_gateway_record_usage_test.go
Normal file
338
backend/internal/service/openai_gateway_record_usage_test.go
Normal file
@@ -0,0 +1,338 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type openAIRecordUsageLogRepoStub struct {
|
||||||
|
UsageLogRepository
|
||||||
|
|
||||||
|
inserted bool
|
||||||
|
err error
|
||||||
|
calls int
|
||||||
|
lastLog *UsageLog
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) {
|
||||||
|
s.calls++
|
||||||
|
s.lastLog = log
|
||||||
|
return s.inserted, s.err
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAIRecordUsageUserRepoStub struct {
|
||||||
|
UserRepository
|
||||||
|
|
||||||
|
deductCalls int
|
||||||
|
deductErr error
|
||||||
|
lastAmount float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIRecordUsageUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error {
|
||||||
|
s.deductCalls++
|
||||||
|
s.lastAmount = amount
|
||||||
|
return s.deductErr
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAIRecordUsageSubRepoStub struct {
|
||||||
|
UserSubscriptionRepository
|
||||||
|
|
||||||
|
incrementCalls int
|
||||||
|
incrementErr error
|
||||||
|
lastAmount float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIRecordUsageSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
||||||
|
s.incrementCalls++
|
||||||
|
s.lastAmount = costUSD
|
||||||
|
return s.incrementErr
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAIRecordUsageAPIKeyQuotaStub struct {
|
||||||
|
quotaCalls int
|
||||||
|
rateLimitCalls int
|
||||||
|
err error
|
||||||
|
lastAmount float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error {
|
||||||
|
s.quotaCalls++
|
||||||
|
s.lastAmount = cost
|
||||||
|
return s.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error {
|
||||||
|
s.rateLimitCalls++
|
||||||
|
s.lastAmount = cost
|
||||||
|
return s.err
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAIUserGroupRateRepoStub struct {
|
||||||
|
UserGroupRateRepository
|
||||||
|
|
||||||
|
rate *float64
|
||||||
|
err error
|
||||||
|
calls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIUserGroupRateRepoStub) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
|
||||||
|
s.calls++
|
||||||
|
if s.err != nil {
|
||||||
|
return nil, s.err
|
||||||
|
}
|
||||||
|
return s.rate, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func i64p(v int64) *int64 {
|
||||||
|
return &v
|
||||||
|
}
|
||||||
|
|
||||||
|
func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService {
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Default.RateMultiplier = 1.1
|
||||||
|
|
||||||
|
return &OpenAIGatewayService{
|
||||||
|
usageLogRepo: usageRepo,
|
||||||
|
userRepo: userRepo,
|
||||||
|
userSubRepo: subRepo,
|
||||||
|
cfg: cfg,
|
||||||
|
billingService: NewBillingService(cfg, nil),
|
||||||
|
billingCacheService: &BillingCacheService{},
|
||||||
|
deferredService: &DeferredService{},
|
||||||
|
userGroupRateResolver: newUserGroupRateResolver(
|
||||||
|
rateRepo,
|
||||||
|
nil,
|
||||||
|
resolveUserGroupRateCacheTTL(cfg),
|
||||||
|
nil,
|
||||||
|
"service.openai_gateway.test",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func expectedOpenAICost(t *testing.T, svc *OpenAIGatewayService, model string, usage OpenAIUsage, multiplier float64) *CostBreakdown {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
cost, err := svc.billingService.CalculateCost(model, UsageTokens{
|
||||||
|
InputTokens: max(usage.InputTokens-usage.CacheReadInputTokens, 0),
|
||||||
|
OutputTokens: usage.OutputTokens,
|
||||||
|
CacheCreationTokens: usage.CacheCreationInputTokens,
|
||||||
|
CacheReadTokens: usage.CacheReadInputTokens,
|
||||||
|
}, multiplier)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return cost
|
||||||
|
}
|
||||||
|
|
||||||
|
func max(a, b int) int {
|
||||||
|
if a > b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) {
|
||||||
|
groupID := int64(11)
|
||||||
|
groupRate := 1.4
|
||||||
|
userRate := 1.8
|
||||||
|
usage := OpenAIUsage{InputTokens: 15, OutputTokens: 4, CacheReadInputTokens: 3}
|
||||||
|
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
rateRepo := &openAIUserGroupRateRepoStub{rate: &userRate}
|
||||||
|
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo)
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||||
|
Result: &OpenAIForwardResult{
|
||||||
|
RequestID: "resp_user_group_rate",
|
||||||
|
Usage: usage,
|
||||||
|
Model: "gpt-5.1",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{
|
||||||
|
ID: 1001,
|
||||||
|
GroupID: i64p(groupID),
|
||||||
|
Group: &Group{
|
||||||
|
ID: groupID,
|
||||||
|
RateMultiplier: groupRate,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
User: &User{ID: 2001},
|
||||||
|
Account: &Account{ID: 3001},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, rateRepo.calls)
|
||||||
|
require.NotNil(t, usageRepo.lastLog)
|
||||||
|
require.Equal(t, userRate, usageRepo.lastLog.RateMultiplier)
|
||||||
|
require.Equal(t, 12, usageRepo.lastLog.InputTokens)
|
||||||
|
require.Equal(t, 3, usageRepo.lastLog.CacheReadTokens)
|
||||||
|
|
||||||
|
expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, userRate)
|
||||||
|
require.InDelta(t, expected.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
|
||||||
|
require.InDelta(t, expected.ActualCost, userRepo.lastAmount, 1e-12)
|
||||||
|
require.Equal(t, 1, userRepo.deductCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateOnResolverError(t *testing.T) {
|
||||||
|
groupID := int64(12)
|
||||||
|
groupRate := 1.6
|
||||||
|
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 5, CacheReadInputTokens: 2}
|
||||||
|
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
rateRepo := &openAIUserGroupRateRepoStub{err: errors.New("db unavailable")}
|
||||||
|
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo)
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||||
|
Result: &OpenAIForwardResult{
|
||||||
|
RequestID: "resp_group_default_on_error",
|
||||||
|
Usage: usage,
|
||||||
|
Model: "gpt-5.1",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{
|
||||||
|
ID: 1002,
|
||||||
|
GroupID: i64p(groupID),
|
||||||
|
Group: &Group{
|
||||||
|
ID: groupID,
|
||||||
|
RateMultiplier: groupRate,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
User: &User{ID: 2002},
|
||||||
|
Account: &Account{ID: 3002},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, rateRepo.calls)
|
||||||
|
require.NotNil(t, usageRepo.lastLog)
|
||||||
|
require.Equal(t, groupRate, usageRepo.lastLog.RateMultiplier)
|
||||||
|
|
||||||
|
expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, groupRate)
|
||||||
|
require.InDelta(t, expected.ActualCost, userRepo.lastAmount, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateWhenResolverMissing(t *testing.T) {
|
||||||
|
groupID := int64(13)
|
||||||
|
groupRate := 1.25
|
||||||
|
usage := OpenAIUsage{InputTokens: 9, OutputTokens: 4, CacheReadInputTokens: 1}
|
||||||
|
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||||
|
svc.userGroupRateResolver = nil
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||||
|
Result: &OpenAIForwardResult{
|
||||||
|
RequestID: "resp_group_default_nil_resolver",
|
||||||
|
Usage: usage,
|
||||||
|
Model: "gpt-5.1",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{
|
||||||
|
ID: 1003,
|
||||||
|
GroupID: i64p(groupID),
|
||||||
|
Group: &Group{
|
||||||
|
ID: groupID,
|
||||||
|
RateMultiplier: groupRate,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
User: &User{ID: 2003},
|
||||||
|
Account: &Account{ID: 3003},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, usageRepo.lastLog)
|
||||||
|
require.Equal(t, groupRate, usageRepo.lastLog.RateMultiplier)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||||
|
Result: &OpenAIForwardResult{
|
||||||
|
RequestID: "resp_duplicate",
|
||||||
|
Usage: OpenAIUsage{
|
||||||
|
InputTokens: 8,
|
||||||
|
OutputTokens: 4,
|
||||||
|
},
|
||||||
|
Model: "gpt-5.1",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 1004},
|
||||||
|
User: &User{ID: 2004},
|
||||||
|
Account: &Account{ID: 3004},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, usageRepo.calls)
|
||||||
|
require.Equal(t, 0, userRepo.deductCalls)
|
||||||
|
require.Equal(t, 0, subRepo.incrementCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured(t *testing.T) {
|
||||||
|
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2}
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||||
|
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||||
|
Result: &OpenAIForwardResult{
|
||||||
|
RequestID: "resp_quota_update",
|
||||||
|
Usage: usage,
|
||||||
|
Model: "gpt-5.1",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{
|
||||||
|
ID: 1005,
|
||||||
|
Quota: 100,
|
||||||
|
},
|
||||||
|
User: &User{ID: 2005},
|
||||||
|
Account: &Account{ID: 3005},
|
||||||
|
APIKeyService: quotaSvc,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||||
|
require.Equal(t, 0, quotaSvc.rateLimitCalls)
|
||||||
|
expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, 1.1)
|
||||||
|
require.InDelta(t, expected.ActualCost, quotaSvc.lastAmount, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceRecordUsage_ClampsActualInputTokensToZero(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||||
|
Result: &OpenAIForwardResult{
|
||||||
|
RequestID: "resp_clamp_actual_input",
|
||||||
|
Usage: OpenAIUsage{
|
||||||
|
InputTokens: 2,
|
||||||
|
OutputTokens: 1,
|
||||||
|
CacheReadInputTokens: 5,
|
||||||
|
},
|
||||||
|
Model: "gpt-5.1",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 1006},
|
||||||
|
User: &User{ID: 2006},
|
||||||
|
Account: &Account{ID: 3006},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, usageRepo.lastLog)
|
||||||
|
require.Equal(t, 0, usageRepo.lastLog.InputTokens)
|
||||||
|
}
|
||||||
@@ -245,23 +245,24 @@ type openAIWSRetryMetrics struct {
|
|||||||
|
|
||||||
// OpenAIGatewayService handles OpenAI API gateway operations
|
// OpenAIGatewayService handles OpenAI API gateway operations
|
||||||
type OpenAIGatewayService struct {
|
type OpenAIGatewayService struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
usageLogRepo UsageLogRepository
|
usageLogRepo UsageLogRepository
|
||||||
userRepo UserRepository
|
userRepo UserRepository
|
||||||
userSubRepo UserSubscriptionRepository
|
userSubRepo UserSubscriptionRepository
|
||||||
cache GatewayCache
|
cache GatewayCache
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
codexDetector CodexClientRestrictionDetector
|
codexDetector CodexClientRestrictionDetector
|
||||||
schedulerSnapshot *SchedulerSnapshotService
|
schedulerSnapshot *SchedulerSnapshotService
|
||||||
concurrencyService *ConcurrencyService
|
concurrencyService *ConcurrencyService
|
||||||
billingService *BillingService
|
billingService *BillingService
|
||||||
rateLimitService *RateLimitService
|
rateLimitService *RateLimitService
|
||||||
billingCacheService *BillingCacheService
|
billingCacheService *BillingCacheService
|
||||||
httpUpstream HTTPUpstream
|
userGroupRateResolver *userGroupRateResolver
|
||||||
deferredService *DeferredService
|
httpUpstream HTTPUpstream
|
||||||
openAITokenProvider *OpenAITokenProvider
|
deferredService *DeferredService
|
||||||
toolCorrector *CodexToolCorrector
|
openAITokenProvider *OpenAITokenProvider
|
||||||
openaiWSResolver OpenAIWSProtocolResolver
|
toolCorrector *CodexToolCorrector
|
||||||
|
openaiWSResolver OpenAIWSProtocolResolver
|
||||||
|
|
||||||
openaiWSPoolOnce sync.Once
|
openaiWSPoolOnce sync.Once
|
||||||
openaiWSStateStoreOnce sync.Once
|
openaiWSStateStoreOnce sync.Once
|
||||||
@@ -284,6 +285,7 @@ func NewOpenAIGatewayService(
|
|||||||
usageLogRepo UsageLogRepository,
|
usageLogRepo UsageLogRepository,
|
||||||
userRepo UserRepository,
|
userRepo UserRepository,
|
||||||
userSubRepo UserSubscriptionRepository,
|
userSubRepo UserSubscriptionRepository,
|
||||||
|
userGroupRateRepo UserGroupRateRepository,
|
||||||
cache GatewayCache,
|
cache GatewayCache,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
schedulerSnapshot *SchedulerSnapshotService,
|
schedulerSnapshot *SchedulerSnapshotService,
|
||||||
@@ -296,18 +298,25 @@ func NewOpenAIGatewayService(
|
|||||||
openAITokenProvider *OpenAITokenProvider,
|
openAITokenProvider *OpenAITokenProvider,
|
||||||
) *OpenAIGatewayService {
|
) *OpenAIGatewayService {
|
||||||
svc := &OpenAIGatewayService{
|
svc := &OpenAIGatewayService{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
usageLogRepo: usageLogRepo,
|
usageLogRepo: usageLogRepo,
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
userSubRepo: userSubRepo,
|
userSubRepo: userSubRepo,
|
||||||
cache: cache,
|
cache: cache,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
codexDetector: NewOpenAICodexClientRestrictionDetector(cfg),
|
codexDetector: NewOpenAICodexClientRestrictionDetector(cfg),
|
||||||
schedulerSnapshot: schedulerSnapshot,
|
schedulerSnapshot: schedulerSnapshot,
|
||||||
concurrencyService: concurrencyService,
|
concurrencyService: concurrencyService,
|
||||||
billingService: billingService,
|
billingService: billingService,
|
||||||
rateLimitService: rateLimitService,
|
rateLimitService: rateLimitService,
|
||||||
billingCacheService: billingCacheService,
|
billingCacheService: billingCacheService,
|
||||||
|
userGroupRateResolver: newUserGroupRateResolver(
|
||||||
|
userGroupRateRepo,
|
||||||
|
nil,
|
||||||
|
resolveUserGroupRateCacheTTL(cfg),
|
||||||
|
nil,
|
||||||
|
"service.openai_gateway",
|
||||||
|
),
|
||||||
httpUpstream: httpUpstream,
|
httpUpstream: httpUpstream,
|
||||||
deferredService: deferredService,
|
deferredService: deferredService,
|
||||||
openAITokenProvider: openAITokenProvider,
|
openAITokenProvider: openAITokenProvider,
|
||||||
@@ -3261,6 +3270,14 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
|
|||||||
// Correct tool calls in final response
|
// Correct tool calls in final response
|
||||||
body = s.correctToolCallsInResponseBody(body)
|
body = s.correctToolCallsInResponseBody(body)
|
||||||
} else {
|
} else {
|
||||||
|
terminalType, terminalPayload, terminalOK := extractOpenAISSETerminalEvent(bodyText)
|
||||||
|
if terminalOK && terminalType == "response.failed" {
|
||||||
|
msg := extractOpenAISSEErrorMessage(terminalPayload)
|
||||||
|
if msg == "" {
|
||||||
|
msg = "Upstream compact response failed"
|
||||||
|
}
|
||||||
|
return nil, s.writeOpenAINonStreamingProtocolError(resp, c, msg)
|
||||||
|
}
|
||||||
usage = s.parseSSEUsageFromBody(bodyText)
|
usage = s.parseSSEUsageFromBody(bodyText)
|
||||||
if originalModel != mappedModel {
|
if originalModel != mappedModel {
|
||||||
bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel)
|
bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel)
|
||||||
@@ -3282,6 +3299,51 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
|
|||||||
return usage, nil
|
return usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func extractOpenAISSETerminalEvent(body string) (string, []byte, bool) {
|
||||||
|
lines := strings.Split(body, "\n")
|
||||||
|
for _, line := range lines {
|
||||||
|
data, ok := extractOpenAISSEDataLine(line)
|
||||||
|
if !ok || data == "" || data == "[DONE]" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
eventType := strings.TrimSpace(gjson.Get(data, "type").String())
|
||||||
|
switch eventType {
|
||||||
|
case "response.completed", "response.done", "response.failed":
|
||||||
|
return eventType, []byte(data), true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractOpenAISSEErrorMessage(payload []byte) string {
|
||||||
|
if len(payload) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
for _, path := range []string{"response.error.message", "error.message", "message"} {
|
||||||
|
if msg := strings.TrimSpace(gjson.GetBytes(payload, path).String()); msg != "" {
|
||||||
|
return sanitizeUpstreamErrorMessage(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(payload)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) writeOpenAINonStreamingProtocolError(resp *http.Response, c *gin.Context, message string) error {
|
||||||
|
message = sanitizeUpstreamErrorMessage(strings.TrimSpace(message))
|
||||||
|
if message == "" {
|
||||||
|
message = "Upstream returned an invalid non-streaming response"
|
||||||
|
}
|
||||||
|
setOpsUpstreamError(c, http.StatusBadGateway, message, "")
|
||||||
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "upstream_error",
|
||||||
|
"message": message,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return fmt.Errorf("non-streaming openai protocol error: %s", message)
|
||||||
|
}
|
||||||
|
|
||||||
func extractCodexFinalResponse(body string) ([]byte, bool) {
|
func extractCodexFinalResponse(body string) ([]byte, bool) {
|
||||||
lines := strings.Split(body, "\n")
|
lines := strings.Split(body, "\n")
|
||||||
for _, line := range lines {
|
for _, line := range lines {
|
||||||
@@ -3413,7 +3475,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
|||||||
// Get rate multiplier
|
// Get rate multiplier
|
||||||
multiplier := s.cfg.Default.RateMultiplier
|
multiplier := s.cfg.Default.RateMultiplier
|
||||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||||
multiplier = apiKey.Group.RateMultiplier
|
resolver := s.userGroupRateResolver
|
||||||
|
if resolver == nil {
|
||||||
|
resolver = newUserGroupRateResolver(nil, nil, resolveUserGroupRateCacheTTL(s.cfg), nil, "service.openai_gateway")
|
||||||
|
}
|
||||||
|
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
|
||||||
}
|
}
|
||||||
|
|
||||||
cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier)
|
cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier)
|
||||||
|
|||||||
@@ -1576,3 +1576,27 @@ func TestHandleOAuthSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) {
|
|||||||
require.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream")
|
require.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream")
|
||||||
require.Contains(t, rec.Body.String(), `data: {"type":"response.in_progress"`)
|
require.Contains(t, rec.Body.String(), `data: {"type":"response.in_progress"`)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandleOAuthSSEToJSON_ResponseFailedReturnsProtocolError(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{cfg: &config.Config{}}
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||||
|
}
|
||||||
|
body := []byte(strings.Join([]string{
|
||||||
|
`data: {"type":"response.failed","error":{"message":"upstream rejected request"}}`,
|
||||||
|
`data: [DONE]`,
|
||||||
|
}, "\n"))
|
||||||
|
|
||||||
|
usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o")
|
||||||
|
require.Nil(t, usage)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, http.StatusBadGateway, rec.Code)
|
||||||
|
require.Contains(t, rec.Body.String(), "upstream rejected request")
|
||||||
|
require.Contains(t, rec.Header().Get("Content-Type"), "application/json")
|
||||||
|
}
|
||||||
|
|||||||
@@ -391,6 +391,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
|
|||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
cfg,
|
cfg,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
|||||||
103
backend/internal/service/user_group_rate_resolver.go
Normal file
103
backend/internal/service/user_group_rate_resolver.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
gocache "github.com/patrickmn/go-cache"
|
||||||
|
"golang.org/x/sync/singleflight"
|
||||||
|
)
|
||||||
|
|
||||||
|
type userGroupRateResolver struct {
|
||||||
|
repo UserGroupRateRepository
|
||||||
|
cache *gocache.Cache
|
||||||
|
cacheTTL time.Duration
|
||||||
|
sf *singleflight.Group
|
||||||
|
logComponent string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUserGroupRateResolver(repo UserGroupRateRepository, cache *gocache.Cache, cacheTTL time.Duration, sf *singleflight.Group, logComponent string) *userGroupRateResolver {
|
||||||
|
if cacheTTL <= 0 {
|
||||||
|
cacheTTL = defaultUserGroupRateCacheTTL
|
||||||
|
}
|
||||||
|
if cache == nil {
|
||||||
|
cache = gocache.New(cacheTTL, time.Minute)
|
||||||
|
}
|
||||||
|
if logComponent == "" {
|
||||||
|
logComponent = "service.gateway"
|
||||||
|
}
|
||||||
|
if sf == nil {
|
||||||
|
sf = &singleflight.Group{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &userGroupRateResolver{
|
||||||
|
repo: repo,
|
||||||
|
cache: cache,
|
||||||
|
cacheTTL: cacheTTL,
|
||||||
|
sf: sf,
|
||||||
|
logComponent: logComponent,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *userGroupRateResolver) Resolve(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 {
|
||||||
|
if r == nil || userID <= 0 || groupID <= 0 {
|
||||||
|
return groupDefaultMultiplier
|
||||||
|
}
|
||||||
|
|
||||||
|
key := fmt.Sprintf("%d:%d", userID, groupID)
|
||||||
|
if r.cache != nil {
|
||||||
|
if cached, ok := r.cache.Get(key); ok {
|
||||||
|
if multiplier, castOK := cached.(float64); castOK {
|
||||||
|
userGroupRateCacheHitTotal.Add(1)
|
||||||
|
return multiplier
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r.repo == nil {
|
||||||
|
return groupDefaultMultiplier
|
||||||
|
}
|
||||||
|
userGroupRateCacheMissTotal.Add(1)
|
||||||
|
|
||||||
|
value, err, shared := r.sf.Do(key, func() (any, error) {
|
||||||
|
if r.cache != nil {
|
||||||
|
if cached, ok := r.cache.Get(key); ok {
|
||||||
|
if multiplier, castOK := cached.(float64); castOK {
|
||||||
|
userGroupRateCacheHitTotal.Add(1)
|
||||||
|
return multiplier, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
userGroupRateCacheLoadTotal.Add(1)
|
||||||
|
userRate, repoErr := r.repo.GetByUserAndGroup(ctx, userID, groupID)
|
||||||
|
if repoErr != nil {
|
||||||
|
return nil, repoErr
|
||||||
|
}
|
||||||
|
|
||||||
|
multiplier := groupDefaultMultiplier
|
||||||
|
if userRate != nil {
|
||||||
|
multiplier = *userRate
|
||||||
|
}
|
||||||
|
if r.cache != nil {
|
||||||
|
r.cache.Set(key, multiplier, r.cacheTTL)
|
||||||
|
}
|
||||||
|
return multiplier, nil
|
||||||
|
})
|
||||||
|
if shared {
|
||||||
|
userGroupRateCacheSFSharedTotal.Add(1)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
userGroupRateCacheFallbackTotal.Add(1)
|
||||||
|
logger.LegacyPrintf(r.logComponent, "get user group rate failed, fallback to group default: user=%d group=%d err=%v", userID, groupID, err)
|
||||||
|
return groupDefaultMultiplier
|
||||||
|
}
|
||||||
|
|
||||||
|
multiplier, ok := value.(float64)
|
||||||
|
if !ok {
|
||||||
|
userGroupRateCacheFallbackTotal.Add(1)
|
||||||
|
return groupDefaultMultiplier
|
||||||
|
}
|
||||||
|
return multiplier
|
||||||
|
}
|
||||||
83
backend/internal/service/user_group_rate_resolver_test.go
Normal file
83
backend/internal/service/user_group_rate_resolver_test.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
gocache "github.com/patrickmn/go-cache"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type userGroupRateResolverRepoStub struct {
|
||||||
|
UserGroupRateRepository
|
||||||
|
|
||||||
|
rate *float64
|
||||||
|
err error
|
||||||
|
calls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userGroupRateResolverRepoStub) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
|
||||||
|
s.calls++
|
||||||
|
if s.err != nil {
|
||||||
|
return nil, s.err
|
||||||
|
}
|
||||||
|
return s.rate, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewUserGroupRateResolver_Defaults(t *testing.T) {
|
||||||
|
resolver := newUserGroupRateResolver(nil, nil, 0, nil, "")
|
||||||
|
|
||||||
|
require.NotNil(t, resolver)
|
||||||
|
require.NotNil(t, resolver.cache)
|
||||||
|
require.Equal(t, defaultUserGroupRateCacheTTL, resolver.cacheTTL)
|
||||||
|
require.NotNil(t, resolver.sf)
|
||||||
|
require.Equal(t, "service.gateway", resolver.logComponent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserGroupRateResolverResolve_FallbackForNilResolverAndInvalidIDs(t *testing.T) {
|
||||||
|
var nilResolver *userGroupRateResolver
|
||||||
|
require.Equal(t, 1.4, nilResolver.Resolve(context.Background(), 101, 202, 1.4))
|
||||||
|
|
||||||
|
resolver := newUserGroupRateResolver(nil, nil, time.Second, nil, "service.test")
|
||||||
|
require.Equal(t, 1.4, resolver.Resolve(context.Background(), 0, 202, 1.4))
|
||||||
|
require.Equal(t, 1.4, resolver.Resolve(context.Background(), 101, 0, 1.4))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserGroupRateResolverResolve_InvalidCacheEntryLoadsRepoAndCaches(t *testing.T) {
|
||||||
|
resetGatewayHotpathStatsForTest()
|
||||||
|
|
||||||
|
rate := 1.7
|
||||||
|
repo := &userGroupRateResolverRepoStub{rate: &rate}
|
||||||
|
cache := gocache.New(time.Minute, time.Minute)
|
||||||
|
cache.Set("101:202", "bad-cache", time.Minute)
|
||||||
|
resolver := newUserGroupRateResolver(repo, cache, time.Minute, nil, "service.test")
|
||||||
|
|
||||||
|
got := resolver.Resolve(context.Background(), 101, 202, 1.2)
|
||||||
|
require.Equal(t, rate, got)
|
||||||
|
require.Equal(t, 1, repo.calls)
|
||||||
|
|
||||||
|
cached, ok := cache.Get("101:202")
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, rate, cached)
|
||||||
|
|
||||||
|
hit, miss, load, _, fallback := GatewayUserGroupRateCacheStats()
|
||||||
|
require.Equal(t, int64(0), hit)
|
||||||
|
require.Equal(t, int64(1), miss)
|
||||||
|
require.Equal(t, int64(1), load)
|
||||||
|
require.Equal(t, int64(0), fallback)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceGetUserGroupRateMultiplier_FallbacksAndUsesExistingResolver(t *testing.T) {
|
||||||
|
var nilSvc *GatewayService
|
||||||
|
require.Equal(t, 1.3, nilSvc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.3))
|
||||||
|
|
||||||
|
rate := 1.9
|
||||||
|
repo := &userGroupRateResolverRepoStub{rate: &rate}
|
||||||
|
resolver := newUserGroupRateResolver(repo, nil, time.Minute, nil, "service.gateway")
|
||||||
|
svc := &GatewayService{userGroupRateResolver: resolver}
|
||||||
|
|
||||||
|
got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.2)
|
||||||
|
require.Equal(t, rate, got)
|
||||||
|
require.Equal(t, 1, repo.calls)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user