refactor: Introduce pre-consume quota and unify relay handlers

This commit introduces a major architectural refactoring to improve quota management, centralize logging, and streamline the relay handling logic.

Key changes:
- **Pre-consume Quota:** Implements a new mechanism to check and reserve user quota *before* making the request to the upstream provider. This ensures more accurate quota deduction and prevents users from exceeding their limits due to concurrent requests.

- **Unified Relay Handlers:** Refactors the relay logic to use generic handlers (e.g., `ChatHandler`, `ImageHandler`) instead of provider-specific implementations. This significantly reduces code duplication and simplifies adding new channels.

- **Centralized Logger:** A new dedicated `logger` package is introduced, and all system logging calls are migrated to use it, moving this responsibility out of the `common` package.

- **Code Reorganization:** DTOs are generalized (e.g., `dalle.go` -> `openai_image.go`) and utility code is moved to more appropriate packages (e.g., `common/http.go` -> `service/http.go`) for better code structure.
This commit is contained in:
CaIon
2025-08-14 20:05:06 +08:00
parent 17bab355e4
commit e2037ad756
113 changed files with 3095 additions and 2518 deletions

View File

@@ -2,17 +2,16 @@ package relay
import (
"bytes"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/logger"
"one-api/relay/channel/gemini"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
"one-api/setting/model_setting"
"one-api/types"
"strings"
@@ -20,64 +19,6 @@ import (
"github.com/gin-gonic/gin"
)
func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) {
request := &dto.GeminiChatRequest{}
err := common.UnmarshalBodyReusable(c, request)
if err != nil {
return nil, err
}
if len(request.Contents) == 0 {
return nil, errors.New("contents is required")
}
return request, nil
}
// 流模式
// /v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse&key=xxx
func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
if c.Query("alt") == "sse" {
relayInfo.IsStream = true
}
// if strings.Contains(c.Request.URL.Path, "streamGenerateContent") {
// relayInfo.IsStream = true
// }
}
func checkGeminiInputSensitive(textRequest *dto.GeminiChatRequest) ([]string, error) {
var inputTexts []string
for _, content := range textRequest.Contents {
for _, part := range content.Parts {
if part.Text != "" {
inputTexts = append(inputTexts, part.Text)
}
}
}
if len(inputTexts) == 0 {
return nil, nil
}
sensitiveWords, err := service.CheckSensitiveInput(inputTexts)
return sensitiveWords, err
}
func getGeminiInputTokens(req *dto.GeminiChatRequest, info *relaycommon.RelayInfo) int {
// 计算输入 token 数量
var inputTexts []string
for _, content := range req.Contents {
for _, part := range content.Parts {
if part.Text != "" {
inputTexts = append(inputTexts, part.Text)
}
}
}
inputText := strings.Join(inputTexts, "\n")
inputTokens := service.CountTokenInput(inputText, info.UpstreamModelName)
info.PromptTokens = inputTokens
return inputTokens
}
func isNoThinkingRequest(req *dto.GeminiChatRequest) bool {
if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil {
configBudget := req.GenerationConfig.ThinkingConfig.ThinkingBudget
@@ -109,97 +50,61 @@ func trimModelThinking(modelName string) string {
return modelName
}
func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
req, err := getAndValidateGeminiRequest(c)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error()))
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
info.InitChannelMeta(c)
relayInfo := relaycommon.GenRelayInfoGemini(c)
// 检查 Gemini 流式模式
checkGeminiStreamMode(c, relayInfo)
if setting.ShouldCheckPromptSensitive() {
sensitiveWords, err := checkGeminiInputSensitive(req)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry())
}
request, ok := info.Request.(*dto.GeminiChatRequest)
if !ok {
common.FatalLog(fmt.Sprintf("invalid request type, expected dto.GeminiChatRequest, got %T", info.Request))
}
// model mapped 模型映射
err = helper.ModelMappedHelper(c, relayInfo, req)
err := helper.ModelMappedHelper(c, info, request)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
if value, exists := c.Get("prompt_tokens"); exists {
promptTokens := value.(int)
relayInfo.SetPromptTokens(promptTokens)
} else {
promptTokens := getGeminiInputTokens(req, relayInfo)
c.Set("prompt_tokens", promptTokens)
}
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
if isNoThinkingRequest(req) {
if isNoThinkingRequest(request) {
// check is thinking
if !strings.Contains(relayInfo.OriginModelName, "-nothinking") {
if !strings.Contains(info.OriginModelName, "-nothinking") {
// try to get no thinking model price
noThinkingModelName := relayInfo.OriginModelName + "-nothinking"
noThinkingModelName := info.OriginModelName + "-nothinking"
containPrice := helper.ContainPriceOrRatio(noThinkingModelName)
if containPrice {
relayInfo.OriginModelName = noThinkingModelName
relayInfo.UpstreamModelName = noThinkingModelName
info.OriginModelName = noThinkingModelName
info.UpstreamModelName = noThinkingModelName
}
}
}
if req.GenerationConfig.ThinkingConfig == nil {
gemini.ThinkingAdaptor(req, relayInfo)
if request.GenerationConfig.ThinkingConfig == nil {
gemini.ThinkingAdaptor(request, info)
}
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
// pre consume quota
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newAPIError != nil {
return newAPIError
}
defer func() {
if newAPIError != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType)
adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
adaptor.Init(info)
// Clean up empty system instruction
if req.SystemInstructions != nil {
if request.SystemInstructions != nil {
hasContent := false
for _, part := range req.SystemInstructions.Parts {
for _, part := range request.SystemInstructions.Parts {
if part.Text != "" {
hasContent = true
break
}
}
if !hasContent {
req.SystemInstructions = nil
request.SystemInstructions = nil
}
}
var requestBody io.Reader
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
@@ -207,7 +112,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
requestBody = bytes.NewReader(body)
} else {
// 使用 ConvertGeminiRequest 转换请求格式
convertedRequest, err := adaptor.ConvertGeminiRequest(c, relayInfo, req)
convertedRequest, err := adaptor.ConvertGeminiRequest(c, info, request)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
@@ -217,10 +122,10 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
// apply param override
if len(relayInfo.ParamOverride) > 0 {
if len(info.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
_ = common.Unmarshal(jsonData, &reqMap)
for key, value := range relayInfo.ParamOverride {
for key, value := range info.ParamOverride {
reqMap[key] = value
}
jsonData, err = common.Marshal(reqMap)
@@ -229,15 +134,14 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
}
if common.DebugEnabled {
println("Gemini request body: %s", string(jsonData))
}
logger.LogDebug(c, "Gemini request body: "+string(jsonData))
requestBody = bytes.NewReader(jsonData)
}
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
common.LogError(c, "Do gemini request failed: "+err.Error())
logger.LogError(c, "Do gemini request failed: "+err.Error())
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
@@ -246,7 +150,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
var httpResp *http.Response
if resp != nil {
httpResp = resp.(*http.Response)
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
newAPIError = service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
@@ -255,23 +159,22 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
}
usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo)
usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), info)
if openaiErr != nil {
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
postConsumeQuota(c, info, usage.(*dto.Usage), "")
return nil
}
func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
relayInfo := relaycommon.GenRelayInfoGemini(c)
func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
info.InitChannelMeta(c)
isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents")
relayInfo.IsGeminiBatchEmbedding = isBatch
info.IsGeminiBatchEmbedding = isBatch
var promptTokens int
var req any
var err error
var inputTexts []string
@@ -303,35 +206,17 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
}
}
}
promptTokens = service.CountTokenInput(strings.Join(inputTexts, "\n"), relayInfo.UpstreamModelName)
relayInfo.SetPromptTokens(promptTokens)
c.Set("prompt_tokens", promptTokens)
err = helper.ModelMappedHelper(c, relayInfo, req)
err = helper.ModelMappedHelper(c, info, req)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, 0)
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newAPIError != nil {
return newAPIError
}
defer func() {
if newAPIError != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType)
adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
adaptor.Init(info)
var requestBody io.Reader
jsonData, err := common.Marshal(req)
@@ -340,10 +225,10 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
}
// apply param override
if len(relayInfo.ParamOverride) > 0 {
if len(info.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
_ = common.Unmarshal(jsonData, &reqMap)
for key, value := range relayInfo.ParamOverride {
for key, value := range info.ParamOverride {
reqMap[key] = value
}
jsonData, err = common.Marshal(reqMap)
@@ -353,9 +238,9 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
}
requestBody = bytes.NewReader(jsonData)
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
common.LogError(c, "Do gemini request failed: "+err.Error())
logger.LogError(c, "Do gemini request failed: "+err.Error())
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
@@ -370,12 +255,12 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
}
}
usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo)
usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), info)
if openaiErr != nil {
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
postConsumeQuota(c, info, usage.(*dto.Usage), "")
return nil
}