Merge pull request #815 from mt21625457/pr/openai-user-group-rate-upstream
fix(openai): 统一专属倍率计费链路并补齐回归测试
This commit is contained in:
@@ -164,7 +164,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore)
|
||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||
|
||||
@@ -521,6 +521,7 @@ type GatewayService struct {
|
||||
claudeTokenProvider *ClaudeTokenProvider
|
||||
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
|
||||
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken)
|
||||
userGroupRateResolver *userGroupRateResolver
|
||||
userGroupRateCache *gocache.Cache
|
||||
userGroupRateSF singleflight.Group
|
||||
modelsListCache *gocache.Cache
|
||||
@@ -582,6 +583,13 @@ func NewGatewayService(
|
||||
modelsListCacheTTL: modelsListTTL,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
}
|
||||
svc.userGroupRateResolver = newUserGroupRateResolver(
|
||||
userGroupRateRepo,
|
||||
svc.userGroupRateCache,
|
||||
userGroupRateTTL,
|
||||
&svc.userGroupRateSF,
|
||||
"service.gateway",
|
||||
)
|
||||
svc.debugModelRouting.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING")))
|
||||
svc.debugClaudeMimic.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC")))
|
||||
return svc
|
||||
@@ -6336,63 +6344,20 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
|
||||
}
|
||||
|
||||
func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 {
|
||||
if s == nil || userID <= 0 || groupID <= 0 {
|
||||
if s == nil {
|
||||
return groupDefaultMultiplier
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%d:%d", userID, groupID)
|
||||
if s.userGroupRateCache != nil {
|
||||
if cached, ok := s.userGroupRateCache.Get(key); ok {
|
||||
if multiplier, castOK := cached.(float64); castOK {
|
||||
userGroupRateCacheHitTotal.Add(1)
|
||||
return multiplier
|
||||
resolver := s.userGroupRateResolver
|
||||
if resolver == nil {
|
||||
resolver = newUserGroupRateResolver(
|
||||
s.userGroupRateRepo,
|
||||
s.userGroupRateCache,
|
||||
resolveUserGroupRateCacheTTL(s.cfg),
|
||||
&s.userGroupRateSF,
|
||||
"service.gateway",
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
if s.userGroupRateRepo == nil {
|
||||
return groupDefaultMultiplier
|
||||
}
|
||||
userGroupRateCacheMissTotal.Add(1)
|
||||
|
||||
value, err, shared := s.userGroupRateSF.Do(key, func() (any, error) {
|
||||
if s.userGroupRateCache != nil {
|
||||
if cached, ok := s.userGroupRateCache.Get(key); ok {
|
||||
if multiplier, castOK := cached.(float64); castOK {
|
||||
userGroupRateCacheHitTotal.Add(1)
|
||||
return multiplier, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
userGroupRateCacheLoadTotal.Add(1)
|
||||
userRate, repoErr := s.userGroupRateRepo.GetByUserAndGroup(ctx, userID, groupID)
|
||||
if repoErr != nil {
|
||||
return nil, repoErr
|
||||
}
|
||||
multiplier := groupDefaultMultiplier
|
||||
if userRate != nil {
|
||||
multiplier = *userRate
|
||||
}
|
||||
if s.userGroupRateCache != nil {
|
||||
s.userGroupRateCache.Set(key, multiplier, resolveUserGroupRateCacheTTL(s.cfg))
|
||||
}
|
||||
return multiplier, nil
|
||||
})
|
||||
if shared {
|
||||
userGroupRateCacheSFSharedTotal.Add(1)
|
||||
}
|
||||
if err != nil {
|
||||
userGroupRateCacheFallbackTotal.Add(1)
|
||||
logger.LegacyPrintf("service.gateway", "get user group rate failed, fallback to group default: user=%d group=%d err=%v", userID, groupID, err)
|
||||
return groupDefaultMultiplier
|
||||
}
|
||||
|
||||
multiplier, ok := value.(float64)
|
||||
if !ok {
|
||||
userGroupRateCacheFallbackTotal.Add(1)
|
||||
return groupDefaultMultiplier
|
||||
}
|
||||
return multiplier
|
||||
return resolver.Resolve(ctx, userID, groupID, groupDefaultMultiplier)
|
||||
}
|
||||
|
||||
// RecordUsageInput 记录使用量的输入参数
|
||||
|
||||
336
backend/internal/service/openai_gateway_record_usage_test.go
Normal file
336
backend/internal/service/openai_gateway_record_usage_test.go
Normal file
@@ -0,0 +1,336 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openAIRecordUsageLogRepoStub struct {
|
||||
UsageLogRepository
|
||||
|
||||
inserted bool
|
||||
err error
|
||||
calls int
|
||||
lastLog *UsageLog
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) {
|
||||
s.calls++
|
||||
s.lastLog = log
|
||||
return s.inserted, s.err
|
||||
}
|
||||
|
||||
type openAIRecordUsageUserRepoStub struct {
|
||||
UserRepository
|
||||
|
||||
deductCalls int
|
||||
deductErr error
|
||||
lastAmount float64
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error {
|
||||
s.deductCalls++
|
||||
s.lastAmount = amount
|
||||
return s.deductErr
|
||||
}
|
||||
|
||||
type openAIRecordUsageSubRepoStub struct {
|
||||
UserSubscriptionRepository
|
||||
|
||||
incrementCalls int
|
||||
incrementErr error
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
||||
s.incrementCalls++
|
||||
return s.incrementErr
|
||||
}
|
||||
|
||||
type openAIRecordUsageAPIKeyQuotaStub struct {
|
||||
quotaCalls int
|
||||
rateLimitCalls int
|
||||
err error
|
||||
lastAmount float64
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error {
|
||||
s.quotaCalls++
|
||||
s.lastAmount = cost
|
||||
return s.err
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error {
|
||||
s.rateLimitCalls++
|
||||
s.lastAmount = cost
|
||||
return s.err
|
||||
}
|
||||
|
||||
type openAIUserGroupRateRepoStub struct {
|
||||
UserGroupRateRepository
|
||||
|
||||
rate *float64
|
||||
err error
|
||||
calls int
|
||||
}
|
||||
|
||||
func (s *openAIUserGroupRateRepoStub) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
|
||||
s.calls++
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
return s.rate, nil
|
||||
}
|
||||
|
||||
func i64p(v int64) *int64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService {
|
||||
cfg := &config.Config{}
|
||||
cfg.Default.RateMultiplier = 1.1
|
||||
|
||||
return &OpenAIGatewayService{
|
||||
usageLogRepo: usageRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: subRepo,
|
||||
cfg: cfg,
|
||||
billingService: NewBillingService(cfg, nil),
|
||||
billingCacheService: &BillingCacheService{},
|
||||
deferredService: &DeferredService{},
|
||||
userGroupRateResolver: newUserGroupRateResolver(
|
||||
rateRepo,
|
||||
nil,
|
||||
resolveUserGroupRateCacheTTL(cfg),
|
||||
nil,
|
||||
"service.openai_gateway.test",
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func expectedOpenAICost(t *testing.T, svc *OpenAIGatewayService, model string, usage OpenAIUsage, multiplier float64) *CostBreakdown {
|
||||
t.Helper()
|
||||
|
||||
cost, err := svc.billingService.CalculateCost(model, UsageTokens{
|
||||
InputTokens: max(usage.InputTokens-usage.CacheReadInputTokens, 0),
|
||||
OutputTokens: usage.OutputTokens,
|
||||
CacheCreationTokens: usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: usage.CacheReadInputTokens,
|
||||
}, multiplier)
|
||||
require.NoError(t, err)
|
||||
return cost
|
||||
}
|
||||
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) {
|
||||
groupID := int64(11)
|
||||
groupRate := 1.4
|
||||
userRate := 1.8
|
||||
usage := OpenAIUsage{InputTokens: 15, OutputTokens: 4, CacheReadInputTokens: 3}
|
||||
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
rateRepo := &openAIUserGroupRateRepoStub{rate: &userRate}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_user_group_rate",
|
||||
Usage: usage,
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 1001,
|
||||
GroupID: i64p(groupID),
|
||||
Group: &Group{
|
||||
ID: groupID,
|
||||
RateMultiplier: groupRate,
|
||||
},
|
||||
},
|
||||
User: &User{ID: 2001},
|
||||
Account: &Account{ID: 3001},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, rateRepo.calls)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, userRate, usageRepo.lastLog.RateMultiplier)
|
||||
require.Equal(t, 12, usageRepo.lastLog.InputTokens)
|
||||
require.Equal(t, 3, usageRepo.lastLog.CacheReadTokens)
|
||||
|
||||
expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, userRate)
|
||||
require.InDelta(t, expected.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
|
||||
require.InDelta(t, expected.ActualCost, userRepo.lastAmount, 1e-12)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateOnResolverError(t *testing.T) {
|
||||
groupID := int64(12)
|
||||
groupRate := 1.6
|
||||
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 5, CacheReadInputTokens: 2}
|
||||
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
rateRepo := &openAIUserGroupRateRepoStub{err: errors.New("db unavailable")}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_group_default_on_error",
|
||||
Usage: usage,
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 1002,
|
||||
GroupID: i64p(groupID),
|
||||
Group: &Group{
|
||||
ID: groupID,
|
||||
RateMultiplier: groupRate,
|
||||
},
|
||||
},
|
||||
User: &User{ID: 2002},
|
||||
Account: &Account{ID: 3002},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, rateRepo.calls)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, groupRate, usageRepo.lastLog.RateMultiplier)
|
||||
|
||||
expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, groupRate)
|
||||
require.InDelta(t, expected.ActualCost, userRepo.lastAmount, 1e-12)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateWhenResolverMissing(t *testing.T) {
|
||||
groupID := int64(13)
|
||||
groupRate := 1.25
|
||||
usage := OpenAIUsage{InputTokens: 9, OutputTokens: 4, CacheReadInputTokens: 1}
|
||||
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
svc.userGroupRateResolver = nil
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_group_default_nil_resolver",
|
||||
Usage: usage,
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 1003,
|
||||
GroupID: i64p(groupID),
|
||||
Group: &Group{
|
||||
ID: groupID,
|
||||
RateMultiplier: groupRate,
|
||||
},
|
||||
},
|
||||
User: &User{ID: 2003},
|
||||
Account: &Account{ID: 3003},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, groupRate, usageRepo.lastLog.RateMultiplier)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_duplicate",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 8,
|
||||
OutputTokens: 4,
|
||||
},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 1004},
|
||||
User: &User{ID: 2004},
|
||||
Account: &Account{ID: 3004},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 0, userRepo.deductCalls)
|
||||
require.Equal(t, 0, subRepo.incrementCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured(t *testing.T) {
|
||||
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2}
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_quota_update",
|
||||
Usage: usage,
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 1005,
|
||||
Quota: 100,
|
||||
},
|
||||
User: &User{ID: 2005},
|
||||
Account: &Account{ID: 3005},
|
||||
APIKeyService: quotaSvc,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||
require.Equal(t, 0, quotaSvc.rateLimitCalls)
|
||||
expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, 1.1)
|
||||
require.InDelta(t, expected.ActualCost, quotaSvc.lastAmount, 1e-12)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ClampsActualInputTokensToZero(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_clamp_actual_input",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 2,
|
||||
OutputTokens: 1,
|
||||
CacheReadInputTokens: 5,
|
||||
},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 1006},
|
||||
User: &User{ID: 2006},
|
||||
Account: &Account{ID: 3006},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, 0, usageRepo.lastLog.InputTokens)
|
||||
}
|
||||
@@ -257,6 +257,7 @@ type OpenAIGatewayService struct {
|
||||
billingService *BillingService
|
||||
rateLimitService *RateLimitService
|
||||
billingCacheService *BillingCacheService
|
||||
userGroupRateResolver *userGroupRateResolver
|
||||
httpUpstream HTTPUpstream
|
||||
deferredService *DeferredService
|
||||
openAITokenProvider *OpenAITokenProvider
|
||||
@@ -284,6 +285,7 @@ func NewOpenAIGatewayService(
|
||||
usageLogRepo UsageLogRepository,
|
||||
userRepo UserRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
userGroupRateRepo UserGroupRateRepository,
|
||||
cache GatewayCache,
|
||||
cfg *config.Config,
|
||||
schedulerSnapshot *SchedulerSnapshotService,
|
||||
@@ -308,6 +310,13 @@ func NewOpenAIGatewayService(
|
||||
billingService: billingService,
|
||||
rateLimitService: rateLimitService,
|
||||
billingCacheService: billingCacheService,
|
||||
userGroupRateResolver: newUserGroupRateResolver(
|
||||
userGroupRateRepo,
|
||||
nil,
|
||||
resolveUserGroupRateCacheTTL(cfg),
|
||||
nil,
|
||||
"service.openai_gateway",
|
||||
),
|
||||
httpUpstream: httpUpstream,
|
||||
deferredService: deferredService,
|
||||
openAITokenProvider: openAITokenProvider,
|
||||
@@ -3261,6 +3270,14 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
|
||||
// Correct tool calls in final response
|
||||
body = s.correctToolCallsInResponseBody(body)
|
||||
} else {
|
||||
terminalType, terminalPayload, terminalOK := extractOpenAISSETerminalEvent(bodyText)
|
||||
if terminalOK && terminalType == "response.failed" {
|
||||
msg := extractOpenAISSEErrorMessage(terminalPayload)
|
||||
if msg == "" {
|
||||
msg = "Upstream compact response failed"
|
||||
}
|
||||
return nil, s.writeOpenAINonStreamingProtocolError(resp, c, msg)
|
||||
}
|
||||
usage = s.parseSSEUsageFromBody(bodyText)
|
||||
if originalModel != mappedModel {
|
||||
bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel)
|
||||
@@ -3282,6 +3299,51 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func extractOpenAISSETerminalEvent(body string) (string, []byte, bool) {
|
||||
lines := strings.Split(body, "\n")
|
||||
for _, line := range lines {
|
||||
data, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok || data == "" || data == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
eventType := strings.TrimSpace(gjson.Get(data, "type").String())
|
||||
switch eventType {
|
||||
case "response.completed", "response.done", "response.failed":
|
||||
return eventType, []byte(data), true
|
||||
}
|
||||
}
|
||||
return "", nil, false
|
||||
}
|
||||
|
||||
func extractOpenAISSEErrorMessage(payload []byte) string {
|
||||
if len(payload) == 0 {
|
||||
return ""
|
||||
}
|
||||
for _, path := range []string{"response.error.message", "error.message", "message"} {
|
||||
if msg := strings.TrimSpace(gjson.GetBytes(payload, path).String()); msg != "" {
|
||||
return sanitizeUpstreamErrorMessage(msg)
|
||||
}
|
||||
}
|
||||
return sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(payload)))
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) writeOpenAINonStreamingProtocolError(resp *http.Response, c *gin.Context, message string) error {
|
||||
message = sanitizeUpstreamErrorMessage(strings.TrimSpace(message))
|
||||
if message == "" {
|
||||
message = "Upstream returned an invalid non-streaming response"
|
||||
}
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, message, "")
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||
c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
return fmt.Errorf("non-streaming openai protocol error: %s", message)
|
||||
}
|
||||
|
||||
func extractCodexFinalResponse(body string) ([]byte, bool) {
|
||||
lines := strings.Split(body, "\n")
|
||||
for _, line := range lines {
|
||||
@@ -3413,7 +3475,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
// Get rate multiplier
|
||||
multiplier := s.cfg.Default.RateMultiplier
|
||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||
multiplier = apiKey.Group.RateMultiplier
|
||||
resolver := s.userGroupRateResolver
|
||||
if resolver == nil {
|
||||
resolver = newUserGroupRateResolver(nil, nil, resolveUserGroupRateCacheTTL(s.cfg), nil, "service.openai_gateway")
|
||||
}
|
||||
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
|
||||
}
|
||||
|
||||
cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier)
|
||||
|
||||
@@ -1576,3 +1576,27 @@ func TestHandleOAuthSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) {
|
||||
require.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream")
|
||||
require.Contains(t, rec.Body.String(), `data: {"type":"response.in_progress"`)
|
||||
}
|
||||
|
||||
func TestHandleOAuthSSEToJSON_ResponseFailedReturnsProtocolError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
svc := &OpenAIGatewayService{cfg: &config.Config{}}
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||
}
|
||||
body := []byte(strings.Join([]string{
|
||||
`data: {"type":"response.failed","error":{"message":"upstream rejected request"}}`,
|
||||
`data: [DONE]`,
|
||||
}, "\n"))
|
||||
|
||||
usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o")
|
||||
require.Nil(t, usage)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, http.StatusBadGateway, rec.Code)
|
||||
require.Contains(t, rec.Body.String(), "upstream rejected request")
|
||||
require.Contains(t, rec.Header().Get("Content-Type"), "application/json")
|
||||
}
|
||||
|
||||
@@ -391,6 +391,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
cfg,
|
||||
nil,
|
||||
nil,
|
||||
|
||||
103
backend/internal/service/user_group_rate_resolver.go
Normal file
103
backend/internal/service/user_group_rate_resolver.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
type userGroupRateResolver struct {
|
||||
repo UserGroupRateRepository
|
||||
cache *gocache.Cache
|
||||
cacheTTL time.Duration
|
||||
sf *singleflight.Group
|
||||
logComponent string
|
||||
}
|
||||
|
||||
func newUserGroupRateResolver(repo UserGroupRateRepository, cache *gocache.Cache, cacheTTL time.Duration, sf *singleflight.Group, logComponent string) *userGroupRateResolver {
|
||||
if cacheTTL <= 0 {
|
||||
cacheTTL = defaultUserGroupRateCacheTTL
|
||||
}
|
||||
if cache == nil {
|
||||
cache = gocache.New(cacheTTL, time.Minute)
|
||||
}
|
||||
if logComponent == "" {
|
||||
logComponent = "service.gateway"
|
||||
}
|
||||
if sf == nil {
|
||||
sf = &singleflight.Group{}
|
||||
}
|
||||
|
||||
return &userGroupRateResolver{
|
||||
repo: repo,
|
||||
cache: cache,
|
||||
cacheTTL: cacheTTL,
|
||||
sf: sf,
|
||||
logComponent: logComponent,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *userGroupRateResolver) Resolve(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 {
|
||||
if r == nil || userID <= 0 || groupID <= 0 {
|
||||
return groupDefaultMultiplier
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%d:%d", userID, groupID)
|
||||
if r.cache != nil {
|
||||
if cached, ok := r.cache.Get(key); ok {
|
||||
if multiplier, castOK := cached.(float64); castOK {
|
||||
userGroupRateCacheHitTotal.Add(1)
|
||||
return multiplier
|
||||
}
|
||||
}
|
||||
}
|
||||
if r.repo == nil {
|
||||
return groupDefaultMultiplier
|
||||
}
|
||||
userGroupRateCacheMissTotal.Add(1)
|
||||
|
||||
value, err, shared := r.sf.Do(key, func() (any, error) {
|
||||
if r.cache != nil {
|
||||
if cached, ok := r.cache.Get(key); ok {
|
||||
if multiplier, castOK := cached.(float64); castOK {
|
||||
userGroupRateCacheHitTotal.Add(1)
|
||||
return multiplier, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
userGroupRateCacheLoadTotal.Add(1)
|
||||
userRate, repoErr := r.repo.GetByUserAndGroup(ctx, userID, groupID)
|
||||
if repoErr != nil {
|
||||
return nil, repoErr
|
||||
}
|
||||
|
||||
multiplier := groupDefaultMultiplier
|
||||
if userRate != nil {
|
||||
multiplier = *userRate
|
||||
}
|
||||
if r.cache != nil {
|
||||
r.cache.Set(key, multiplier, r.cacheTTL)
|
||||
}
|
||||
return multiplier, nil
|
||||
})
|
||||
if shared {
|
||||
userGroupRateCacheSFSharedTotal.Add(1)
|
||||
}
|
||||
if err != nil {
|
||||
userGroupRateCacheFallbackTotal.Add(1)
|
||||
logger.LegacyPrintf(r.logComponent, "get user group rate failed, fallback to group default: user=%d group=%d err=%v", userID, groupID, err)
|
||||
return groupDefaultMultiplier
|
||||
}
|
||||
|
||||
multiplier, ok := value.(float64)
|
||||
if !ok {
|
||||
userGroupRateCacheFallbackTotal.Add(1)
|
||||
return groupDefaultMultiplier
|
||||
}
|
||||
return multiplier
|
||||
}
|
||||
83
backend/internal/service/user_group_rate_resolver_test.go
Normal file
83
backend/internal/service/user_group_rate_resolver_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type userGroupRateResolverRepoStub struct {
|
||||
UserGroupRateRepository
|
||||
|
||||
rate *float64
|
||||
err error
|
||||
calls int
|
||||
}
|
||||
|
||||
func (s *userGroupRateResolverRepoStub) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
|
||||
s.calls++
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
return s.rate, nil
|
||||
}
|
||||
|
||||
func TestNewUserGroupRateResolver_Defaults(t *testing.T) {
|
||||
resolver := newUserGroupRateResolver(nil, nil, 0, nil, "")
|
||||
|
||||
require.NotNil(t, resolver)
|
||||
require.NotNil(t, resolver.cache)
|
||||
require.Equal(t, defaultUserGroupRateCacheTTL, resolver.cacheTTL)
|
||||
require.NotNil(t, resolver.sf)
|
||||
require.Equal(t, "service.gateway", resolver.logComponent)
|
||||
}
|
||||
|
||||
func TestUserGroupRateResolverResolve_FallbackForNilResolverAndInvalidIDs(t *testing.T) {
|
||||
var nilResolver *userGroupRateResolver
|
||||
require.Equal(t, 1.4, nilResolver.Resolve(context.Background(), 101, 202, 1.4))
|
||||
|
||||
resolver := newUserGroupRateResolver(nil, nil, time.Second, nil, "service.test")
|
||||
require.Equal(t, 1.4, resolver.Resolve(context.Background(), 0, 202, 1.4))
|
||||
require.Equal(t, 1.4, resolver.Resolve(context.Background(), 101, 0, 1.4))
|
||||
}
|
||||
|
||||
func TestUserGroupRateResolverResolve_InvalidCacheEntryLoadsRepoAndCaches(t *testing.T) {
|
||||
resetGatewayHotpathStatsForTest()
|
||||
|
||||
rate := 1.7
|
||||
repo := &userGroupRateResolverRepoStub{rate: &rate}
|
||||
cache := gocache.New(time.Minute, time.Minute)
|
||||
cache.Set("101:202", "bad-cache", time.Minute)
|
||||
resolver := newUserGroupRateResolver(repo, cache, time.Minute, nil, "service.test")
|
||||
|
||||
got := resolver.Resolve(context.Background(), 101, 202, 1.2)
|
||||
require.Equal(t, rate, got)
|
||||
require.Equal(t, 1, repo.calls)
|
||||
|
||||
cached, ok := cache.Get("101:202")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, rate, cached)
|
||||
|
||||
hit, miss, load, _, fallback := GatewayUserGroupRateCacheStats()
|
||||
require.Equal(t, int64(0), hit)
|
||||
require.Equal(t, int64(1), miss)
|
||||
require.Equal(t, int64(1), load)
|
||||
require.Equal(t, int64(0), fallback)
|
||||
}
|
||||
|
||||
func TestGatewayServiceGetUserGroupRateMultiplier_FallbacksAndUsesExistingResolver(t *testing.T) {
|
||||
var nilSvc *GatewayService
|
||||
require.Equal(t, 1.3, nilSvc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.3))
|
||||
|
||||
rate := 1.9
|
||||
repo := &userGroupRateResolverRepoStub{rate: &rate}
|
||||
resolver := newUserGroupRateResolver(repo, nil, time.Minute, nil, "service.gateway")
|
||||
svc := &GatewayService{userGroupRateResolver: resolver}
|
||||
|
||||
got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.2)
|
||||
require.Equal(t, rate, got)
|
||||
require.Equal(t, 1, repo.calls)
|
||||
}
|
||||
Reference in New Issue
Block a user