refactor: Introduce pre-consume quota and unify relay handlers

This commit introduces a major architectural refactoring to improve quota management, centralize logging, and streamline the relay handling logic.

Key changes:
- **Pre-consume Quota:** Implements a new mechanism to check and reserve user quota *before* making the request to the upstream provider. This ensures more accurate quota deduction and prevents users from exceeding their limits due to concurrent requests.

- **Unified Relay Handlers:** Refactors the relay logic to use generic handlers (e.g., `ChatHandler`, `ImageHandler`) instead of provider-specific implementations. This significantly reduces code duplication and simplifies adding new channels.

- **Centralized Logger:** A new dedicated `logger` package is introduced, and all system logging calls are migrated to use it, moving this responsibility out of the `common` package.

- **Code Reorganization:** DTOs are generalized (e.g., `dalle.go` -> `openai_image.go`) and utility code is moved to more appropriate packages (e.g., `common/http.go` -> `service/http.go`) for better code structure.
This commit is contained in:
CaIon
2025-08-14 20:05:06 +08:00
parent 17bab355e4
commit e2037ad756
113 changed files with 3095 additions and 2518 deletions

View File

@@ -10,6 +10,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -108,11 +109,11 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
if resp == nil || resp.Body == nil {
common.LogError(c, "invalid response or response body")
logger.LogError(c, "invalid response or response body")
return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
}
defer common.CloseResponseBodyGracefully(resp)
defer service.CloseResponseBodyGracefully(resp)
model := info.UpstreamModelName
var responseId string
@@ -129,7 +130,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
if lastStreamData != "" {
err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
if err != nil {
common.SysError("error handling stream format: " + err.Error())
logger.SysError("error handling stream format: " + err.Error())
}
}
if len(data) > 0 {
@@ -143,7 +144,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
shouldSendLastResp := true
if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
&containStreamUsage, info, &shouldSendLastResp); err != nil {
common.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData))
logger.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData))
}
if info.RelayFormat == relaycommon.RelayFormatOpenAI {
@@ -154,7 +155,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
// 处理token计算
if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
common.LogError(c, "error processing tokens: "+err.Error())
logger.LogError(c, "error processing tokens: "+err.Error())
}
if !containStreamUsage {
@@ -173,7 +174,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
}
func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp)
defer service.CloseResponseBodyGracefully(resp)
var simpleResponse dto.OpenAITextResponse
responseBody, err := io.ReadAll(resp.Body)
@@ -235,7 +236,7 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
responseBody = geminiRespStr
}
common.IOCopyBytesGracefully(c, resp, responseBody)
service.IOCopyBytesGracefully(c, resp, responseBody)
return &simpleResponse.Usage, nil
}
@@ -247,7 +248,7 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
// if the upstream returns a specific status code, once the upstream has already written the header,
// the subsequent failure of the response body should be regarded as a non-recoverable error,
// and can be terminated directly.
defer common.CloseResponseBodyGracefully(resp)
defer service.CloseResponseBodyGracefully(resp)
usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens
usage.TotalTokens = info.PromptTokens
@@ -258,13 +259,13 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
c.Writer.WriteHeaderNow()
_, err := io.Copy(c.Writer, resp.Body)
if err != nil {
common.LogError(c, err.Error())
logger.LogError(c, err.Error())
}
return usage
}
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
defer common.CloseResponseBodyGracefully(resp)
defer service.CloseResponseBodyGracefully(resp)
// count tokens by audio file duration
audioTokens, err := countAudioTokens(c)
@@ -276,7 +277,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
// 写入新的 response body
common.IOCopyBytesGracefully(c, resp, responseBody)
service.IOCopyBytesGracefully(c, resp, responseBody)
usage := &dto.Usage{}
usage.PromptTokens = audioTokens
@@ -386,7 +387,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
localUsage.InputTokens += textToken + audioToken
localUsage.InputTokenDetails.TextTokens += textToken
@@ -459,7 +460,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
info.IsFirstRequest = false
localUsage.InputTokens += textToken + audioToken
@@ -474,9 +475,9 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
localUsage = &dto.RealtimeUsage{}
// print now usage
}
common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
realtimeSession := realtimeEvent.Session
@@ -491,7 +492,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
localUsage.OutputTokens += textToken + audioToken
localUsage.OutputTokenDetails.TextTokens += textToken
@@ -517,7 +518,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
case <-targetClosed:
case err := <-errChan:
//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
common.LogError(c, "realtime error: "+err.Error())
logger.LogError(c, "realtime error: "+err.Error())
case <-c.Done():
}
@@ -553,7 +554,7 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R
}
func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp)
defer service.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -567,7 +568,7 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h
}
// 写入新的 response body
common.IOCopyBytesGracefully(c, resp, responseBody)
service.IOCopyBytesGracefully(c, resp, responseBody)
// Once we've written to the client, we should not return errors anymore
// because the upstream has already consumed resources and returned content