fix(openai): 统一专属倍率计费链路并补齐回归测试

抽取共享的用户分组专属倍率解析器,统一缓存、singleflight 与回退逻辑。\n\n让 OpenAI 独立计费链路复用专属倍率解析,修复 usage 记录与实际扣费未命中用户专属倍率的问题。\n\n补齐 OpenAI 计费与解析器单元测试,并修复全量回归中暴露的 lint 阻塞项。\n\nCo-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2026-03-06 14:54:52 +08:00
parent 63a8c76946
commit a18bbb5f2f
8 changed files with 692 additions and 112 deletions

View File

@@ -501,33 +501,34 @@ func (s *GatewayService) TempUnscheduleRetryableError(ctx context.Context, accou
// GatewayService handles API gateway operations
type GatewayService struct {
accountRepo AccountRepository
groupRepo GroupRepository
usageLogRepo UsageLogRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
userGroupRateRepo UserGroupRateRepository
cache GatewayCache
digestStore *DigestSessionStore
cfg *config.Config
schedulerSnapshot *SchedulerSnapshotService
billingService *BillingService
rateLimitService *RateLimitService
billingCacheService *BillingCacheService
identityService *IdentityService
httpUpstream HTTPUpstream
deferredService *DeferredService
concurrencyService *ConcurrencyService
claudeTokenProvider *ClaudeTokenProvider
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken
userGroupRateCache *gocache.Cache
userGroupRateSF singleflight.Group
modelsListCache *gocache.Cache
modelsListCacheTTL time.Duration
responseHeaderFilter *responseheaders.CompiledHeaderFilter
debugModelRouting atomic.Bool
debugClaudeMimic atomic.Bool
accountRepo AccountRepository
groupRepo GroupRepository
usageLogRepo UsageLogRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
userGroupRateRepo UserGroupRateRepository
cache GatewayCache
digestStore *DigestSessionStore
cfg *config.Config
schedulerSnapshot *SchedulerSnapshotService
billingService *BillingService
rateLimitService *RateLimitService
billingCacheService *BillingCacheService
identityService *IdentityService
httpUpstream HTTPUpstream
deferredService *DeferredService
concurrencyService *ConcurrencyService
claudeTokenProvider *ClaudeTokenProvider
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken
userGroupRateResolver *userGroupRateResolver
userGroupRateCache *gocache.Cache
userGroupRateSF singleflight.Group
modelsListCache *gocache.Cache
modelsListCacheTTL time.Duration
responseHeaderFilter *responseheaders.CompiledHeaderFilter
debugModelRouting atomic.Bool
debugClaudeMimic atomic.Bool
}
// NewGatewayService creates a new GatewayService
@@ -582,6 +583,13 @@ func NewGatewayService(
modelsListCacheTTL: modelsListTTL,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
}
svc.userGroupRateResolver = newUserGroupRateResolver(
userGroupRateRepo,
svc.userGroupRateCache,
userGroupRateTTL,
&svc.userGroupRateSF,
"service.gateway",
)
svc.debugModelRouting.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING")))
svc.debugClaudeMimic.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC")))
return svc
@@ -6332,63 +6340,20 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
}
func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 {
if s == nil || userID <= 0 || groupID <= 0 {
if s == nil {
return groupDefaultMultiplier
}
key := fmt.Sprintf("%d:%d", userID, groupID)
if s.userGroupRateCache != nil {
if cached, ok := s.userGroupRateCache.Get(key); ok {
if multiplier, castOK := cached.(float64); castOK {
userGroupRateCacheHitTotal.Add(1)
return multiplier
}
}
resolver := s.userGroupRateResolver
if resolver == nil {
resolver = newUserGroupRateResolver(
s.userGroupRateRepo,
s.userGroupRateCache,
resolveUserGroupRateCacheTTL(s.cfg),
&s.userGroupRateSF,
"service.gateway",
)
}
if s.userGroupRateRepo == nil {
return groupDefaultMultiplier
}
userGroupRateCacheMissTotal.Add(1)
value, err, shared := s.userGroupRateSF.Do(key, func() (any, error) {
if s.userGroupRateCache != nil {
if cached, ok := s.userGroupRateCache.Get(key); ok {
if multiplier, castOK := cached.(float64); castOK {
userGroupRateCacheHitTotal.Add(1)
return multiplier, nil
}
}
}
userGroupRateCacheLoadTotal.Add(1)
userRate, repoErr := s.userGroupRateRepo.GetByUserAndGroup(ctx, userID, groupID)
if repoErr != nil {
return nil, repoErr
}
multiplier := groupDefaultMultiplier
if userRate != nil {
multiplier = *userRate
}
if s.userGroupRateCache != nil {
s.userGroupRateCache.Set(key, multiplier, resolveUserGroupRateCacheTTL(s.cfg))
}
return multiplier, nil
})
if shared {
userGroupRateCacheSFSharedTotal.Add(1)
}
if err != nil {
userGroupRateCacheFallbackTotal.Add(1)
logger.LegacyPrintf("service.gateway", "get user group rate failed, fallback to group default: user=%d group=%d err=%v", userID, groupID, err)
return groupDefaultMultiplier
}
multiplier, ok := value.(float64)
if !ok {
userGroupRateCacheFallbackTotal.Add(1)
return groupDefaultMultiplier
}
return multiplier
return resolver.Resolve(ctx, userID, groupID, groupDefaultMultiplier)
}
// RecordUsageInput 记录使用量的输入参数

View File

@@ -0,0 +1,338 @@
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type openAIRecordUsageLogRepoStub struct {
UsageLogRepository
inserted bool
err error
calls int
lastLog *UsageLog
}
func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) {
s.calls++
s.lastLog = log
return s.inserted, s.err
}
type openAIRecordUsageUserRepoStub struct {
UserRepository
deductCalls int
deductErr error
lastAmount float64
}
func (s *openAIRecordUsageUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error {
s.deductCalls++
s.lastAmount = amount
return s.deductErr
}
type openAIRecordUsageSubRepoStub struct {
UserSubscriptionRepository
incrementCalls int
incrementErr error
lastAmount float64
}
func (s *openAIRecordUsageSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
s.incrementCalls++
s.lastAmount = costUSD
return s.incrementErr
}
type openAIRecordUsageAPIKeyQuotaStub struct {
quotaCalls int
rateLimitCalls int
err error
lastAmount float64
}
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error {
s.quotaCalls++
s.lastAmount = cost
return s.err
}
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error {
s.rateLimitCalls++
s.lastAmount = cost
return s.err
}
type openAIUserGroupRateRepoStub struct {
UserGroupRateRepository
rate *float64
err error
calls int
}
func (s *openAIUserGroupRateRepoStub) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
s.calls++
if s.err != nil {
return nil, s.err
}
return s.rate, nil
}
func i64p(v int64) *int64 {
return &v
}
func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService {
cfg := &config.Config{}
cfg.Default.RateMultiplier = 1.1
return &OpenAIGatewayService{
usageLogRepo: usageRepo,
userRepo: userRepo,
userSubRepo: subRepo,
cfg: cfg,
billingService: NewBillingService(cfg, nil),
billingCacheService: &BillingCacheService{},
deferredService: &DeferredService{},
userGroupRateResolver: newUserGroupRateResolver(
rateRepo,
nil,
resolveUserGroupRateCacheTTL(cfg),
nil,
"service.openai_gateway.test",
),
}
}
func expectedOpenAICost(t *testing.T, svc *OpenAIGatewayService, model string, usage OpenAIUsage, multiplier float64) *CostBreakdown {
t.Helper()
cost, err := svc.billingService.CalculateCost(model, UsageTokens{
InputTokens: max(usage.InputTokens-usage.CacheReadInputTokens, 0),
OutputTokens: usage.OutputTokens,
CacheCreationTokens: usage.CacheCreationInputTokens,
CacheReadTokens: usage.CacheReadInputTokens,
}, multiplier)
require.NoError(t, err)
return cost
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) {
groupID := int64(11)
groupRate := 1.4
userRate := 1.8
usage := OpenAIUsage{InputTokens: 15, OutputTokens: 4, CacheReadInputTokens: 3}
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
rateRepo := &openAIUserGroupRateRepoStub{rate: &userRate}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_user_group_rate",
Usage: usage,
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 1001,
GroupID: i64p(groupID),
Group: &Group{
ID: groupID,
RateMultiplier: groupRate,
},
},
User: &User{ID: 2001},
Account: &Account{ID: 3001},
})
require.NoError(t, err)
require.Equal(t, 1, rateRepo.calls)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, userRate, usageRepo.lastLog.RateMultiplier)
require.Equal(t, 12, usageRepo.lastLog.InputTokens)
require.Equal(t, 3, usageRepo.lastLog.CacheReadTokens)
expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, userRate)
require.InDelta(t, expected.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
require.InDelta(t, expected.ActualCost, userRepo.lastAmount, 1e-12)
require.Equal(t, 1, userRepo.deductCalls)
}
func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateOnResolverError(t *testing.T) {
groupID := int64(12)
groupRate := 1.6
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 5, CacheReadInputTokens: 2}
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
rateRepo := &openAIUserGroupRateRepoStub{err: errors.New("db unavailable")}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_group_default_on_error",
Usage: usage,
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 1002,
GroupID: i64p(groupID),
Group: &Group{
ID: groupID,
RateMultiplier: groupRate,
},
},
User: &User{ID: 2002},
Account: &Account{ID: 3002},
})
require.NoError(t, err)
require.Equal(t, 1, rateRepo.calls)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, groupRate, usageRepo.lastLog.RateMultiplier)
expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, groupRate)
require.InDelta(t, expected.ActualCost, userRepo.lastAmount, 1e-12)
}
func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateWhenResolverMissing(t *testing.T) {
groupID := int64(13)
groupRate := 1.25
usage := OpenAIUsage{InputTokens: 9, OutputTokens: 4, CacheReadInputTokens: 1}
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
svc.userGroupRateResolver = nil
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_group_default_nil_resolver",
Usage: usage,
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 1003,
GroupID: i64p(groupID),
Group: &Group{
ID: groupID,
RateMultiplier: groupRate,
},
},
User: &User{ID: 2003},
Account: &Account{ID: 3003},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, groupRate, usageRepo.lastLog.RateMultiplier)
}
func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_duplicate",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 1004},
User: &User{ID: 2004},
Account: &Account{ID: 3004},
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 0, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls)
}
func TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured(t *testing.T) {
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2}
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_quota_update",
Usage: usage,
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 1005,
Quota: 100,
},
User: &User{ID: 2005},
Account: &Account{ID: 3005},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, quotaSvc.quotaCalls)
require.Equal(t, 0, quotaSvc.rateLimitCalls)
expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, 1.1)
require.InDelta(t, expected.ActualCost, quotaSvc.lastAmount, 1e-12)
}
func TestOpenAIGatewayServiceRecordUsage_ClampsActualInputTokensToZero(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_clamp_actual_input",
Usage: OpenAIUsage{
InputTokens: 2,
OutputTokens: 1,
CacheReadInputTokens: 5,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 1006},
User: &User{ID: 2006},
Account: &Account{ID: 3006},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, 0, usageRepo.lastLog.InputTokens)
}

View File

@@ -245,23 +245,24 @@ type openAIWSRetryMetrics struct {
// OpenAIGatewayService handles OpenAI API gateway operations
type OpenAIGatewayService struct {
accountRepo AccountRepository
usageLogRepo UsageLogRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
cache GatewayCache
cfg *config.Config
codexDetector CodexClientRestrictionDetector
schedulerSnapshot *SchedulerSnapshotService
concurrencyService *ConcurrencyService
billingService *BillingService
rateLimitService *RateLimitService
billingCacheService *BillingCacheService
httpUpstream HTTPUpstream
deferredService *DeferredService
openAITokenProvider *OpenAITokenProvider
toolCorrector *CodexToolCorrector
openaiWSResolver OpenAIWSProtocolResolver
accountRepo AccountRepository
usageLogRepo UsageLogRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
cache GatewayCache
cfg *config.Config
codexDetector CodexClientRestrictionDetector
schedulerSnapshot *SchedulerSnapshotService
concurrencyService *ConcurrencyService
billingService *BillingService
rateLimitService *RateLimitService
billingCacheService *BillingCacheService
userGroupRateResolver *userGroupRateResolver
httpUpstream HTTPUpstream
deferredService *DeferredService
openAITokenProvider *OpenAITokenProvider
toolCorrector *CodexToolCorrector
openaiWSResolver OpenAIWSProtocolResolver
openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once
@@ -284,6 +285,7 @@ func NewOpenAIGatewayService(
usageLogRepo UsageLogRepository,
userRepo UserRepository,
userSubRepo UserSubscriptionRepository,
userGroupRateRepo UserGroupRateRepository,
cache GatewayCache,
cfg *config.Config,
schedulerSnapshot *SchedulerSnapshotService,
@@ -296,18 +298,25 @@ func NewOpenAIGatewayService(
openAITokenProvider *OpenAITokenProvider,
) *OpenAIGatewayService {
svc := &OpenAIGatewayService{
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
codexDetector: NewOpenAICodexClientRestrictionDetector(cfg),
schedulerSnapshot: schedulerSnapshot,
concurrencyService: concurrencyService,
billingService: billingService,
rateLimitService: rateLimitService,
billingCacheService: billingCacheService,
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
codexDetector: NewOpenAICodexClientRestrictionDetector(cfg),
schedulerSnapshot: schedulerSnapshot,
concurrencyService: concurrencyService,
billingService: billingService,
rateLimitService: rateLimitService,
billingCacheService: billingCacheService,
userGroupRateResolver: newUserGroupRateResolver(
userGroupRateRepo,
nil,
resolveUserGroupRateCacheTTL(cfg),
nil,
"service.openai_gateway",
),
httpUpstream: httpUpstream,
deferredService: deferredService,
openAITokenProvider: openAITokenProvider,
@@ -3261,6 +3270,14 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
// Correct tool calls in final response
body = s.correctToolCallsInResponseBody(body)
} else {
terminalType, terminalPayload, terminalOK := extractOpenAISSETerminalEvent(bodyText)
if terminalOK && terminalType == "response.failed" {
msg := extractOpenAISSEErrorMessage(terminalPayload)
if msg == "" {
msg = "Upstream compact response failed"
}
return nil, s.writeOpenAINonStreamingProtocolError(resp, c, msg)
}
usage = s.parseSSEUsageFromBody(bodyText)
if originalModel != mappedModel {
bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel)
@@ -3282,6 +3299,51 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
return usage, nil
}
func extractOpenAISSETerminalEvent(body string) (string, []byte, bool) {
lines := strings.Split(body, "\n")
for _, line := range lines {
data, ok := extractOpenAISSEDataLine(line)
if !ok || data == "" || data == "[DONE]" {
continue
}
eventType := strings.TrimSpace(gjson.Get(data, "type").String())
switch eventType {
case "response.completed", "response.done", "response.failed":
return eventType, []byte(data), true
}
}
return "", nil, false
}
func extractOpenAISSEErrorMessage(payload []byte) string {
if len(payload) == 0 {
return ""
}
for _, path := range []string{"response.error.message", "error.message", "message"} {
if msg := strings.TrimSpace(gjson.GetBytes(payload, path).String()); msg != "" {
return sanitizeUpstreamErrorMessage(msg)
}
}
return sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(payload)))
}
func (s *OpenAIGatewayService) writeOpenAINonStreamingProtocolError(resp *http.Response, c *gin.Context, message string) error {
message = sanitizeUpstreamErrorMessage(strings.TrimSpace(message))
if message == "" {
message = "Upstream returned an invalid non-streaming response"
}
setOpsUpstreamError(c, http.StatusBadGateway, message, "")
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8")
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": message,
},
})
return fmt.Errorf("non-streaming openai protocol error: %s", message)
}
func extractCodexFinalResponse(body string) ([]byte, bool) {
lines := strings.Split(body, "\n")
for _, line := range lines {
@@ -3413,7 +3475,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Get rate multiplier
multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil {
multiplier = apiKey.Group.RateMultiplier
resolver := s.userGroupRateResolver
if resolver == nil {
resolver = newUserGroupRateResolver(nil, nil, resolveUserGroupRateCacheTTL(s.cfg), nil, "service.openai_gateway")
}
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
}
cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier)

View File

@@ -1576,3 +1576,27 @@ func TestHandleOAuthSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) {
require.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream")
require.Contains(t, rec.Body.String(), `data: {"type":"response.in_progress"`)
}
func TestHandleOAuthSSEToJSON_ResponseFailedReturnsProtocolError(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
svc := &OpenAIGatewayService{cfg: &config.Config{}}
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
}
body := []byte(strings.Join([]string{
`data: {"type":"response.failed","error":{"message":"upstream rejected request"}}`,
`data: [DONE]`,
}, "\n"))
usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o")
require.Nil(t, usage)
require.Error(t, err)
require.Equal(t, http.StatusBadGateway, rec.Code)
require.Contains(t, rec.Body.String(), "upstream rejected request")
require.Contains(t, rec.Header().Get("Content-Type"), "application/json")
}

View File

@@ -391,6 +391,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil,
nil,
nil,
nil,
cfg,
nil,
nil,

View File

@@ -0,0 +1,103 @@
package service
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
gocache "github.com/patrickmn/go-cache"
"golang.org/x/sync/singleflight"
)
type userGroupRateResolver struct {
repo UserGroupRateRepository
cache *gocache.Cache
cacheTTL time.Duration
sf *singleflight.Group
logComponent string
}
func newUserGroupRateResolver(repo UserGroupRateRepository, cache *gocache.Cache, cacheTTL time.Duration, sf *singleflight.Group, logComponent string) *userGroupRateResolver {
if cacheTTL <= 0 {
cacheTTL = defaultUserGroupRateCacheTTL
}
if cache == nil {
cache = gocache.New(cacheTTL, time.Minute)
}
if logComponent == "" {
logComponent = "service.gateway"
}
if sf == nil {
sf = &singleflight.Group{}
}
return &userGroupRateResolver{
repo: repo,
cache: cache,
cacheTTL: cacheTTL,
sf: sf,
logComponent: logComponent,
}
}
func (r *userGroupRateResolver) Resolve(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 {
if r == nil || userID <= 0 || groupID <= 0 {
return groupDefaultMultiplier
}
key := fmt.Sprintf("%d:%d", userID, groupID)
if r.cache != nil {
if cached, ok := r.cache.Get(key); ok {
if multiplier, castOK := cached.(float64); castOK {
userGroupRateCacheHitTotal.Add(1)
return multiplier
}
}
}
if r.repo == nil {
return groupDefaultMultiplier
}
userGroupRateCacheMissTotal.Add(1)
value, err, shared := r.sf.Do(key, func() (any, error) {
if r.cache != nil {
if cached, ok := r.cache.Get(key); ok {
if multiplier, castOK := cached.(float64); castOK {
userGroupRateCacheHitTotal.Add(1)
return multiplier, nil
}
}
}
userGroupRateCacheLoadTotal.Add(1)
userRate, repoErr := r.repo.GetByUserAndGroup(ctx, userID, groupID)
if repoErr != nil {
return nil, repoErr
}
multiplier := groupDefaultMultiplier
if userRate != nil {
multiplier = *userRate
}
if r.cache != nil {
r.cache.Set(key, multiplier, r.cacheTTL)
}
return multiplier, nil
})
if shared {
userGroupRateCacheSFSharedTotal.Add(1)
}
if err != nil {
userGroupRateCacheFallbackTotal.Add(1)
logger.LegacyPrintf(r.logComponent, "get user group rate failed, fallback to group default: user=%d group=%d err=%v", userID, groupID, err)
return groupDefaultMultiplier
}
multiplier, ok := value.(float64)
if !ok {
userGroupRateCacheFallbackTotal.Add(1)
return groupDefaultMultiplier
}
return multiplier
}

View File

@@ -0,0 +1,83 @@
package service
import (
"context"
"testing"
"time"
gocache "github.com/patrickmn/go-cache"
"github.com/stretchr/testify/require"
)
type userGroupRateResolverRepoStub struct {
UserGroupRateRepository
rate *float64
err error
calls int
}
func (s *userGroupRateResolverRepoStub) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
s.calls++
if s.err != nil {
return nil, s.err
}
return s.rate, nil
}
func TestNewUserGroupRateResolver_Defaults(t *testing.T) {
resolver := newUserGroupRateResolver(nil, nil, 0, nil, "")
require.NotNil(t, resolver)
require.NotNil(t, resolver.cache)
require.Equal(t, defaultUserGroupRateCacheTTL, resolver.cacheTTL)
require.NotNil(t, resolver.sf)
require.Equal(t, "service.gateway", resolver.logComponent)
}
func TestUserGroupRateResolverResolve_FallbackForNilResolverAndInvalidIDs(t *testing.T) {
var nilResolver *userGroupRateResolver
require.Equal(t, 1.4, nilResolver.Resolve(context.Background(), 101, 202, 1.4))
resolver := newUserGroupRateResolver(nil, nil, time.Second, nil, "service.test")
require.Equal(t, 1.4, resolver.Resolve(context.Background(), 0, 202, 1.4))
require.Equal(t, 1.4, resolver.Resolve(context.Background(), 101, 0, 1.4))
}
func TestUserGroupRateResolverResolve_InvalidCacheEntryLoadsRepoAndCaches(t *testing.T) {
resetGatewayHotpathStatsForTest()
rate := 1.7
repo := &userGroupRateResolverRepoStub{rate: &rate}
cache := gocache.New(time.Minute, time.Minute)
cache.Set("101:202", "bad-cache", time.Minute)
resolver := newUserGroupRateResolver(repo, cache, time.Minute, nil, "service.test")
got := resolver.Resolve(context.Background(), 101, 202, 1.2)
require.Equal(t, rate, got)
require.Equal(t, 1, repo.calls)
cached, ok := cache.Get("101:202")
require.True(t, ok)
require.Equal(t, rate, cached)
hit, miss, load, _, fallback := GatewayUserGroupRateCacheStats()
require.Equal(t, int64(0), hit)
require.Equal(t, int64(1), miss)
require.Equal(t, int64(1), load)
require.Equal(t, int64(0), fallback)
}
func TestGatewayServiceGetUserGroupRateMultiplier_FallbacksAndUsesExistingResolver(t *testing.T) {
var nilSvc *GatewayService
require.Equal(t, 1.3, nilSvc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.3))
rate := 1.9
repo := &userGroupRateResolverRepoStub{rate: &rate}
resolver := newUserGroupRateResolver(repo, nil, time.Minute, nil, "service.gateway")
svc := &GatewayService{userGroupRateResolver: resolver}
got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.2)
require.Equal(t, rate, got)
require.Equal(t, 1, repo.calls)
}