merge: sync upstream changes

This commit is contained in:
cyhhao
2026-01-17 18:30:16 +08:00
189 changed files with 20383 additions and 2096 deletions

View File

@@ -93,6 +93,8 @@ type OpenAIGatewayService struct {
billingCacheService *BillingCacheService
httpUpstream HTTPUpstream
deferredService *DeferredService
openAITokenProvider *OpenAITokenProvider
toolCorrector *CodexToolCorrector
}
// NewOpenAIGatewayService creates a new OpenAIGatewayService
@@ -110,6 +112,7 @@ func NewOpenAIGatewayService(
billingCacheService *BillingCacheService,
httpUpstream HTTPUpstream,
deferredService *DeferredService,
openAITokenProvider *OpenAITokenProvider,
) *OpenAIGatewayService {
return &OpenAIGatewayService{
accountRepo: accountRepo,
@@ -125,6 +128,8 @@ func NewOpenAIGatewayService(
billingCacheService: billingCacheService,
httpUpstream: httpUpstream,
deferredService: deferredService,
openAITokenProvider: openAITokenProvider,
toolCorrector: NewCodexToolCorrector(),
}
}
@@ -503,6 +508,15 @@ func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {
case AccountTypeOAuth:
// 使用 TokenProvider 获取缓存的 token
if s.openAITokenProvider != nil {
accessToken, err := s.openAITokenProvider.GetAccessToken(ctx, account)
if err != nil {
return "", "", err
}
return accessToken, "oauth", nil
}
// 降级TokenProvider 未配置时直接从账号读取
accessToken := account.GetOpenAIAccessToken()
if accessToken == "" {
return "", "", errors.New("access_token not found in credentials")
@@ -664,6 +678,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
proxyURL = account.Proxy.URL()
}
// Capture upstream request body for ops retry of this attempt.
if c != nil {
c.Set(OpsUpstreamRequestBodyKey, string(body))
}
// Send request
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
@@ -673,6 +692,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
@@ -707,6 +727,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
@@ -864,6 +885,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "http_error",
@@ -894,6 +916,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: kind,
@@ -1097,6 +1120,12 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
data = correctedData
line = "data: " + correctedData
}
// 写入客户端(客户端断开后继续 drain 上游)
if !clientDisconnected {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
@@ -1199,6 +1228,20 @@ func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel st
return line
}
// correctToolCallsInResponseBody 修正响应体中的工具调用
func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byte {
if len(body) == 0 {
return body
}
bodyStr := string(body)
corrected, changed := s.toolCorrector.CorrectToolCallsInSSEData(bodyStr)
if changed {
return []byte(corrected)
}
return body
}
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
// Parse response.completed event for usage (OpenAI Responses format)
var event struct {
@@ -1302,6 +1345,8 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
// Correct tool calls in final response
body = s.correctToolCallsInResponseBody(body)
} else {
usage = s.parseSSEUsageFromBody(bodyText)
if originalModel != mappedModel {
@@ -1470,28 +1515,30 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Create usage log
durationMs := int(result.Duration.Milliseconds())
accountRateMultiplier := account.BillingRateMultiplier()
usageLog := &UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
InputTokens: actualInputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost,
RateMultiplier: multiplier,
BillingType: billingType,
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
CreatedAt: time.Now(),
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
InputTokens: actualInputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost,
RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType,
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
CreatedAt: time.Now(),
}
// 添加 UserAgent