From a18bbb5f2fb5d2e623dc18daa18ae3acc7e86973 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Fri, 6 Mar 2026 14:54:52 +0800 Subject: [PATCH] =?UTF-8?q?fix(openai):=20=E7=BB=9F=E4=B8=80=E4=B8=93?= =?UTF-8?q?=E5=B1=9E=E5=80=8D=E7=8E=87=E8=AE=A1=E8=B4=B9=E9=93=BE=E8=B7=AF?= =?UTF-8?q?=E5=B9=B6=E8=A1=A5=E9=BD=90=E5=9B=9E=E5=BD=92=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 抽取共享的用户分组专属倍率解析器,统一缓存、singleflight 与回退逻辑。\n\n让 OpenAI 独立计费链路复用专属倍率解析,修复 usage 记录与实际扣费未命中用户专属倍率的问题。\n\n补齐 OpenAI 计费与解析器单元测试,并修复全量回归中暴露的 lint 阻塞项。\n\nCo-Authored-By: Claude Opus 4.6 --- backend/cmd/server/wire_gen.go | 2 +- backend/internal/service/gateway_service.go | 127 +++---- .../openai_gateway_record_usage_test.go | 338 ++++++++++++++++++ .../service/openai_gateway_service.go | 126 +++++-- .../service/openai_gateway_service_test.go | 24 ++ .../openai_ws_protocol_forward_test.go | 1 + .../service/user_group_rate_resolver.go | 103 ++++++ .../service/user_group_rate_resolver_test.go | 83 +++++ 8 files changed, 692 insertions(+), 112 deletions(-) create mode 100644 backend/internal/service/openai_gateway_record_usage_test.go create mode 100644 backend/internal/service/user_group_rate_resolver.go create mode 100644 backend/internal/service/user_group_rate_resolver_test.go diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 60bb17d5..041a1368 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -164,7 +164,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { digestSessionStore := service.NewDigestSessionStore() gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore) openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService) - openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) + openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 075f3ef0..5c7f9c29 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -501,33 +501,34 @@ func (s *GatewayService) TempUnscheduleRetryableError(ctx context.Context, accou // GatewayService handles API gateway operations type GatewayService struct { - accountRepo AccountRepository - groupRepo GroupRepository - usageLogRepo UsageLogRepository - userRepo UserRepository - userSubRepo UserSubscriptionRepository - userGroupRateRepo UserGroupRateRepository - cache GatewayCache - digestStore *DigestSessionStore - cfg *config.Config - schedulerSnapshot *SchedulerSnapshotService - billingService *BillingService - rateLimitService *RateLimitService - billingCacheService *BillingCacheService - identityService *IdentityService - httpUpstream HTTPUpstream - deferredService *DeferredService - concurrencyService *ConcurrencyService - claudeTokenProvider *ClaudeTokenProvider - sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken) - rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken) - userGroupRateCache *gocache.Cache - userGroupRateSF singleflight.Group - modelsListCache *gocache.Cache - modelsListCacheTTL time.Duration - responseHeaderFilter *responseheaders.CompiledHeaderFilter - debugModelRouting atomic.Bool - debugClaudeMimic atomic.Bool + accountRepo AccountRepository + groupRepo GroupRepository + usageLogRepo UsageLogRepository + userRepo UserRepository + userSubRepo UserSubscriptionRepository + userGroupRateRepo UserGroupRateRepository + cache GatewayCache + digestStore *DigestSessionStore + cfg *config.Config + schedulerSnapshot *SchedulerSnapshotService + billingService *BillingService + rateLimitService *RateLimitService + billingCacheService *BillingCacheService + identityService *IdentityService + httpUpstream HTTPUpstream + deferredService *DeferredService + concurrencyService *ConcurrencyService + claudeTokenProvider *ClaudeTokenProvider + sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken) + rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken) + userGroupRateResolver *userGroupRateResolver + userGroupRateCache *gocache.Cache + userGroupRateSF singleflight.Group + modelsListCache *gocache.Cache + modelsListCacheTTL time.Duration + responseHeaderFilter *responseheaders.CompiledHeaderFilter + debugModelRouting atomic.Bool + debugClaudeMimic atomic.Bool } // NewGatewayService creates a new GatewayService @@ -582,6 +583,13 @@ func NewGatewayService( modelsListCacheTTL: modelsListTTL, responseHeaderFilter: compileResponseHeaderFilter(cfg), } + svc.userGroupRateResolver = newUserGroupRateResolver( + userGroupRateRepo, + svc.userGroupRateCache, + userGroupRateTTL, + &svc.userGroupRateSF, + "service.gateway", + ) svc.debugModelRouting.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING"))) svc.debugClaudeMimic.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC"))) return svc @@ -6332,63 +6340,20 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo } func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 { - if s == nil || userID <= 0 || groupID <= 0 { + if s == nil { return groupDefaultMultiplier } - - key := fmt.Sprintf("%d:%d", userID, groupID) - if s.userGroupRateCache != nil { - if cached, ok := s.userGroupRateCache.Get(key); ok { - if multiplier, castOK := cached.(float64); castOK { - userGroupRateCacheHitTotal.Add(1) - return multiplier - } - } + resolver := s.userGroupRateResolver + if resolver == nil { + resolver = newUserGroupRateResolver( + s.userGroupRateRepo, + s.userGroupRateCache, + resolveUserGroupRateCacheTTL(s.cfg), + &s.userGroupRateSF, + "service.gateway", + ) } - if s.userGroupRateRepo == nil { - return groupDefaultMultiplier - } - userGroupRateCacheMissTotal.Add(1) - - value, err, shared := s.userGroupRateSF.Do(key, func() (any, error) { - if s.userGroupRateCache != nil { - if cached, ok := s.userGroupRateCache.Get(key); ok { - if multiplier, castOK := cached.(float64); castOK { - userGroupRateCacheHitTotal.Add(1) - return multiplier, nil - } - } - } - - userGroupRateCacheLoadTotal.Add(1) - userRate, repoErr := s.userGroupRateRepo.GetByUserAndGroup(ctx, userID, groupID) - if repoErr != nil { - return nil, repoErr - } - multiplier := groupDefaultMultiplier - if userRate != nil { - multiplier = *userRate - } - if s.userGroupRateCache != nil { - s.userGroupRateCache.Set(key, multiplier, resolveUserGroupRateCacheTTL(s.cfg)) - } - return multiplier, nil - }) - if shared { - userGroupRateCacheSFSharedTotal.Add(1) - } - if err != nil { - userGroupRateCacheFallbackTotal.Add(1) - logger.LegacyPrintf("service.gateway", "get user group rate failed, fallback to group default: user=%d group=%d err=%v", userID, groupID, err) - return groupDefaultMultiplier - } - - multiplier, ok := value.(float64) - if !ok { - userGroupRateCacheFallbackTotal.Add(1) - return groupDefaultMultiplier - } - return multiplier + return resolver.Resolve(ctx, userID, groupID, groupDefaultMultiplier) } // RecordUsageInput 记录使用量的输入参数 diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go new file mode 100644 index 00000000..421f3fec --- /dev/null +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -0,0 +1,338 @@ +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type openAIRecordUsageLogRepoStub struct { + UsageLogRepository + + inserted bool + err error + calls int + lastLog *UsageLog +} + +func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) { + s.calls++ + s.lastLog = log + return s.inserted, s.err +} + +type openAIRecordUsageUserRepoStub struct { + UserRepository + + deductCalls int + deductErr error + lastAmount float64 +} + +func (s *openAIRecordUsageUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error { + s.deductCalls++ + s.lastAmount = amount + return s.deductErr +} + +type openAIRecordUsageSubRepoStub struct { + UserSubscriptionRepository + + incrementCalls int + incrementErr error + lastAmount float64 +} + +func (s *openAIRecordUsageSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { + s.incrementCalls++ + s.lastAmount = costUSD + return s.incrementErr +} + +type openAIRecordUsageAPIKeyQuotaStub struct { + quotaCalls int + rateLimitCalls int + err error + lastAmount float64 +} + +func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error { + s.quotaCalls++ + s.lastAmount = cost + return s.err +} + +func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error { + s.rateLimitCalls++ + s.lastAmount = cost + return s.err +} + +type openAIUserGroupRateRepoStub struct { + UserGroupRateRepository + + rate *float64 + err error + calls int +} + +func (s *openAIUserGroupRateRepoStub) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) { + s.calls++ + if s.err != nil { + return nil, s.err + } + return s.rate, nil +} + +func i64p(v int64) *int64 { + return &v +} + +func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService { + cfg := &config.Config{} + cfg.Default.RateMultiplier = 1.1 + + return &OpenAIGatewayService{ + usageLogRepo: usageRepo, + userRepo: userRepo, + userSubRepo: subRepo, + cfg: cfg, + billingService: NewBillingService(cfg, nil), + billingCacheService: &BillingCacheService{}, + deferredService: &DeferredService{}, + userGroupRateResolver: newUserGroupRateResolver( + rateRepo, + nil, + resolveUserGroupRateCacheTTL(cfg), + nil, + "service.openai_gateway.test", + ), + } +} + +func expectedOpenAICost(t *testing.T, svc *OpenAIGatewayService, model string, usage OpenAIUsage, multiplier float64) *CostBreakdown { + t.Helper() + + cost, err := svc.billingService.CalculateCost(model, UsageTokens{ + InputTokens: max(usage.InputTokens-usage.CacheReadInputTokens, 0), + OutputTokens: usage.OutputTokens, + CacheCreationTokens: usage.CacheCreationInputTokens, + CacheReadTokens: usage.CacheReadInputTokens, + }, multiplier) + require.NoError(t, err) + return cost +} + +func max(a, b int) int { + if a > b { + return a + } + return b +} + +func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) { + groupID := int64(11) + groupRate := 1.4 + userRate := 1.8 + usage := OpenAIUsage{InputTokens: 15, OutputTokens: 4, CacheReadInputTokens: 3} + + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + rateRepo := &openAIUserGroupRateRepoStub{rate: &userRate} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_user_group_rate", + Usage: usage, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1001, + GroupID: i64p(groupID), + Group: &Group{ + ID: groupID, + RateMultiplier: groupRate, + }, + }, + User: &User{ID: 2001}, + Account: &Account{ID: 3001}, + }) + + require.NoError(t, err) + require.Equal(t, 1, rateRepo.calls) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, userRate, usageRepo.lastLog.RateMultiplier) + require.Equal(t, 12, usageRepo.lastLog.InputTokens) + require.Equal(t, 3, usageRepo.lastLog.CacheReadTokens) + + expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, userRate) + require.InDelta(t, expected.ActualCost, usageRepo.lastLog.ActualCost, 1e-12) + require.InDelta(t, expected.ActualCost, userRepo.lastAmount, 1e-12) + require.Equal(t, 1, userRepo.deductCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateOnResolverError(t *testing.T) { + groupID := int64(12) + groupRate := 1.6 + usage := OpenAIUsage{InputTokens: 10, OutputTokens: 5, CacheReadInputTokens: 2} + + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + rateRepo := &openAIUserGroupRateRepoStub{err: errors.New("db unavailable")} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_group_default_on_error", + Usage: usage, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1002, + GroupID: i64p(groupID), + Group: &Group{ + ID: groupID, + RateMultiplier: groupRate, + }, + }, + User: &User{ID: 2002}, + Account: &Account{ID: 3002}, + }) + + require.NoError(t, err) + require.Equal(t, 1, rateRepo.calls) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, groupRate, usageRepo.lastLog.RateMultiplier) + + expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, groupRate) + require.InDelta(t, expected.ActualCost, userRepo.lastAmount, 1e-12) +} + +func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateWhenResolverMissing(t *testing.T) { + groupID := int64(13) + groupRate := 1.25 + usage := OpenAIUsage{InputTokens: 9, OutputTokens: 4, CacheReadInputTokens: 1} + + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + svc.userGroupRateResolver = nil + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_group_default_nil_resolver", + Usage: usage, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1003, + GroupID: i64p(groupID), + Group: &Group{ + ID: groupID, + RateMultiplier: groupRate, + }, + }, + User: &User{ID: 2003}, + Account: &Account{ID: 3003}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, groupRate, usageRepo.lastLog.RateMultiplier) +} + +func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_duplicate", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 1004}, + User: &User{ID: 2004}, + Account: &Account{ID: 3004}, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured(t *testing.T) { + usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2} + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_quota_update", + Usage: usage, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1005, + Quota: 100, + }, + User: &User{ID: 2005}, + Account: &Account{ID: 3005}, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, quotaSvc.quotaCalls) + require.Equal(t, 0, quotaSvc.rateLimitCalls) + expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, 1.1) + require.InDelta(t, expected.ActualCost, quotaSvc.lastAmount, 1e-12) +} + +func TestOpenAIGatewayServiceRecordUsage_ClampsActualInputTokensToZero(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_clamp_actual_input", + Usage: OpenAIUsage{ + InputTokens: 2, + OutputTokens: 1, + CacheReadInputTokens: 5, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 1006}, + User: &User{ID: 2006}, + Account: &Account{ID: 3006}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, 0, usageRepo.lastLog.InputTokens) +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 73bdba65..9970fc19 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -245,23 +245,24 @@ type openAIWSRetryMetrics struct { // OpenAIGatewayService handles OpenAI API gateway operations type OpenAIGatewayService struct { - accountRepo AccountRepository - usageLogRepo UsageLogRepository - userRepo UserRepository - userSubRepo UserSubscriptionRepository - cache GatewayCache - cfg *config.Config - codexDetector CodexClientRestrictionDetector - schedulerSnapshot *SchedulerSnapshotService - concurrencyService *ConcurrencyService - billingService *BillingService - rateLimitService *RateLimitService - billingCacheService *BillingCacheService - httpUpstream HTTPUpstream - deferredService *DeferredService - openAITokenProvider *OpenAITokenProvider - toolCorrector *CodexToolCorrector - openaiWSResolver OpenAIWSProtocolResolver + accountRepo AccountRepository + usageLogRepo UsageLogRepository + userRepo UserRepository + userSubRepo UserSubscriptionRepository + cache GatewayCache + cfg *config.Config + codexDetector CodexClientRestrictionDetector + schedulerSnapshot *SchedulerSnapshotService + concurrencyService *ConcurrencyService + billingService *BillingService + rateLimitService *RateLimitService + billingCacheService *BillingCacheService + userGroupRateResolver *userGroupRateResolver + httpUpstream HTTPUpstream + deferredService *DeferredService + openAITokenProvider *OpenAITokenProvider + toolCorrector *CodexToolCorrector + openaiWSResolver OpenAIWSProtocolResolver openaiWSPoolOnce sync.Once openaiWSStateStoreOnce sync.Once @@ -284,6 +285,7 @@ func NewOpenAIGatewayService( usageLogRepo UsageLogRepository, userRepo UserRepository, userSubRepo UserSubscriptionRepository, + userGroupRateRepo UserGroupRateRepository, cache GatewayCache, cfg *config.Config, schedulerSnapshot *SchedulerSnapshotService, @@ -296,18 +298,25 @@ func NewOpenAIGatewayService( openAITokenProvider *OpenAITokenProvider, ) *OpenAIGatewayService { svc := &OpenAIGatewayService{ - accountRepo: accountRepo, - usageLogRepo: usageLogRepo, - userRepo: userRepo, - userSubRepo: userSubRepo, - cache: cache, - cfg: cfg, - codexDetector: NewOpenAICodexClientRestrictionDetector(cfg), - schedulerSnapshot: schedulerSnapshot, - concurrencyService: concurrencyService, - billingService: billingService, - rateLimitService: rateLimitService, - billingCacheService: billingCacheService, + accountRepo: accountRepo, + usageLogRepo: usageLogRepo, + userRepo: userRepo, + userSubRepo: userSubRepo, + cache: cache, + cfg: cfg, + codexDetector: NewOpenAICodexClientRestrictionDetector(cfg), + schedulerSnapshot: schedulerSnapshot, + concurrencyService: concurrencyService, + billingService: billingService, + rateLimitService: rateLimitService, + billingCacheService: billingCacheService, + userGroupRateResolver: newUserGroupRateResolver( + userGroupRateRepo, + nil, + resolveUserGroupRateCacheTTL(cfg), + nil, + "service.openai_gateway", + ), httpUpstream: httpUpstream, deferredService: deferredService, openAITokenProvider: openAITokenProvider, @@ -3261,6 +3270,14 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin. // Correct tool calls in final response body = s.correctToolCallsInResponseBody(body) } else { + terminalType, terminalPayload, terminalOK := extractOpenAISSETerminalEvent(bodyText) + if terminalOK && terminalType == "response.failed" { + msg := extractOpenAISSEErrorMessage(terminalPayload) + if msg == "" { + msg = "Upstream compact response failed" + } + return nil, s.writeOpenAINonStreamingProtocolError(resp, c, msg) + } usage = s.parseSSEUsageFromBody(bodyText) if originalModel != mappedModel { bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel) @@ -3282,6 +3299,51 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin. return usage, nil } +func extractOpenAISSETerminalEvent(body string) (string, []byte, bool) { + lines := strings.Split(body, "\n") + for _, line := range lines { + data, ok := extractOpenAISSEDataLine(line) + if !ok || data == "" || data == "[DONE]" { + continue + } + eventType := strings.TrimSpace(gjson.Get(data, "type").String()) + switch eventType { + case "response.completed", "response.done", "response.failed": + return eventType, []byte(data), true + } + } + return "", nil, false +} + +func extractOpenAISSEErrorMessage(payload []byte) string { + if len(payload) == 0 { + return "" + } + for _, path := range []string{"response.error.message", "error.message", "message"} { + if msg := strings.TrimSpace(gjson.GetBytes(payload, path).String()); msg != "" { + return sanitizeUpstreamErrorMessage(msg) + } + } + return sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(payload))) +} + +func (s *OpenAIGatewayService) writeOpenAINonStreamingProtocolError(resp *http.Response, c *gin.Context, message string) error { + message = sanitizeUpstreamErrorMessage(strings.TrimSpace(message)) + if message == "" { + message = "Upstream returned an invalid non-streaming response" + } + setOpsUpstreamError(c, http.StatusBadGateway, message, "") + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8") + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": message, + }, + }) + return fmt.Errorf("non-streaming openai protocol error: %s", message) +} + func extractCodexFinalResponse(body string) ([]byte, bool) { lines := strings.Split(body, "\n") for _, line := range lines { @@ -3413,7 +3475,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec // Get rate multiplier multiplier := s.cfg.Default.RateMultiplier if apiKey.GroupID != nil && apiKey.Group != nil { - multiplier = apiKey.Group.RateMultiplier + resolver := s.userGroupRateResolver + if resolver == nil { + resolver = newUserGroupRateResolver(nil, nil, resolveUserGroupRateCacheTTL(s.cfg), nil, "service.openai_gateway") + } + multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier) } cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier) diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 4f5f7f3c..8fc29e75 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -1576,3 +1576,27 @@ func TestHandleOAuthSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) { require.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream") require.Contains(t, rec.Body.String(), `data: {"type":"response.in_progress"`) } + +func TestHandleOAuthSSEToJSON_ResponseFailedReturnsProtocolError(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + svc := &OpenAIGatewayService{cfg: &config.Config{}} + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + } + body := []byte(strings.Join([]string{ + `data: {"type":"response.failed","error":{"message":"upstream rejected request"}}`, + `data: [DONE]`, + }, "\n")) + + usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o") + require.Nil(t, usage) + require.Error(t, err) + require.Equal(t, http.StatusBadGateway, rec.Code) + require.Contains(t, rec.Body.String(), "upstream rejected request") + require.Contains(t, rec.Header().Get("Content-Type"), "application/json") +} diff --git a/backend/internal/service/openai_ws_protocol_forward_test.go b/backend/internal/service/openai_ws_protocol_forward_test.go index df4d4871..7295b13d 100644 --- a/backend/internal/service/openai_ws_protocol_forward_test.go +++ b/backend/internal/service/openai_ws_protocol_forward_test.go @@ -391,6 +391,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) { nil, nil, nil, + nil, cfg, nil, nil, diff --git a/backend/internal/service/user_group_rate_resolver.go b/backend/internal/service/user_group_rate_resolver.go new file mode 100644 index 00000000..7f0ffb0f --- /dev/null +++ b/backend/internal/service/user_group_rate_resolver.go @@ -0,0 +1,103 @@ +package service + +import ( + "context" + "fmt" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + gocache "github.com/patrickmn/go-cache" + "golang.org/x/sync/singleflight" +) + +type userGroupRateResolver struct { + repo UserGroupRateRepository + cache *gocache.Cache + cacheTTL time.Duration + sf *singleflight.Group + logComponent string +} + +func newUserGroupRateResolver(repo UserGroupRateRepository, cache *gocache.Cache, cacheTTL time.Duration, sf *singleflight.Group, logComponent string) *userGroupRateResolver { + if cacheTTL <= 0 { + cacheTTL = defaultUserGroupRateCacheTTL + } + if cache == nil { + cache = gocache.New(cacheTTL, time.Minute) + } + if logComponent == "" { + logComponent = "service.gateway" + } + if sf == nil { + sf = &singleflight.Group{} + } + + return &userGroupRateResolver{ + repo: repo, + cache: cache, + cacheTTL: cacheTTL, + sf: sf, + logComponent: logComponent, + } +} + +func (r *userGroupRateResolver) Resolve(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 { + if r == nil || userID <= 0 || groupID <= 0 { + return groupDefaultMultiplier + } + + key := fmt.Sprintf("%d:%d", userID, groupID) + if r.cache != nil { + if cached, ok := r.cache.Get(key); ok { + if multiplier, castOK := cached.(float64); castOK { + userGroupRateCacheHitTotal.Add(1) + return multiplier + } + } + } + if r.repo == nil { + return groupDefaultMultiplier + } + userGroupRateCacheMissTotal.Add(1) + + value, err, shared := r.sf.Do(key, func() (any, error) { + if r.cache != nil { + if cached, ok := r.cache.Get(key); ok { + if multiplier, castOK := cached.(float64); castOK { + userGroupRateCacheHitTotal.Add(1) + return multiplier, nil + } + } + } + + userGroupRateCacheLoadTotal.Add(1) + userRate, repoErr := r.repo.GetByUserAndGroup(ctx, userID, groupID) + if repoErr != nil { + return nil, repoErr + } + + multiplier := groupDefaultMultiplier + if userRate != nil { + multiplier = *userRate + } + if r.cache != nil { + r.cache.Set(key, multiplier, r.cacheTTL) + } + return multiplier, nil + }) + if shared { + userGroupRateCacheSFSharedTotal.Add(1) + } + if err != nil { + userGroupRateCacheFallbackTotal.Add(1) + logger.LegacyPrintf(r.logComponent, "get user group rate failed, fallback to group default: user=%d group=%d err=%v", userID, groupID, err) + return groupDefaultMultiplier + } + + multiplier, ok := value.(float64) + if !ok { + userGroupRateCacheFallbackTotal.Add(1) + return groupDefaultMultiplier + } + return multiplier +} diff --git a/backend/internal/service/user_group_rate_resolver_test.go b/backend/internal/service/user_group_rate_resolver_test.go new file mode 100644 index 00000000..064ef7ba --- /dev/null +++ b/backend/internal/service/user_group_rate_resolver_test.go @@ -0,0 +1,83 @@ +package service + +import ( + "context" + "testing" + "time" + + gocache "github.com/patrickmn/go-cache" + "github.com/stretchr/testify/require" +) + +type userGroupRateResolverRepoStub struct { + UserGroupRateRepository + + rate *float64 + err error + calls int +} + +func (s *userGroupRateResolverRepoStub) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) { + s.calls++ + if s.err != nil { + return nil, s.err + } + return s.rate, nil +} + +func TestNewUserGroupRateResolver_Defaults(t *testing.T) { + resolver := newUserGroupRateResolver(nil, nil, 0, nil, "") + + require.NotNil(t, resolver) + require.NotNil(t, resolver.cache) + require.Equal(t, defaultUserGroupRateCacheTTL, resolver.cacheTTL) + require.NotNil(t, resolver.sf) + require.Equal(t, "service.gateway", resolver.logComponent) +} + +func TestUserGroupRateResolverResolve_FallbackForNilResolverAndInvalidIDs(t *testing.T) { + var nilResolver *userGroupRateResolver + require.Equal(t, 1.4, nilResolver.Resolve(context.Background(), 101, 202, 1.4)) + + resolver := newUserGroupRateResolver(nil, nil, time.Second, nil, "service.test") + require.Equal(t, 1.4, resolver.Resolve(context.Background(), 0, 202, 1.4)) + require.Equal(t, 1.4, resolver.Resolve(context.Background(), 101, 0, 1.4)) +} + +func TestUserGroupRateResolverResolve_InvalidCacheEntryLoadsRepoAndCaches(t *testing.T) { + resetGatewayHotpathStatsForTest() + + rate := 1.7 + repo := &userGroupRateResolverRepoStub{rate: &rate} + cache := gocache.New(time.Minute, time.Minute) + cache.Set("101:202", "bad-cache", time.Minute) + resolver := newUserGroupRateResolver(repo, cache, time.Minute, nil, "service.test") + + got := resolver.Resolve(context.Background(), 101, 202, 1.2) + require.Equal(t, rate, got) + require.Equal(t, 1, repo.calls) + + cached, ok := cache.Get("101:202") + require.True(t, ok) + require.Equal(t, rate, cached) + + hit, miss, load, _, fallback := GatewayUserGroupRateCacheStats() + require.Equal(t, int64(0), hit) + require.Equal(t, int64(1), miss) + require.Equal(t, int64(1), load) + require.Equal(t, int64(0), fallback) +} + +func TestGatewayServiceGetUserGroupRateMultiplier_FallbacksAndUsesExistingResolver(t *testing.T) { + var nilSvc *GatewayService + require.Equal(t, 1.3, nilSvc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.3)) + + rate := 1.9 + repo := &userGroupRateResolverRepoStub{rate: &rate} + resolver := newUserGroupRateResolver(repo, nil, time.Minute, nil, "service.gateway") + svc := &GatewayService{userGroupRateResolver: resolver} + + got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.2) + require.Equal(t, rate, got) + require.Equal(t, 1, repo.calls) +}