fix(openai): 统一专属倍率计费链路并补齐回归测试

抽取共享的用户分组专属倍率解析器,统一缓存、singleflight 与回退逻辑。\n\n让 OpenAI 独立计费链路复用专属倍率解析,修复 usage 记录与实际扣费未命中用户专属倍率的问题。\n\n补齐 OpenAI 计费与解析器单元测试,并修复全量回归中暴露的 lint 阻塞项。\n\nCo-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2026-03-06 14:54:52 +08:00
parent 63a8c76946
commit a18bbb5f2f
8 changed files with 692 additions and 112 deletions

View File

@@ -245,23 +245,24 @@ type openAIWSRetryMetrics struct {
// OpenAIGatewayService handles OpenAI API gateway operations
type OpenAIGatewayService struct {
accountRepo AccountRepository
usageLogRepo UsageLogRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
cache GatewayCache
cfg *config.Config
codexDetector CodexClientRestrictionDetector
schedulerSnapshot *SchedulerSnapshotService
concurrencyService *ConcurrencyService
billingService *BillingService
rateLimitService *RateLimitService
billingCacheService *BillingCacheService
httpUpstream HTTPUpstream
deferredService *DeferredService
openAITokenProvider *OpenAITokenProvider
toolCorrector *CodexToolCorrector
openaiWSResolver OpenAIWSProtocolResolver
accountRepo AccountRepository
usageLogRepo UsageLogRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
cache GatewayCache
cfg *config.Config
codexDetector CodexClientRestrictionDetector
schedulerSnapshot *SchedulerSnapshotService
concurrencyService *ConcurrencyService
billingService *BillingService
rateLimitService *RateLimitService
billingCacheService *BillingCacheService
userGroupRateResolver *userGroupRateResolver
httpUpstream HTTPUpstream
deferredService *DeferredService
openAITokenProvider *OpenAITokenProvider
toolCorrector *CodexToolCorrector
openaiWSResolver OpenAIWSProtocolResolver
openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once
@@ -284,6 +285,7 @@ func NewOpenAIGatewayService(
usageLogRepo UsageLogRepository,
userRepo UserRepository,
userSubRepo UserSubscriptionRepository,
userGroupRateRepo UserGroupRateRepository,
cache GatewayCache,
cfg *config.Config,
schedulerSnapshot *SchedulerSnapshotService,
@@ -296,18 +298,25 @@ func NewOpenAIGatewayService(
openAITokenProvider *OpenAITokenProvider,
) *OpenAIGatewayService {
svc := &OpenAIGatewayService{
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
codexDetector: NewOpenAICodexClientRestrictionDetector(cfg),
schedulerSnapshot: schedulerSnapshot,
concurrencyService: concurrencyService,
billingService: billingService,
rateLimitService: rateLimitService,
billingCacheService: billingCacheService,
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
codexDetector: NewOpenAICodexClientRestrictionDetector(cfg),
schedulerSnapshot: schedulerSnapshot,
concurrencyService: concurrencyService,
billingService: billingService,
rateLimitService: rateLimitService,
billingCacheService: billingCacheService,
userGroupRateResolver: newUserGroupRateResolver(
userGroupRateRepo,
nil,
resolveUserGroupRateCacheTTL(cfg),
nil,
"service.openai_gateway",
),
httpUpstream: httpUpstream,
deferredService: deferredService,
openAITokenProvider: openAITokenProvider,
@@ -3261,6 +3270,14 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
// Correct tool calls in final response
body = s.correctToolCallsInResponseBody(body)
} else {
terminalType, terminalPayload, terminalOK := extractOpenAISSETerminalEvent(bodyText)
if terminalOK && terminalType == "response.failed" {
msg := extractOpenAISSEErrorMessage(terminalPayload)
if msg == "" {
msg = "Upstream compact response failed"
}
return nil, s.writeOpenAINonStreamingProtocolError(resp, c, msg)
}
usage = s.parseSSEUsageFromBody(bodyText)
if originalModel != mappedModel {
bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel)
@@ -3282,6 +3299,51 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
return usage, nil
}
func extractOpenAISSETerminalEvent(body string) (string, []byte, bool) {
lines := strings.Split(body, "\n")
for _, line := range lines {
data, ok := extractOpenAISSEDataLine(line)
if !ok || data == "" || data == "[DONE]" {
continue
}
eventType := strings.TrimSpace(gjson.Get(data, "type").String())
switch eventType {
case "response.completed", "response.done", "response.failed":
return eventType, []byte(data), true
}
}
return "", nil, false
}
func extractOpenAISSEErrorMessage(payload []byte) string {
if len(payload) == 0 {
return ""
}
for _, path := range []string{"response.error.message", "error.message", "message"} {
if msg := strings.TrimSpace(gjson.GetBytes(payload, path).String()); msg != "" {
return sanitizeUpstreamErrorMessage(msg)
}
}
return sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(payload)))
}
func (s *OpenAIGatewayService) writeOpenAINonStreamingProtocolError(resp *http.Response, c *gin.Context, message string) error {
message = sanitizeUpstreamErrorMessage(strings.TrimSpace(message))
if message == "" {
message = "Upstream returned an invalid non-streaming response"
}
setOpsUpstreamError(c, http.StatusBadGateway, message, "")
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8")
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": message,
},
})
return fmt.Errorf("non-streaming openai protocol error: %s", message)
}
func extractCodexFinalResponse(body string) ([]byte, bool) {
lines := strings.Split(body, "\n")
for _, line := range lines {
@@ -3413,7 +3475,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Get rate multiplier
multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil {
multiplier = apiKey.Group.RateMultiplier
resolver := s.userGroupRateResolver
if resolver == nil {
resolver = newUserGroupRateResolver(nil, nil, resolveUserGroupRateCacheTTL(s.cfg), nil, "service.openai_gateway")
}
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
}
cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier)