fix(openai): 统一专属倍率计费链路并补齐回归测试
抽取共享的用户分组专属倍率解析器,统一缓存、singleflight 与回退逻辑。\n\n让 OpenAI 独立计费链路复用专属倍率解析,修复 usage 记录与实际扣费未命中用户专属倍率的问题。\n\n补齐 OpenAI 计费与解析器单元测试,并修复全量回归中暴露的 lint 阻塞项。\n\nCo-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -245,23 +245,24 @@ type openAIWSRetryMetrics struct {
|
||||
|
||||
// OpenAIGatewayService handles OpenAI API gateway operations
|
||||
type OpenAIGatewayService struct {
|
||||
accountRepo AccountRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache GatewayCache
|
||||
cfg *config.Config
|
||||
codexDetector CodexClientRestrictionDetector
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
concurrencyService *ConcurrencyService
|
||||
billingService *BillingService
|
||||
rateLimitService *RateLimitService
|
||||
billingCacheService *BillingCacheService
|
||||
httpUpstream HTTPUpstream
|
||||
deferredService *DeferredService
|
||||
openAITokenProvider *OpenAITokenProvider
|
||||
toolCorrector *CodexToolCorrector
|
||||
openaiWSResolver OpenAIWSProtocolResolver
|
||||
accountRepo AccountRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache GatewayCache
|
||||
cfg *config.Config
|
||||
codexDetector CodexClientRestrictionDetector
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
concurrencyService *ConcurrencyService
|
||||
billingService *BillingService
|
||||
rateLimitService *RateLimitService
|
||||
billingCacheService *BillingCacheService
|
||||
userGroupRateResolver *userGroupRateResolver
|
||||
httpUpstream HTTPUpstream
|
||||
deferredService *DeferredService
|
||||
openAITokenProvider *OpenAITokenProvider
|
||||
toolCorrector *CodexToolCorrector
|
||||
openaiWSResolver OpenAIWSProtocolResolver
|
||||
|
||||
openaiWSPoolOnce sync.Once
|
||||
openaiWSStateStoreOnce sync.Once
|
||||
@@ -284,6 +285,7 @@ func NewOpenAIGatewayService(
|
||||
usageLogRepo UsageLogRepository,
|
||||
userRepo UserRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
userGroupRateRepo UserGroupRateRepository,
|
||||
cache GatewayCache,
|
||||
cfg *config.Config,
|
||||
schedulerSnapshot *SchedulerSnapshotService,
|
||||
@@ -296,18 +298,25 @@ func NewOpenAIGatewayService(
|
||||
openAITokenProvider *OpenAITokenProvider,
|
||||
) *OpenAIGatewayService {
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
codexDetector: NewOpenAICodexClientRestrictionDetector(cfg),
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
concurrencyService: concurrencyService,
|
||||
billingService: billingService,
|
||||
rateLimitService: rateLimitService,
|
||||
billingCacheService: billingCacheService,
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
codexDetector: NewOpenAICodexClientRestrictionDetector(cfg),
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
concurrencyService: concurrencyService,
|
||||
billingService: billingService,
|
||||
rateLimitService: rateLimitService,
|
||||
billingCacheService: billingCacheService,
|
||||
userGroupRateResolver: newUserGroupRateResolver(
|
||||
userGroupRateRepo,
|
||||
nil,
|
||||
resolveUserGroupRateCacheTTL(cfg),
|
||||
nil,
|
||||
"service.openai_gateway",
|
||||
),
|
||||
httpUpstream: httpUpstream,
|
||||
deferredService: deferredService,
|
||||
openAITokenProvider: openAITokenProvider,
|
||||
@@ -3261,6 +3270,14 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
|
||||
// Correct tool calls in final response
|
||||
body = s.correctToolCallsInResponseBody(body)
|
||||
} else {
|
||||
terminalType, terminalPayload, terminalOK := extractOpenAISSETerminalEvent(bodyText)
|
||||
if terminalOK && terminalType == "response.failed" {
|
||||
msg := extractOpenAISSEErrorMessage(terminalPayload)
|
||||
if msg == "" {
|
||||
msg = "Upstream compact response failed"
|
||||
}
|
||||
return nil, s.writeOpenAINonStreamingProtocolError(resp, c, msg)
|
||||
}
|
||||
usage = s.parseSSEUsageFromBody(bodyText)
|
||||
if originalModel != mappedModel {
|
||||
bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel)
|
||||
@@ -3282,6 +3299,51 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func extractOpenAISSETerminalEvent(body string) (string, []byte, bool) {
|
||||
lines := strings.Split(body, "\n")
|
||||
for _, line := range lines {
|
||||
data, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok || data == "" || data == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
eventType := strings.TrimSpace(gjson.Get(data, "type").String())
|
||||
switch eventType {
|
||||
case "response.completed", "response.done", "response.failed":
|
||||
return eventType, []byte(data), true
|
||||
}
|
||||
}
|
||||
return "", nil, false
|
||||
}
|
||||
|
||||
func extractOpenAISSEErrorMessage(payload []byte) string {
|
||||
if len(payload) == 0 {
|
||||
return ""
|
||||
}
|
||||
for _, path := range []string{"response.error.message", "error.message", "message"} {
|
||||
if msg := strings.TrimSpace(gjson.GetBytes(payload, path).String()); msg != "" {
|
||||
return sanitizeUpstreamErrorMessage(msg)
|
||||
}
|
||||
}
|
||||
return sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(payload)))
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) writeOpenAINonStreamingProtocolError(resp *http.Response, c *gin.Context, message string) error {
|
||||
message = sanitizeUpstreamErrorMessage(strings.TrimSpace(message))
|
||||
if message == "" {
|
||||
message = "Upstream returned an invalid non-streaming response"
|
||||
}
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, message, "")
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||
c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
return fmt.Errorf("non-streaming openai protocol error: %s", message)
|
||||
}
|
||||
|
||||
func extractCodexFinalResponse(body string) ([]byte, bool) {
|
||||
lines := strings.Split(body, "\n")
|
||||
for _, line := range lines {
|
||||
@@ -3413,7 +3475,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
// Get rate multiplier
|
||||
multiplier := s.cfg.Default.RateMultiplier
|
||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||
multiplier = apiKey.Group.RateMultiplier
|
||||
resolver := s.userGroupRateResolver
|
||||
if resolver == nil {
|
||||
resolver = newUserGroupRateResolver(nil, nil, resolveUserGroupRateCacheTTL(s.cfg), nil, "service.openai_gateway")
|
||||
}
|
||||
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
|
||||
}
|
||||
|
||||
cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier)
|
||||
|
||||
Reference in New Issue
Block a user