Merge pull request #815 from mt21625457/pr/openai-user-group-rate-upstream

fix(openai): 统一专属倍率计费链路并补齐回归测试
This commit is contained in:
Wesley Liddick
2026-03-06 17:33:09 +08:00
committed by GitHub
8 changed files with 690 additions and 112 deletions

View File

@@ -164,7 +164,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
digestSessionStore := service.NewDigestSessionStore() digestSessionStore := service.NewDigestSessionStore()
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore)
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService) openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)

View File

@@ -521,6 +521,7 @@ type GatewayService struct {
claudeTokenProvider *ClaudeTokenProvider claudeTokenProvider *ClaudeTokenProvider
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken
userGroupRateResolver *userGroupRateResolver
userGroupRateCache *gocache.Cache userGroupRateCache *gocache.Cache
userGroupRateSF singleflight.Group userGroupRateSF singleflight.Group
modelsListCache *gocache.Cache modelsListCache *gocache.Cache
@@ -582,6 +583,13 @@ func NewGatewayService(
modelsListCacheTTL: modelsListTTL, modelsListCacheTTL: modelsListTTL,
responseHeaderFilter: compileResponseHeaderFilter(cfg), responseHeaderFilter: compileResponseHeaderFilter(cfg),
} }
svc.userGroupRateResolver = newUserGroupRateResolver(
userGroupRateRepo,
svc.userGroupRateCache,
userGroupRateTTL,
&svc.userGroupRateSF,
"service.gateway",
)
svc.debugModelRouting.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING"))) svc.debugModelRouting.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING")))
svc.debugClaudeMimic.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC"))) svc.debugClaudeMimic.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC")))
return svc return svc
@@ -6336,63 +6344,20 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
} }
func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 { func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 {
if s == nil || userID <= 0 || groupID <= 0 { if s == nil {
return groupDefaultMultiplier return groupDefaultMultiplier
} }
resolver := s.userGroupRateResolver
key := fmt.Sprintf("%d:%d", userID, groupID) if resolver == nil {
if s.userGroupRateCache != nil { resolver = newUserGroupRateResolver(
if cached, ok := s.userGroupRateCache.Get(key); ok { s.userGroupRateRepo,
if multiplier, castOK := cached.(float64); castOK { s.userGroupRateCache,
userGroupRateCacheHitTotal.Add(1) resolveUserGroupRateCacheTTL(s.cfg),
return multiplier &s.userGroupRateSF,
"service.gateway",
)
} }
} return resolver.Resolve(ctx, userID, groupID, groupDefaultMultiplier)
}
if s.userGroupRateRepo == nil {
return groupDefaultMultiplier
}
userGroupRateCacheMissTotal.Add(1)
value, err, shared := s.userGroupRateSF.Do(key, func() (any, error) {
if s.userGroupRateCache != nil {
if cached, ok := s.userGroupRateCache.Get(key); ok {
if multiplier, castOK := cached.(float64); castOK {
userGroupRateCacheHitTotal.Add(1)
return multiplier, nil
}
}
}
userGroupRateCacheLoadTotal.Add(1)
userRate, repoErr := s.userGroupRateRepo.GetByUserAndGroup(ctx, userID, groupID)
if repoErr != nil {
return nil, repoErr
}
multiplier := groupDefaultMultiplier
if userRate != nil {
multiplier = *userRate
}
if s.userGroupRateCache != nil {
s.userGroupRateCache.Set(key, multiplier, resolveUserGroupRateCacheTTL(s.cfg))
}
return multiplier, nil
})
if shared {
userGroupRateCacheSFSharedTotal.Add(1)
}
if err != nil {
userGroupRateCacheFallbackTotal.Add(1)
logger.LegacyPrintf("service.gateway", "get user group rate failed, fallback to group default: user=%d group=%d err=%v", userID, groupID, err)
return groupDefaultMultiplier
}
multiplier, ok := value.(float64)
if !ok {
userGroupRateCacheFallbackTotal.Add(1)
return groupDefaultMultiplier
}
return multiplier
} }
// RecordUsageInput 记录使用量的输入参数 // RecordUsageInput 记录使用量的输入参数

View File

@@ -0,0 +1,336 @@
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type openAIRecordUsageLogRepoStub struct {
UsageLogRepository
inserted bool
err error
calls int
lastLog *UsageLog
}
func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) {
s.calls++
s.lastLog = log
return s.inserted, s.err
}
type openAIRecordUsageUserRepoStub struct {
UserRepository
deductCalls int
deductErr error
lastAmount float64
}
func (s *openAIRecordUsageUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error {
s.deductCalls++
s.lastAmount = amount
return s.deductErr
}
type openAIRecordUsageSubRepoStub struct {
UserSubscriptionRepository
incrementCalls int
incrementErr error
}
func (s *openAIRecordUsageSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
s.incrementCalls++
return s.incrementErr
}
type openAIRecordUsageAPIKeyQuotaStub struct {
quotaCalls int
rateLimitCalls int
err error
lastAmount float64
}
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error {
s.quotaCalls++
s.lastAmount = cost
return s.err
}
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error {
s.rateLimitCalls++
s.lastAmount = cost
return s.err
}
type openAIUserGroupRateRepoStub struct {
UserGroupRateRepository
rate *float64
err error
calls int
}
func (s *openAIUserGroupRateRepoStub) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
s.calls++
if s.err != nil {
return nil, s.err
}
return s.rate, nil
}
func i64p(v int64) *int64 {
return &v
}
func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService {
cfg := &config.Config{}
cfg.Default.RateMultiplier = 1.1
return &OpenAIGatewayService{
usageLogRepo: usageRepo,
userRepo: userRepo,
userSubRepo: subRepo,
cfg: cfg,
billingService: NewBillingService(cfg, nil),
billingCacheService: &BillingCacheService{},
deferredService: &DeferredService{},
userGroupRateResolver: newUserGroupRateResolver(
rateRepo,
nil,
resolveUserGroupRateCacheTTL(cfg),
nil,
"service.openai_gateway.test",
),
}
}
func expectedOpenAICost(t *testing.T, svc *OpenAIGatewayService, model string, usage OpenAIUsage, multiplier float64) *CostBreakdown {
t.Helper()
cost, err := svc.billingService.CalculateCost(model, UsageTokens{
InputTokens: max(usage.InputTokens-usage.CacheReadInputTokens, 0),
OutputTokens: usage.OutputTokens,
CacheCreationTokens: usage.CacheCreationInputTokens,
CacheReadTokens: usage.CacheReadInputTokens,
}, multiplier)
require.NoError(t, err)
return cost
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) {
groupID := int64(11)
groupRate := 1.4
userRate := 1.8
usage := OpenAIUsage{InputTokens: 15, OutputTokens: 4, CacheReadInputTokens: 3}
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
rateRepo := &openAIUserGroupRateRepoStub{rate: &userRate}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_user_group_rate",
Usage: usage,
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 1001,
GroupID: i64p(groupID),
Group: &Group{
ID: groupID,
RateMultiplier: groupRate,
},
},
User: &User{ID: 2001},
Account: &Account{ID: 3001},
})
require.NoError(t, err)
require.Equal(t, 1, rateRepo.calls)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, userRate, usageRepo.lastLog.RateMultiplier)
require.Equal(t, 12, usageRepo.lastLog.InputTokens)
require.Equal(t, 3, usageRepo.lastLog.CacheReadTokens)
expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, userRate)
require.InDelta(t, expected.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
require.InDelta(t, expected.ActualCost, userRepo.lastAmount, 1e-12)
require.Equal(t, 1, userRepo.deductCalls)
}
func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateOnResolverError(t *testing.T) {
groupID := int64(12)
groupRate := 1.6
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 5, CacheReadInputTokens: 2}
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
rateRepo := &openAIUserGroupRateRepoStub{err: errors.New("db unavailable")}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_group_default_on_error",
Usage: usage,
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 1002,
GroupID: i64p(groupID),
Group: &Group{
ID: groupID,
RateMultiplier: groupRate,
},
},
User: &User{ID: 2002},
Account: &Account{ID: 3002},
})
require.NoError(t, err)
require.Equal(t, 1, rateRepo.calls)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, groupRate, usageRepo.lastLog.RateMultiplier)
expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, groupRate)
require.InDelta(t, expected.ActualCost, userRepo.lastAmount, 1e-12)
}
func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateWhenResolverMissing(t *testing.T) {
groupID := int64(13)
groupRate := 1.25
usage := OpenAIUsage{InputTokens: 9, OutputTokens: 4, CacheReadInputTokens: 1}
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
svc.userGroupRateResolver = nil
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_group_default_nil_resolver",
Usage: usage,
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 1003,
GroupID: i64p(groupID),
Group: &Group{
ID: groupID,
RateMultiplier: groupRate,
},
},
User: &User{ID: 2003},
Account: &Account{ID: 3003},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, groupRate, usageRepo.lastLog.RateMultiplier)
}
func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_duplicate",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 1004},
User: &User{ID: 2004},
Account: &Account{ID: 3004},
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 0, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls)
}
func TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured(t *testing.T) {
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2}
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_quota_update",
Usage: usage,
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 1005,
Quota: 100,
},
User: &User{ID: 2005},
Account: &Account{ID: 3005},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, quotaSvc.quotaCalls)
require.Equal(t, 0, quotaSvc.rateLimitCalls)
expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, 1.1)
require.InDelta(t, expected.ActualCost, quotaSvc.lastAmount, 1e-12)
}
func TestOpenAIGatewayServiceRecordUsage_ClampsActualInputTokensToZero(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_clamp_actual_input",
Usage: OpenAIUsage{
InputTokens: 2,
OutputTokens: 1,
CacheReadInputTokens: 5,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 1006},
User: &User{ID: 2006},
Account: &Account{ID: 3006},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, 0, usageRepo.lastLog.InputTokens)
}

View File

@@ -257,6 +257,7 @@ type OpenAIGatewayService struct {
billingService *BillingService billingService *BillingService
rateLimitService *RateLimitService rateLimitService *RateLimitService
billingCacheService *BillingCacheService billingCacheService *BillingCacheService
userGroupRateResolver *userGroupRateResolver
httpUpstream HTTPUpstream httpUpstream HTTPUpstream
deferredService *DeferredService deferredService *DeferredService
openAITokenProvider *OpenAITokenProvider openAITokenProvider *OpenAITokenProvider
@@ -284,6 +285,7 @@ func NewOpenAIGatewayService(
usageLogRepo UsageLogRepository, usageLogRepo UsageLogRepository,
userRepo UserRepository, userRepo UserRepository,
userSubRepo UserSubscriptionRepository, userSubRepo UserSubscriptionRepository,
userGroupRateRepo UserGroupRateRepository,
cache GatewayCache, cache GatewayCache,
cfg *config.Config, cfg *config.Config,
schedulerSnapshot *SchedulerSnapshotService, schedulerSnapshot *SchedulerSnapshotService,
@@ -308,6 +310,13 @@ func NewOpenAIGatewayService(
billingService: billingService, billingService: billingService,
rateLimitService: rateLimitService, rateLimitService: rateLimitService,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
userGroupRateResolver: newUserGroupRateResolver(
userGroupRateRepo,
nil,
resolveUserGroupRateCacheTTL(cfg),
nil,
"service.openai_gateway",
),
httpUpstream: httpUpstream, httpUpstream: httpUpstream,
deferredService: deferredService, deferredService: deferredService,
openAITokenProvider: openAITokenProvider, openAITokenProvider: openAITokenProvider,
@@ -3261,6 +3270,14 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
// Correct tool calls in final response // Correct tool calls in final response
body = s.correctToolCallsInResponseBody(body) body = s.correctToolCallsInResponseBody(body)
} else { } else {
terminalType, terminalPayload, terminalOK := extractOpenAISSETerminalEvent(bodyText)
if terminalOK && terminalType == "response.failed" {
msg := extractOpenAISSEErrorMessage(terminalPayload)
if msg == "" {
msg = "Upstream compact response failed"
}
return nil, s.writeOpenAINonStreamingProtocolError(resp, c, msg)
}
usage = s.parseSSEUsageFromBody(bodyText) usage = s.parseSSEUsageFromBody(bodyText)
if originalModel != mappedModel { if originalModel != mappedModel {
bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel) bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel)
@@ -3282,6 +3299,51 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
return usage, nil return usage, nil
} }
func extractOpenAISSETerminalEvent(body string) (string, []byte, bool) {
lines := strings.Split(body, "\n")
for _, line := range lines {
data, ok := extractOpenAISSEDataLine(line)
if !ok || data == "" || data == "[DONE]" {
continue
}
eventType := strings.TrimSpace(gjson.Get(data, "type").String())
switch eventType {
case "response.completed", "response.done", "response.failed":
return eventType, []byte(data), true
}
}
return "", nil, false
}
func extractOpenAISSEErrorMessage(payload []byte) string {
if len(payload) == 0 {
return ""
}
for _, path := range []string{"response.error.message", "error.message", "message"} {
if msg := strings.TrimSpace(gjson.GetBytes(payload, path).String()); msg != "" {
return sanitizeUpstreamErrorMessage(msg)
}
}
return sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(payload)))
}
func (s *OpenAIGatewayService) writeOpenAINonStreamingProtocolError(resp *http.Response, c *gin.Context, message string) error {
message = sanitizeUpstreamErrorMessage(strings.TrimSpace(message))
if message == "" {
message = "Upstream returned an invalid non-streaming response"
}
setOpsUpstreamError(c, http.StatusBadGateway, message, "")
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8")
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": message,
},
})
return fmt.Errorf("non-streaming openai protocol error: %s", message)
}
func extractCodexFinalResponse(body string) ([]byte, bool) { func extractCodexFinalResponse(body string) ([]byte, bool) {
lines := strings.Split(body, "\n") lines := strings.Split(body, "\n")
for _, line := range lines { for _, line := range lines {
@@ -3413,7 +3475,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Get rate multiplier // Get rate multiplier
multiplier := s.cfg.Default.RateMultiplier multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil { if apiKey.GroupID != nil && apiKey.Group != nil {
multiplier = apiKey.Group.RateMultiplier resolver := s.userGroupRateResolver
if resolver == nil {
resolver = newUserGroupRateResolver(nil, nil, resolveUserGroupRateCacheTTL(s.cfg), nil, "service.openai_gateway")
}
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
} }
cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier) cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier)

View File

@@ -1576,3 +1576,27 @@ func TestHandleOAuthSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) {
require.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream") require.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream")
require.Contains(t, rec.Body.String(), `data: {"type":"response.in_progress"`) require.Contains(t, rec.Body.String(), `data: {"type":"response.in_progress"`)
} }
func TestHandleOAuthSSEToJSON_ResponseFailedReturnsProtocolError(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
svc := &OpenAIGatewayService{cfg: &config.Config{}}
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
}
body := []byte(strings.Join([]string{
`data: {"type":"response.failed","error":{"message":"upstream rejected request"}}`,
`data: [DONE]`,
}, "\n"))
usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o")
require.Nil(t, usage)
require.Error(t, err)
require.Equal(t, http.StatusBadGateway, rec.Code)
require.Contains(t, rec.Body.String(), "upstream rejected request")
require.Contains(t, rec.Header().Get("Content-Type"), "application/json")
}

View File

@@ -391,6 +391,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil, nil,
nil, nil,
nil, nil,
nil,
cfg, cfg,
nil, nil,
nil, nil,

View File

@@ -0,0 +1,103 @@
package service
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
gocache "github.com/patrickmn/go-cache"
"golang.org/x/sync/singleflight"
)
type userGroupRateResolver struct {
repo UserGroupRateRepository
cache *gocache.Cache
cacheTTL time.Duration
sf *singleflight.Group
logComponent string
}
func newUserGroupRateResolver(repo UserGroupRateRepository, cache *gocache.Cache, cacheTTL time.Duration, sf *singleflight.Group, logComponent string) *userGroupRateResolver {
if cacheTTL <= 0 {
cacheTTL = defaultUserGroupRateCacheTTL
}
if cache == nil {
cache = gocache.New(cacheTTL, time.Minute)
}
if logComponent == "" {
logComponent = "service.gateway"
}
if sf == nil {
sf = &singleflight.Group{}
}
return &userGroupRateResolver{
repo: repo,
cache: cache,
cacheTTL: cacheTTL,
sf: sf,
logComponent: logComponent,
}
}
func (r *userGroupRateResolver) Resolve(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 {
if r == nil || userID <= 0 || groupID <= 0 {
return groupDefaultMultiplier
}
key := fmt.Sprintf("%d:%d", userID, groupID)
if r.cache != nil {
if cached, ok := r.cache.Get(key); ok {
if multiplier, castOK := cached.(float64); castOK {
userGroupRateCacheHitTotal.Add(1)
return multiplier
}
}
}
if r.repo == nil {
return groupDefaultMultiplier
}
userGroupRateCacheMissTotal.Add(1)
value, err, shared := r.sf.Do(key, func() (any, error) {
if r.cache != nil {
if cached, ok := r.cache.Get(key); ok {
if multiplier, castOK := cached.(float64); castOK {
userGroupRateCacheHitTotal.Add(1)
return multiplier, nil
}
}
}
userGroupRateCacheLoadTotal.Add(1)
userRate, repoErr := r.repo.GetByUserAndGroup(ctx, userID, groupID)
if repoErr != nil {
return nil, repoErr
}
multiplier := groupDefaultMultiplier
if userRate != nil {
multiplier = *userRate
}
if r.cache != nil {
r.cache.Set(key, multiplier, r.cacheTTL)
}
return multiplier, nil
})
if shared {
userGroupRateCacheSFSharedTotal.Add(1)
}
if err != nil {
userGroupRateCacheFallbackTotal.Add(1)
logger.LegacyPrintf(r.logComponent, "get user group rate failed, fallback to group default: user=%d group=%d err=%v", userID, groupID, err)
return groupDefaultMultiplier
}
multiplier, ok := value.(float64)
if !ok {
userGroupRateCacheFallbackTotal.Add(1)
return groupDefaultMultiplier
}
return multiplier
}

View File

@@ -0,0 +1,83 @@
package service
import (
"context"
"testing"
"time"
gocache "github.com/patrickmn/go-cache"
"github.com/stretchr/testify/require"
)
type userGroupRateResolverRepoStub struct {
UserGroupRateRepository
rate *float64
err error
calls int
}
func (s *userGroupRateResolverRepoStub) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
s.calls++
if s.err != nil {
return nil, s.err
}
return s.rate, nil
}
func TestNewUserGroupRateResolver_Defaults(t *testing.T) {
resolver := newUserGroupRateResolver(nil, nil, 0, nil, "")
require.NotNil(t, resolver)
require.NotNil(t, resolver.cache)
require.Equal(t, defaultUserGroupRateCacheTTL, resolver.cacheTTL)
require.NotNil(t, resolver.sf)
require.Equal(t, "service.gateway", resolver.logComponent)
}
func TestUserGroupRateResolverResolve_FallbackForNilResolverAndInvalidIDs(t *testing.T) {
var nilResolver *userGroupRateResolver
require.Equal(t, 1.4, nilResolver.Resolve(context.Background(), 101, 202, 1.4))
resolver := newUserGroupRateResolver(nil, nil, time.Second, nil, "service.test")
require.Equal(t, 1.4, resolver.Resolve(context.Background(), 0, 202, 1.4))
require.Equal(t, 1.4, resolver.Resolve(context.Background(), 101, 0, 1.4))
}
func TestUserGroupRateResolverResolve_InvalidCacheEntryLoadsRepoAndCaches(t *testing.T) {
resetGatewayHotpathStatsForTest()
rate := 1.7
repo := &userGroupRateResolverRepoStub{rate: &rate}
cache := gocache.New(time.Minute, time.Minute)
cache.Set("101:202", "bad-cache", time.Minute)
resolver := newUserGroupRateResolver(repo, cache, time.Minute, nil, "service.test")
got := resolver.Resolve(context.Background(), 101, 202, 1.2)
require.Equal(t, rate, got)
require.Equal(t, 1, repo.calls)
cached, ok := cache.Get("101:202")
require.True(t, ok)
require.Equal(t, rate, cached)
hit, miss, load, _, fallback := GatewayUserGroupRateCacheStats()
require.Equal(t, int64(0), hit)
require.Equal(t, int64(1), miss)
require.Equal(t, int64(1), load)
require.Equal(t, int64(0), fallback)
}
func TestGatewayServiceGetUserGroupRateMultiplier_FallbacksAndUsesExistingResolver(t *testing.T) {
var nilSvc *GatewayService
require.Equal(t, 1.3, nilSvc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.3))
rate := 1.9
repo := &userGroupRateResolverRepoStub{rate: &rate}
resolver := newUserGroupRateResolver(repo, nil, time.Minute, nil, "service.gateway")
svc := &GatewayService{userGroupRateResolver: resolver}
got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.2)
require.Equal(t, rate, got)
require.Equal(t, 1, repo.calls)
}