refactor: Introduce pre-consume quota and unify relay handlers

This commit introduces a major architectural refactoring to improve quota management, centralize logging, and streamline the relay handling logic.

Key changes:
- **Pre-consume Quota:** Implements a new mechanism to check and reserve user quota *before* making the request to the upstream provider. This ensures more accurate quota deduction and prevents users from exceeding their limits due to concurrent requests.

- **Unified Relay Handlers:** Refactors the relay logic to use generic handlers (e.g., `ChatHandler`, `ImageHandler`) instead of provider-specific implementations. This significantly reduces code duplication and simplifies adding new channels.

- **Centralized Logger:** A new dedicated `logger` package is introduced, and all system logging calls are migrated to use it, moving this responsibility out of the `common` package.

- **Code Reorganization:** DTOs are generalized (e.g., `dalle.go` -> `openai_image.go`) and utility code is moved to more appropriate packages (e.g., `common/http.go` -> `service/http.go`) for better code structure.
This commit is contained in:
CaIon
2025-08-14 20:05:06 +08:00
parent 17bab355e4
commit e2037ad756
113 changed files with 3095 additions and 2518 deletions

View File

@@ -4,107 +4,40 @@ import (
"errors"
"fmt"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
audioRequest := &dto.AudioRequest{}
err := common.UnmarshalBodyReusable(c, audioRequest)
if err != nil {
return nil, err
}
switch info.RelayMode {
case relayconstant.RelayModeAudioSpeech:
if audioRequest.Model == "" {
return nil, errors.New("model is required")
}
if setting.ShouldCheckPromptSensitive() {
words, err := service.CheckSensitiveInput(audioRequest.Input)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
return nil, err
}
}
default:
err = c.Request.ParseForm()
if err != nil {
return nil, err
}
formData := c.Request.PostForm
if audioRequest.Model == "" {
audioRequest.Model = formData.Get("model")
}
func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
info.InitChannelMeta(c)
if audioRequest.Model == "" {
return nil, errors.New("model is required")
}
audioRequest.ResponseFormat = formData.Get("response_format")
if audioRequest.ResponseFormat == "" {
audioRequest.ResponseFormat = "json"
}
}
return audioRequest, nil
}
func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c)
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error()))
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
audioRequest, ok := info.Request.(*dto.AudioRequest)
if !ok {
return types.NewError(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
promptTokens := 0
preConsumedTokens := common.PreConsumedQuota
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
preConsumedTokens = promptTokens
relayInfo.PromptTokens = promptTokens
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
defer func() {
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
err := helper.ModelMappedHelper(c, info, audioRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
adaptor := GetAdaptor(relayInfo.ApiType)
adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
adaptor.Init(info)
ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest)
ioReader, err := adaptor.ConvertAudioRequest(c, info, *audioRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
resp, err := adaptor.DoRequest(c, relayInfo, ioReader)
resp, err := adaptor.DoRequest(c, info, ioReader)
if err != nil {
return types.NewError(err, types.ErrorCodeDoRequestFailed)
}
@@ -121,14 +54,14 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
}
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
if newAPIError != nil {
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
postConsumeQuota(c, info, usage.(*dto.Usage), "")
return nil
}

View File

@@ -6,8 +6,8 @@ import (
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/types"
@@ -43,7 +43,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
common.SysError("updateTask client.Do err: " + err.Error())
logger.SysError("updateTask client.Do err: " + err.Error())
return &aliResponse, err, nil
}
defer resp.Body.Close()
@@ -53,7 +53,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error
var response AliResponse
err = json.Unmarshal(responseBody, &response)
if err != nil {
common.SysError("updateTask NewDecoder err: " + err.Error())
logger.SysError("updateTask NewDecoder err: " + err.Error())
return &aliResponse, err, nil
}
@@ -109,7 +109,7 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc
if responseFormat == "b64_json" {
_, b64, err := service.GetImageFromUrl(data.Url)
if err != nil {
common.LogError(c, "get_image_data_failed: "+err.Error())
logger.LogError(c, "get_image_data_failed: "+err.Error())
continue
}
b64Json = b64
@@ -134,14 +134,14 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &aliTaskResponse)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
if aliTaskResponse.Message != "" {
common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
logger.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
}

View File

@@ -4,9 +4,9 @@ import (
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/types"
"github.com/gin-gonic/gin"
@@ -36,7 +36,7 @@ func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
var aliResponse AliRerankResponse
err = json.Unmarshal(responseBody, &aliResponse)

View File

@@ -7,7 +7,9 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
"one-api/logger"
"one-api/relay/helper"
"one-api/service"
"strings"
"one-api/types"
@@ -46,7 +48,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIErro
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
model := c.GetString("model")
if model == "" {
@@ -148,7 +150,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
var aliResponse AliResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if aliResponse.Usage.OutputTokens != 0 {
@@ -161,7 +163,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -171,7 +173,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
return false
}
})
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
return nil, &usage
}
@@ -181,7 +183,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.U
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil

View File

@@ -7,6 +7,7 @@ import (
"io"
"net/http"
common2 "one-api/common"
"one-api/logger"
"one-api/relay/common"
"one-api/relay/constant"
"one-api/relay/helper"
@@ -181,7 +182,7 @@ func sendPingData(c *gin.Context, mutex *sync.Mutex) error {
err := helper.PingData(c)
if err != nil {
common2.LogError(c, "SSE ping error: "+err.Error())
logger.LogError(c, "SSE ping error: "+err.Error())
done <- err
return
}

View File

@@ -9,6 +9,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -118,7 +119,7 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
var baiduResponse BaiduChatStreamResponse
err := common.Unmarshal([]byte(data), &baiduResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if baiduResponse.Usage.TotalTokens != 0 {
@@ -129,11 +130,11 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
response := streamResponseBaidu2OpenAI(&baiduResponse)
err = helper.ObjectData(c, response)
if err != nil {
common.SysError("error sending stream response: " + err.Error())
logger.SysError("error sending stream response: " + err.Error())
}
return true
})
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
return nil, usage
}
@@ -143,7 +144,7 @@ func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
@@ -168,7 +169,7 @@ func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil

View File

@@ -7,6 +7,7 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
"one-api/logger"
"one-api/relay/channel/openrouter"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
@@ -375,7 +376,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
for _, toolCall := range message.ParseToolCalls() {
inputObj := make(map[string]any)
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
logger.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
continue
}
claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{
@@ -609,7 +610,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
var claudeResponse dto.ClaudeResponse
err := common.UnmarshalJsonStr(data, &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
logger.SysError("error unmarshalling stream response: " + err.Error())
return types.NewError(err, types.ErrorCodeBadResponseBody)
}
if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
@@ -637,7 +638,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
err = helper.ObjectData(c, response)
if err != nil {
common.LogError(c, "send_stream_response_failed: "+err.Error())
logger.LogError(c, "send_stream_response_failed: "+err.Error())
}
}
return nil
@@ -653,7 +654,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
}
if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
if common.DebugEnabled {
common.SysError("claude response usage is not complete, maybe upstream error")
logger.SysError("claude response usage is not complete, maybe upstream error")
}
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
}
@@ -667,7 +668,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
err := helper.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
logger.SysError("send final response failed: " + err.Error())
}
}
helper.Done(c)
@@ -736,12 +737,12 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
c.Set("claude_web_search_requests", claudeResponse.Usage.ServerToolUse.WebSearchRequests)
}
common.IOCopyBytesGracefully(c, nil, responseData)
service.IOCopyBytesGracefully(c, nil, responseData)
return nil
}
func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
defer common.CloseResponseBodyGracefully(resp)
defer service.CloseResponseBodyGracefully(resp)
claudeInfo := &ClaudeResponseInfo{
ResponseId: helper.GetResponseID(c),

View File

@@ -5,8 +5,8 @@ import (
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -51,7 +51,7 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
var response dto.ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &response)
if err != nil {
common.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
logger.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
continue
}
for _, choice := range response.Choices {
@@ -66,24 +66,24 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
info.FirstResponseTime = time.Now()
}
if err != nil {
common.LogError(c, "error_rendering_stream_response: "+err.Error())
logger.LogError(c, "error_rendering_stream_response: "+err.Error())
}
}
if err := scanner.Err(); err != nil {
common.LogError(c, "error_scanning_stream_response: "+err.Error())
logger.LogError(c, "error_scanning_stream_response: "+err.Error())
}
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
err := helper.ObjectData(c, response)
if err != nil {
common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
logger.LogError(c, "error_rendering_final_usage_response: "+err.Error())
}
}
helper.Done(c)
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
return nil, usage
}
@@ -93,7 +93,7 @@ func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
var response dto.TextResponse
err = json.Unmarshal(responseBody, &response)
if err != nil {
@@ -123,7 +123,7 @@ func cfSTTHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &cfResp)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil

View File

@@ -7,6 +7,7 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -118,7 +119,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
var cohereResp CohereResponse
err := json.Unmarshal([]byte(data), &cohereResp)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
var openaiResp dto.ChatCompletionsStreamResponse
@@ -153,7 +154,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
}
jsonStr, err := json.Marshal(openaiResp)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
@@ -175,7 +176,7 @@ func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
var cohereResp CohereResponseResult
err = json.Unmarshal(responseBody, &cohereResp)
if err != nil {
@@ -216,7 +217,7 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
var cohereResp CohereRerankResponseResult
err = json.Unmarshal(responseBody, &cohereResp)
if err != nil {

View File

@@ -9,6 +9,7 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -49,7 +50,7 @@ func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
// convert coze response to openai response
var response dto.TextResponse
var cozeResponse CozeChatDetailResponse
@@ -154,7 +155,7 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
var chatData CozeChatResponseData
err := json.Unmarshal([]byte(data), &chatData)
if err != nil {
common.SysError("error_unmarshalling_stream_response: " + err.Error())
logger.SysError("error_unmarshalling_stream_response: " + err.Error())
return
}
@@ -171,14 +172,14 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
var messageData CozeChatV3MessageDetail
err := json.Unmarshal([]byte(data), &messageData)
if err != nil {
common.SysError("error_unmarshalling_stream_response: " + err.Error())
logger.SysError("error_unmarshalling_stream_response: " + err.Error())
return
}
var content string
err = json.Unmarshal(messageData.Content, &content)
if err != nil {
common.SysError("error_unmarshalling_stream_response: " + err.Error())
logger.SysError("error_unmarshalling_stream_response: " + err.Error())
return
}
@@ -203,11 +204,11 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
var errorData CozeError
err := json.Unmarshal([]byte(data), &errorData)
if err != nil {
common.SysError("error_unmarshalling_stream_response: " + err.Error())
logger.SysError("error_unmarshalling_stream_response: " + err.Error())
return
}
common.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
logger.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
}
}

View File

@@ -11,6 +11,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -36,14 +37,14 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
// Decode base64 string
decodedData, err := base64.StdEncoding.DecodeString(base64Data)
if err != nil {
common.SysError("failed to decode base64: " + err.Error())
logger.SysError("failed to decode base64: " + err.Error())
return nil
}
// Create temporary file
tempFile, err := os.CreateTemp("", "dify-upload-*")
if err != nil {
common.SysError("failed to create temp file: " + err.Error())
logger.SysError("failed to create temp file: " + err.Error())
return nil
}
defer tempFile.Close()
@@ -51,7 +52,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
// Write decoded data to temp file
if _, err := tempFile.Write(decodedData); err != nil {
common.SysError("failed to write to temp file: " + err.Error())
logger.SysError("failed to write to temp file: " + err.Error())
return nil
}
@@ -61,7 +62,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
// Add user field
if err := writer.WriteField("user", user); err != nil {
common.SysError("failed to add user field: " + err.Error())
logger.SysError("failed to add user field: " + err.Error())
return nil
}
@@ -74,13 +75,13 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
// Create form file
part, err := writer.CreateFormFile("file", fmt.Sprintf("image.%s", strings.TrimPrefix(mimeType, "image/")))
if err != nil {
common.SysError("failed to create form file: " + err.Error())
logger.SysError("failed to create form file: " + err.Error())
return nil
}
// Copy file content to form
if _, err = io.Copy(part, bytes.NewReader(decodedData)); err != nil {
common.SysError("failed to copy file content: " + err.Error())
logger.SysError("failed to copy file content: " + err.Error())
return nil
}
writer.Close()
@@ -88,7 +89,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
// Create HTTP request
req, err := http.NewRequest("POST", uploadUrl, body)
if err != nil {
common.SysError("failed to create request: " + err.Error())
logger.SysError("failed to create request: " + err.Error())
return nil
}
@@ -99,7 +100,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
client := service.GetHttpClient()
resp, err := client.Do(req)
if err != nil {
common.SysError("failed to send request: " + err.Error())
logger.SysError("failed to send request: " + err.Error())
return nil
}
defer resp.Body.Close()
@@ -109,7 +110,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
Id string `json:"id"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
common.SysError("failed to decode response: " + err.Error())
logger.SysError("failed to decode response: " + err.Error())
return nil
}
@@ -219,7 +220,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
var difyResponse DifyChunkChatCompletionResponse
err := json.Unmarshal([]byte(data), &difyResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
var openaiResponse dto.ChatCompletionsStreamResponse
@@ -239,7 +240,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
}
err = helper.ObjectData(c, openaiResponse)
if err != nil {
common.SysError(err.Error())
logger.SysError(err.Error())
}
return true
})
@@ -258,7 +259,7 @@ func difyHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &difyResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)

View File

@@ -78,7 +78,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
},
},
Parameters: dto.GeminiImageParameters{
SampleCount: request.N,
SampleCount: int(request.N),
AspectRatio: aspectRatio,
PersonGeneration: "allow_adult", // default allow adult
},

View File

@@ -5,6 +5,7 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -17,7 +18,7 @@ import (
)
func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp)
defer service.CloseResponseBodyGracefully(resp)
// 读取响应体
responseBody, err := io.ReadAll(resp.Body)
@@ -53,13 +54,13 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
}
}
common.IOCopyBytesGracefully(c, resp, responseBody)
service.IOCopyBytesGracefully(c, resp, responseBody)
return &usage, nil
}
func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp)
defer service.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -89,7 +90,7 @@ func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *rel
}
}
common.IOCopyBytesGracefully(c, resp, responseBody)
service.IOCopyBytesGracefully(c, resp, responseBody)
return usage, nil
}
@@ -106,7 +107,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
var geminiResponse dto.GeminiChatResponse
err := common.UnmarshalJsonStr(data, &geminiResponse)
if err != nil {
common.LogError(c, "error unmarshalling stream response: "+err.Error())
logger.LogError(c, "error unmarshalling stream response: "+err.Error())
return false
}
@@ -140,7 +141,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
// 直接发送 GeminiChatResponse 响应
err = helper.StringData(c, data)
if err != nil {
common.LogError(c, err.Error())
logger.LogError(c, err.Error())
}
info.SendResponseCount++
return true

View File

@@ -9,6 +9,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
@@ -901,7 +902,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
var geminiResponse dto.GeminiChatResponse
err := common.UnmarshalJsonStr(data, &geminiResponse)
if err != nil {
common.LogError(c, "error unmarshalling stream response: "+err.Error())
logger.LogError(c, "error unmarshalling stream response: "+err.Error())
return false
}
@@ -945,7 +946,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
finishReason = constant.FinishReasonToolCalls
err = handleStream(c, info, emptyResponse)
if err != nil {
common.LogError(c, err.Error())
logger.LogError(c, err.Error())
}
response.ClearToolCalls()
@@ -957,7 +958,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
err = handleStream(c, info, response)
if err != nil {
common.LogError(c, err.Error())
logger.LogError(c, err.Error())
}
if isStop {
_ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason))
@@ -993,7 +994,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
err := handleFinalStream(c, info, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
logger.SysError("send final response failed: " + err.Error())
}
//if info.RelayFormat == relaycommon.RelayFormatOpenAI {
// helper.Done(c)
@@ -1007,7 +1008,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
if common.DebugEnabled {
println(string(responseBody))
}
@@ -1057,13 +1058,13 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
break
}
common.IOCopyBytesGracefully(c, resp, responseBody)
service.IOCopyBytesGracefully(c, resp, responseBody)
return &usage, nil
}
func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp)
defer service.CloseResponseBodyGracefully(resp)
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
@@ -1107,7 +1108,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
common.IOCopyBytesGracefully(c, resp, jsonResponse)
service.IOCopyBytesGracefully(c, resp, jsonResponse)
return usage, nil
}

View File

@@ -5,9 +5,9 @@ import (
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/types"
"github.com/gin-gonic/gin"
@@ -54,7 +54,7 @@ func jimengImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.R
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &jimengResponse)
if err != nil {

View File

@@ -12,7 +12,7 @@ import (
"io"
"net/http"
"net/url"
"one-api/common"
"one-api/logger"
"sort"
"strings"
"time"
@@ -44,7 +44,7 @@ func SetPayloadHash(c *gin.Context, req any) error {
if err != nil {
return err
}
common.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body))
logger.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body))
payloadHash := sha256.Sum256(body)
hexPayloadHash := hex.EncodeToString(payloadHash[:])
c.Set(HexPayloadHashKey, hexPayloadHash)

View File

@@ -7,6 +7,7 @@ import (
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/types"
"github.com/gin-gonic/gin"
@@ -56,7 +57,7 @@ func mokaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
@@ -77,6 +78,6 @@ func mokaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
common.IOCopyBytesGracefully(c, resp, jsonResponse)
service.IOCopyBytesGracefully(c, resp, jsonResponse)
return &fullTextResponse.Usage, nil
}

View File

@@ -94,7 +94,7 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
@@ -123,7 +123,7 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
common.IOCopyBytesGracefully(c, resp, doResponseBody)
service.IOCopyBytesGracefully(c, resp, doResponseBody)
return usage, nil
}

View File

@@ -7,6 +7,7 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
@@ -50,7 +51,7 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
var streamResponse dto.ChatCompletionsStreamResponse
if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
common.LogError(c, "failed to unmarshal stream response: "+err.Error())
logger.LogError(c, "failed to unmarshal stream response: "+err.Error())
return err
}
@@ -63,7 +64,7 @@ func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
geminiResponseStr, err := common.Marshal(geminiResponse)
if err != nil {
common.LogError(c, "failed to marshal gemini response: "+err.Error())
logger.LogError(c, "failed to marshal gemini response: "+err.Error())
return err
}
@@ -110,14 +111,14 @@ func processChatCompletions(streamResp string, streamItems []string, responseTex
var streamResponses []dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
logger.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
return err
}
if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
common.SysError("error processing stream response: " + err.Error())
logger.SysError("error processing stream response: " + err.Error())
}
}
return nil
@@ -146,7 +147,7 @@ func processCompletions(streamResp string, streamItems []string, responseTextBui
var streamResponses []dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
logger.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
@@ -213,7 +214,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
info.ClaudeConvertInfo.Done = true
var streamResponse dto.ChatCompletionsStreamResponse
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
logger.SysError("error unmarshalling stream response: " + err.Error())
return
}
@@ -227,7 +228,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
case relaycommon.RelayFormatGemini:
var streamResponse dto.ChatCompletionsStreamResponse
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
logger.SysError("error unmarshalling stream response: " + err.Error())
return
}
@@ -245,7 +246,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
geminiResponseStr, err := common.Marshal(geminiResponse)
if err != nil {
common.SysError("error marshalling gemini response: " + err.Error())
logger.SysError("error marshalling gemini response: " + err.Error())
return
}

View File

@@ -10,6 +10,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -108,11 +109,11 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
if resp == nil || resp.Body == nil {
common.LogError(c, "invalid response or response body")
logger.LogError(c, "invalid response or response body")
return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
}
defer common.CloseResponseBodyGracefully(resp)
defer service.CloseResponseBodyGracefully(resp)
model := info.UpstreamModelName
var responseId string
@@ -129,7 +130,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
if lastStreamData != "" {
err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
if err != nil {
common.SysError("error handling stream format: " + err.Error())
logger.SysError("error handling stream format: " + err.Error())
}
}
if len(data) > 0 {
@@ -143,7 +144,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
shouldSendLastResp := true
if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
&containStreamUsage, info, &shouldSendLastResp); err != nil {
common.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData))
logger.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData))
}
if info.RelayFormat == relaycommon.RelayFormatOpenAI {
@@ -154,7 +155,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
// 处理token计算
if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
common.LogError(c, "error processing tokens: "+err.Error())
logger.LogError(c, "error processing tokens: "+err.Error())
}
if !containStreamUsage {
@@ -173,7 +174,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
}
func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp)
defer service.CloseResponseBodyGracefully(resp)
var simpleResponse dto.OpenAITextResponse
responseBody, err := io.ReadAll(resp.Body)
@@ -235,7 +236,7 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
responseBody = geminiRespStr
}
common.IOCopyBytesGracefully(c, resp, responseBody)
service.IOCopyBytesGracefully(c, resp, responseBody)
return &simpleResponse.Usage, nil
}
@@ -247,7 +248,7 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
// if the upstream returns a specific status code, once the upstream has already written the header,
// the subsequent failure of the response body should be regarded as a non-recoverable error,
// and can be terminated directly.
defer common.CloseResponseBodyGracefully(resp)
defer service.CloseResponseBodyGracefully(resp)
usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens
usage.TotalTokens = info.PromptTokens
@@ -258,13 +259,13 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
c.Writer.WriteHeaderNow()
_, err := io.Copy(c.Writer, resp.Body)
if err != nil {
common.LogError(c, err.Error())
logger.LogError(c, err.Error())
}
return usage
}
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
defer common.CloseResponseBodyGracefully(resp)
defer service.CloseResponseBodyGracefully(resp)
// count tokens by audio file duration
audioTokens, err := countAudioTokens(c)
@@ -276,7 +277,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
// 写入新的 response body
common.IOCopyBytesGracefully(c, resp, responseBody)
service.IOCopyBytesGracefully(c, resp, responseBody)
usage := &dto.Usage{}
usage.PromptTokens = audioTokens
@@ -386,7 +387,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
localUsage.InputTokens += textToken + audioToken
localUsage.InputTokenDetails.TextTokens += textToken
@@ -459,7 +460,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
info.IsFirstRequest = false
localUsage.InputTokens += textToken + audioToken
@@ -474,9 +475,9 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
localUsage = &dto.RealtimeUsage{}
// print now usage
}
common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
realtimeSession := realtimeEvent.Session
@@ -491,7 +492,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
localUsage.OutputTokens += textToken + audioToken
localUsage.OutputTokenDetails.TextTokens += textToken
@@ -517,7 +518,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
case <-targetClosed:
case err := <-errChan:
//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
common.LogError(c, "realtime error: "+err.Error())
logger.LogError(c, "realtime error: "+err.Error())
case <-c.Done():
}
@@ -553,7 +554,7 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R
}
func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp)
defer service.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -567,7 +568,7 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h
}
// 写入新的 response body
common.IOCopyBytesGracefully(c, resp, responseBody)
service.IOCopyBytesGracefully(c, resp, responseBody)
// Once we've written to the client, we should not return errors anymore
// because the upstream has already consumed resources and returned content

View File

@@ -6,6 +6,7 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -16,7 +17,7 @@ import (
)
func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp)
defer service.CloseResponseBodyGracefully(resp)
// read response body
var responsesResponse dto.OpenAIResponsesResponse
@@ -33,7 +34,7 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
}
// 写入新的 response body
common.IOCopyBytesGracefully(c, resp, responseBody)
service.IOCopyBytesGracefully(c, resp, responseBody)
// compute usage
usage := dto.Usage{}
@@ -54,7 +55,7 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
if resp == nil || resp.Body == nil {
common.LogError(c, "invalid response or response body")
logger.LogError(c, "invalid response or response body")
return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
}

View File

@@ -7,6 +7,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -58,15 +59,15 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
go func() {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
common.SysError("error reading stream response: " + err.Error())
logger.SysError("error reading stream response: " + err.Error())
stopChan <- true
return
}
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
logger.SysError("error unmarshalling stream response: " + err.Error())
stopChan <- true
return
}
@@ -78,7 +79,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
}
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
logger.SysError("error marshalling stream response: " + err.Error())
stopChan <- true
return
}
@@ -96,7 +97,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
return false
}
})
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
return nil, responseText
}
@@ -105,7 +106,7 @@ func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
@@ -133,6 +134,6 @@ func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
common.IOCopyBytesGracefully(c, resp, jsonResponse)
service.IOCopyBytesGracefully(c, resp, jsonResponse)
return &usage, nil
}

View File

@@ -4,9 +4,9 @@ import (
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/types"
"github.com/gin-gonic/gin"
@@ -17,7 +17,7 @@ func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
var siliconflowResp SFRerankResponse
err = json.Unmarshal(responseBody, &siliconflowResp)
if err != nil {
@@ -39,6 +39,6 @@ func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
common.IOCopyBytesGracefully(c, resp, jsonResponse)
service.IOCopyBytesGracefully(c, resp, jsonResponse)
return usage, nil
}

View File

@@ -11,6 +11,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
@@ -139,7 +140,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(byteBody))
if err != nil {
common.SysError(fmt.Sprintf("Get Task error: %v", err))
logger.SysError(fmt.Sprintf("Get Task error: %v", err))
return nil, err
}
defer req.Body.Close()

View File

@@ -13,6 +13,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -106,7 +107,7 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
var tencentResponse TencentChatResponse
err := json.Unmarshal([]byte(data), &tencentResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
@@ -117,17 +118,17 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
err = helper.ObjectData(c, response)
if err != nil {
common.SysError(err.Error())
logger.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
common.SysError("error reading stream: " + err.Error())
logger.SysError("error reading stream: " + err.Error())
}
helper.Done(c)
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
return service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens), nil
}
@@ -138,7 +139,7 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &tencentSb)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
@@ -156,7 +157,7 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
common.IOCopyBytesGracefully(c, resp, jsonResponse)
service.IOCopyBytesGracefully(c, resp, jsonResponse)
return &fullTextResponse.Usage, nil
}

View File

@@ -6,6 +6,7 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
"one-api/logger"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
@@ -47,7 +48,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
var xAIResp *dto.ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &xAIResp)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
@@ -63,7 +64,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
_ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
err = helper.ObjectData(c, openaiResponse)
if err != nil {
common.SysError(err.Error())
logger.SysError(err.Error())
}
return true
})
@@ -74,12 +75,12 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
}
helper.Done(c)
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
return usage, nil
}
func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp)
defer service.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -101,7 +102,7 @@ func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.IOCopyBytesGracefully(c, resp, encodeJson)
service.IOCopyBytesGracefully(c, resp, encodeJson)
return xaiResponse.Usage, nil
}

View File

@@ -11,6 +11,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
"one-api/relay/helper"
"one-api/types"
"strings"
@@ -143,7 +144,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -218,20 +219,20 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
for {
_, msg, err := conn.ReadMessage()
if err != nil {
common.SysError("error reading stream response: " + err.Error())
logger.SysError("error reading stream response: " + err.Error())
break
}
var response XunfeiChatResponse
err = json.Unmarshal(msg, &response)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
logger.SysError("error unmarshalling stream response: " + err.Error())
break
}
dataChan <- response
if response.Payload.Choices.Status == 2 {
err := conn.Close()
if err != nil {
common.SysError("error closing websocket connection: " + err.Error())
logger.SysError("error closing websocket connection: " + err.Error())
}
break
}
@@ -282,6 +283,6 @@ func getAPIVersion(c *gin.Context, modelName string) string {
return apiVersion
}
apiVersion = "v1.1"
common.SysLog("api_version not found, using default: " + apiVersion)
logger.SysLog("api_version not found, using default: " + apiVersion)
return apiVersion
}

View File

@@ -8,8 +8,10 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"strings"
"sync"
@@ -38,7 +40,7 @@ func getZhipuToken(apikey string) string {
split := strings.Split(apikey, ".")
if len(split) != 2 {
common.SysError("invalid zhipu key: " + apikey)
logger.SysError("invalid zhipu key: " + apikey)
return ""
}
@@ -186,7 +188,7 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
response := streamResponseZhipu2OpenAI(data)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -195,13 +197,13 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
var zhipuResponse ZhipuStreamMetaResponse
err := json.Unmarshal([]byte(data), &zhipuResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
usage = zhipuUsage
@@ -212,7 +214,7 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
return false
}
})
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
return usage, nil
}
@@ -222,7 +224,7 @@ func zhipuHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &zhipuResponse)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)

View File

@@ -10,6 +10,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
"one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
@@ -214,7 +215,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
logger.SysError("error consuming token remain quota: " + err.Error())
}
tokenName := c.GetString("token_name")
@@ -300,7 +301,7 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
}
common.IOCopyBytesGracefully(c, nil, respBody)
service.IOCopyBytesGracefully(c, nil, respBody)
return nil
}
@@ -521,7 +522,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
logger.SysError("error consuming token remain quota: " + err.Error())
}
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %sID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result)
@@ -572,7 +573,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
//无实例账号自动禁用渠道No available account instance
channel, err := model.GetChannelById(midjourneyTask.ChannelId, true)
if err != nil {
common.SysError("get_channel_null: " + err.Error())
logger.SysError("get_channel_null: " + err.Error())
}
if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled {
model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance")

View File

@@ -2,7 +2,6 @@ package relay
import (
"bytes"
"errors"
"fmt"
"io"
"net/http"
@@ -18,68 +17,26 @@ import (
"github.com/gin-gonic/gin"
)
func getAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) {
textRequest = &dto.ClaudeRequest{}
err = c.ShouldBindJSON(textRequest)
if err != nil {
return nil, err
}
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
return nil, errors.New("field messages is required")
}
if textRequest.Model == "" {
return nil, errors.New("field model is required")
}
return textRequest, nil
}
func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
info.InitChannelMeta(c)
relayInfo := relaycommon.GenRelayInfoClaude(c)
textRequest, ok := info.Request.(*dto.ClaudeRequest)
// get & validate textRequest 获取并验证文本请求
textRequest, err := getAndValidateClaudeRequest(c)
if err != nil {
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
if !ok {
common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ClaudeRequest, got %T", info.Request))
}
if textRequest.Stream {
relayInfo.IsStream = true
}
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
err := helper.ModelMappedHelper(c, info, textRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误
if err != nil {
return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry())
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newAPIError != nil {
return newAPIError
}
defer func() {
if newAPIError != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType)
adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
adaptor.Init(info)
if textRequest.MaxTokens == 0 {
textRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
@@ -104,18 +61,18 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
textRequest.Temperature = common.GetPointer[float64](1.0)
}
textRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
relayInfo.UpstreamModelName = textRequest.Model
info.UpstreamModelName = textRequest.Model
}
var requestBody io.Reader
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
requestBody = bytes.NewBuffer(body)
} else {
convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest)
convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, textRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
@@ -125,10 +82,10 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
// apply param override
if len(relayInfo.ParamOverride) > 0 {
if len(info.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
_ = common.Unmarshal(jsonData, &reqMap)
for key, value := range relayInfo.ParamOverride {
for key, value := range info.ParamOverride {
reqMap[key] = value
}
jsonData, err = common.Marshal(reqMap)
@@ -145,14 +102,14 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
statusCodeMappingStr := c.GetString("status_code_mapping")
var httpResp *http.Response
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
if resp != nil {
httpResp = resp.(*http.Response)
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
newAPIError = service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
@@ -161,24 +118,14 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
}
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
//log.Printf("usage: %v", usage)
if newAPIError != nil {
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
service.PostClaudeConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
service.PostClaudeConsumeQuota(c, info, usage.(*dto.Usage))
return nil
}
func getClaudePromptTokens(textRequest *dto.ClaudeRequest, info *relaycommon.RelayInfo) (int, error) {
var promptTokens int
var err error
switch info.RelayMode {
default:
promptTokens, err = service.CountTokenClaudeRequest(*textRequest, info.UpstreamModelName)
}
info.PromptTokens = promptTokens
return promptTokens, err
}

View File

@@ -1,10 +1,12 @@
package common
import (
"errors"
"one-api/common"
"one-api/constant"
"one-api/dto"
relayconstant "one-api/relay/constant"
"one-api/types"
"strings"
"time"
@@ -33,17 +35,6 @@ type ClaudeConvertInfo struct {
Done bool
}
const (
RelayFormatOpenAI = "openai"
RelayFormatClaude = "claude"
RelayFormatGemini = "gemini"
RelayFormatOpenAIResponses = "openai_responses"
RelayFormatOpenAIAudio = "openai_audio"
RelayFormatOpenAIImage = "openai_image"
RelayFormatRerank = "rerank"
RelayFormatEmbedding = "embedding"
)
type RerankerInfo struct {
Documents []any
ReturnDocuments bool
@@ -59,61 +50,103 @@ type ResponsesUsageInfo struct {
BuiltInTools map[string]*BuildInToolInfo
}
type RelayInfo struct {
type ChannelMeta struct {
ChannelType int
ChannelId int
ChannelIsMultiKey bool // 是否多密钥
ChannelMultiKeyIndex int // 多密钥索引
TokenId int
TokenKey string
UserId int
UsingGroup string // 使用的分组
UserGroup string // 用户所在分组
TokenUnlimited bool
StartTime time.Time
FirstResponseTime time.Time
isFirstResponse bool
ChannelIsMultiKey bool
ChannelMultiKeyIndex int
ChannelBaseUrl string
ApiType int
ApiVersion string
ApiKey string
Organization string
ChannelCreateTime int64
ParamOverride map[string]interface{}
ChannelSetting dto.ChannelSettings
ChannelOtherSettings dto.ChannelOtherSettings
UpstreamModelName string
IsModelMapped bool
}
type RelayInfo struct {
TokenId int
TokenKey string
UserId int
UsingGroup string // 使用的分组
UserGroup string // 用户所在分组
TokenUnlimited bool
StartTime time.Time
FirstResponseTime time.Time
isFirstResponse bool
//SendLastReasoningResponse bool
ApiType int
IsStream bool
IsGeminiBatchEmbedding bool
IsPlayground bool
UsePrice bool
RelayMode int
UpstreamModelName string
OriginModelName string
//RecodeModelName string
RequestURLPath string
ApiVersion string
PromptTokens int
ApiKey string
Organization string
BaseUrl string
SupportStreamOptions bool
ShouldIncludeUsage bool
DisablePing bool // 是否禁止向下游发送自定义 Ping
IsModelMapped bool
ClientWs *websocket.Conn
TargetWs *websocket.Conn
InputAudioFormat string
OutputAudioFormat string
RealtimeTools []dto.RealTimeTool
IsFirstRequest bool
AudioUsage bool
ReasoningEffort string
ChannelSetting dto.ChannelSettings
ChannelOtherSettings dto.ChannelOtherSettings
ParamOverride map[string]interface{}
UserSetting dto.UserSetting
UserEmail string
UserQuota int
RelayFormat string
SendResponseCount int
ChannelCreateTime int64
RequestURLPath string
PromptTokens int
SupportStreamOptions bool
ShouldIncludeUsage bool
DisablePing bool // 是否禁止向下游发送自定义 Ping
ClientWs *websocket.Conn
TargetWs *websocket.Conn
InputAudioFormat string
OutputAudioFormat string
RealtimeTools []dto.RealTimeTool
IsFirstRequest bool
AudioUsage bool
ReasoningEffort string
UserSetting dto.UserSetting
UserEmail string
UserQuota int
RelayFormat types.RelayFormat
SendResponseCount int
FinalPreConsumedQuota int // 最终预消耗的配额
PriceData types.PriceData
Request dto.Request
ThinkingContentInfo
*ClaudeConvertInfo
*RerankerInfo
*ResponsesUsageInfo
*ChannelMeta
}
func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
apiType, _ := common.ChannelType2APIType(channelType)
channelMeta := &ChannelMeta{
ChannelType: channelType,
ChannelId: common.GetContextKeyInt(c, constant.ContextKeyChannelId),
ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey),
ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex),
ChannelBaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl),
ApiType: apiType,
ApiVersion: c.GetString("api_version"),
ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey),
Organization: c.GetString("channel_organization"),
ChannelCreateTime: c.GetInt64("channel_create_time"),
ParamOverride: paramOverride,
UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
IsModelMapped: false,
}
channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
if ok {
channelMeta.ChannelSetting = channelSetting
}
channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting)
if ok {
channelMeta.ChannelOtherSettings = channelOtherSettings
}
info.ChannelMeta = channelMeta
}
// 定义支持流式选项的通道类型
@@ -132,7 +165,8 @@ var streamSupportedChannels = map[int]bool{
}
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
info := GenRelayInfo(c)
info := genBaseRelayInfo(c, nil)
info.RelayFormat = types.RelayFormatOpenAIRealtime
info.ClientWs = ws
info.InputAudioFormat = "pcm16"
info.OutputAudioFormat = "pcm16"
@@ -140,9 +174,9 @@ func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
return info
}
func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c)
info.RelayFormat = RelayFormatClaude
func GenRelayInfoClaude(c *gin.Context, request dto.Request) *RelayInfo {
info := genBaseRelayInfo(c, request)
info.RelayFormat = types.RelayFormatClaude
info.ShouldIncludeUsage = false
info.ClaudeConvertInfo = &ClaudeConvertInfo{
LastMessagesType: LastMessageTypeNone,
@@ -150,41 +184,41 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
return info
}
func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
info := GenRelayInfo(c)
func GenRelayInfoRerank(c *gin.Context, request *dto.RerankRequest) *RelayInfo {
info := genBaseRelayInfo(c, request)
info.RelayMode = relayconstant.RelayModeRerank
info.RelayFormat = RelayFormatRerank
info.RelayFormat = types.RelayFormatRerank
info.RerankerInfo = &RerankerInfo{
Documents: req.Documents,
ReturnDocuments: req.GetReturnDocuments(),
Documents: request.Documents,
ReturnDocuments: request.GetReturnDocuments(),
}
return info
}
func GenRelayInfoOpenAIAudio(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c)
info.RelayFormat = RelayFormatOpenAIAudio
func GenRelayInfoOpenAIAudio(c *gin.Context, request dto.Request) *RelayInfo {
info := genBaseRelayInfo(c, request)
info.RelayFormat = types.RelayFormatOpenAIAudio
return info
}
func GenRelayInfoEmbedding(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c)
info.RelayFormat = RelayFormatEmbedding
func GenRelayInfoEmbedding(c *gin.Context, request dto.Request) *RelayInfo {
info := genBaseRelayInfo(c, request)
info.RelayFormat = types.RelayFormatEmbedding
return info
}
func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo {
info := GenRelayInfo(c)
func GenRelayInfoResponses(c *gin.Context, request *dto.OpenAIResponsesRequest) *RelayInfo {
info := genBaseRelayInfo(c, request)
info.RelayMode = relayconstant.RelayModeResponses
info.RelayFormat = RelayFormatOpenAIResponses
info.RelayFormat = types.RelayFormatOpenAIResponses
info.SupportStreamOptions = false
info.ResponsesUsageInfo = &ResponsesUsageInfo{
BuiltInTools: make(map[string]*BuildInToolInfo),
}
if len(req.Tools) > 0 {
for _, tool := range req.Tools {
if len(request.Tools) > 0 {
for _, tool := range request.Tools {
toolType := common.Interface2String(tool["type"])
info.ResponsesUsageInfo.BuiltInTools[toolType] = &BuildInToolInfo{
ToolName: toolType,
@@ -200,104 +234,76 @@ func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *Rel
}
}
}
info.IsStream = req.Stream
return info
}
func GenRelayInfoGemini(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c)
info.RelayFormat = RelayFormatGemini
func GenRelayInfoGemini(c *gin.Context, request dto.Request) *RelayInfo {
info := genBaseRelayInfo(c, request)
info.RelayFormat = types.RelayFormatGemini
info.ShouldIncludeUsage = false
return info
}
func GenRelayInfoImage(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c)
info.RelayFormat = RelayFormatOpenAIImage
func GenRelayInfoImage(c *gin.Context, request dto.Request) *RelayInfo {
info := genBaseRelayInfo(c, request)
info.RelayFormat = types.RelayFormatOpenAIImage
return info
}
func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
func GenRelayInfoOpenAI(c *gin.Context, request dto.Request) *RelayInfo {
info := genBaseRelayInfo(c, request)
info.RelayFormat = types.RelayFormatOpenAI
return info
}
func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
//channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
//channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
//paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId)
tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey)
userId := common.GetContextKeyInt(c, constant.ContextKeyUserId)
tokenUnlimited := common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited)
startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
if startTime.IsZero() {
startTime = time.Now()
}
// firstResponseTime = time.Now() - 1 second
apiType, _ := common.ChannelType2APIType(channelType)
info := &RelayInfo{
UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
isFirstResponse: true,
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl),
RequestURLPath: c.Request.URL.String(),
ChannelType: channelType,
ChannelId: channelId,
TokenId: tokenId,
TokenKey: tokenKey,
UserId: userId,
UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup),
UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup),
TokenUnlimited: tokenUnlimited,
Request: request,
UserId: common.GetContextKeyInt(c, constant.ContextKeyUserId),
UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup),
UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup),
UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
PromptTokens: common.GetContextKeyInt(c, constant.ContextKeyPromptTokens),
TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId),
TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey),
TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited),
isFirstResponse: true,
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
RequestURLPath: c.Request.URL.String(),
IsStream: request.IsStream(c),
StartTime: startTime,
FirstResponseTime: startTime.Add(-time.Second),
OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
//RecodeModelName: c.GetString("original_model"),
IsModelMapped: false,
ApiType: apiType,
ApiVersion: c.GetString("api_version"),
ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey),
Organization: c.GetString("channel_organization"),
ChannelCreateTime: c.GetInt64("channel_create_time"),
ParamOverride: paramOverride,
RelayFormat: RelayFormatOpenAI,
ThinkingContentInfo: ThinkingContentInfo{
IsFirstThinkingContent: true,
SendLastThinkingContent: false,
},
ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey),
ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex),
}
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
info.IsPlayground = true
info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg")
info.RequestURLPath = "/v1" + info.RequestURLPath
}
if info.BaseUrl == "" {
info.BaseUrl = constant.ChannelBaseURLs[channelType]
}
if info.ChannelType == constant.ChannelTypeAzure {
info.ApiVersion = GetAPIVersion(c)
}
if info.ChannelType == constant.ChannelTypeVertexAi {
info.ApiVersion = c.GetString("region")
}
if streamSupportedChannels[info.ChannelType] {
info.SupportStreamOptions = true
}
channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
if ok {
info.ChannelSetting = channelSetting
}
channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting)
if ok {
info.ChannelOtherSettings = channelOtherSettings
}
userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting)
if ok {
@@ -307,12 +313,39 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
return info
}
func (info *RelayInfo) SetPromptTokens(promptTokens int) {
info.PromptTokens = promptTokens
func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Request, ws *websocket.Conn) (*RelayInfo, error) {
switch relayFormat {
case types.RelayFormatOpenAI:
return GenRelayInfoOpenAI(c, request), nil
case types.RelayFormatOpenAIAudio:
return GenRelayInfoOpenAIAudio(c, request), nil
case types.RelayFormatOpenAIImage:
return GenRelayInfoImage(c, request), nil
case types.RelayFormatOpenAIRealtime:
return GenRelayInfoWs(c, ws), nil
case types.RelayFormatClaude:
return GenRelayInfoClaude(c, request), nil
case types.RelayFormatRerank:
if request, ok := request.(*dto.RerankRequest); ok {
return GenRelayInfoRerank(c, request), nil
}
return nil, errors.New("request is not a RerankRequest")
case types.RelayFormatGemini:
return GenRelayInfoGemini(c, request), nil
case types.RelayFormatEmbedding:
return GenRelayInfoEmbedding(c, request), nil
case types.RelayFormatOpenAIResponses:
if request, ok := request.(*dto.OpenAIResponsesRequest); ok {
return GenRelayInfoResponses(c, request), nil
}
return nil, errors.New("request is not a OpenAIResponsesRequest")
default:
return nil, errors.New("invalid relay format")
}
}
func (info *RelayInfo) SetIsStream(isStream bool) {
info.IsStream = isStream
func (info *RelayInfo) SetPromptTokens(promptTokens int) {
info.PromptTokens = promptTokens
}
func (info *RelayInfo) SetFirstResponseTime() {

View File

@@ -8,6 +8,7 @@ import (
"one-api/dto"
"one-api/relay/channel/xinference"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/types"
"github.com/gin-gonic/gin"
@@ -18,7 +19,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
service.CloseResponseBodyGracefully(resp)
if common.DebugEnabled {
println("reranker response body: ", string(responseBody))
}

View File

@@ -8,7 +8,6 @@ import (
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
@@ -16,69 +15,27 @@ import (
"github.com/gin-gonic/gin"
)
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
token := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
return token
}
func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embeddingRequest dto.EmbeddingRequest) error {
if embeddingRequest.Input == nil {
return fmt.Errorf("input is empty")
}
if info.RelayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
embeddingRequest.Model = "omni-moderation-latest"
}
if info.RelayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
embeddingRequest.Model = c.Param("model")
}
return nil
}
info.InitChannelMeta(c)
func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
relayInfo := relaycommon.GenRelayInfoEmbedding(c)
var embeddingRequest *dto.EmbeddingRequest
err := common.UnmarshalBodyReusable(c, &embeddingRequest)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
embeddingRequest, ok := info.Request.(*dto.EmbeddingRequest)
if !ok {
common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ClaudeRequest, got %T", info.Request))
}
err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest)
err := helper.ModelMappedHelper(c, info, embeddingRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
promptToken := getEmbeddingPromptToken(*embeddingRequest)
relayInfo.PromptTokens = promptToken
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newAPIError != nil {
return newAPIError
}
defer func() {
if newAPIError != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType)
adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
adaptor.Init(info)
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest)
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, info, *embeddingRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
@@ -88,7 +45,7 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
requestBody := bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
@@ -104,12 +61,12 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
}
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
if newAPIError != nil {
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
postConsumeQuota(c, info, usage.(*dto.Usage), "")
return nil
}

View File

@@ -2,17 +2,16 @@ package relay
import (
"bytes"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/logger"
"one-api/relay/channel/gemini"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
"one-api/setting/model_setting"
"one-api/types"
"strings"
@@ -20,64 +19,6 @@ import (
"github.com/gin-gonic/gin"
)
func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) {
request := &dto.GeminiChatRequest{}
err := common.UnmarshalBodyReusable(c, request)
if err != nil {
return nil, err
}
if len(request.Contents) == 0 {
return nil, errors.New("contents is required")
}
return request, nil
}
// 流模式
// /v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse&key=xxx
func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
if c.Query("alt") == "sse" {
relayInfo.IsStream = true
}
// if strings.Contains(c.Request.URL.Path, "streamGenerateContent") {
// relayInfo.IsStream = true
// }
}
func checkGeminiInputSensitive(textRequest *dto.GeminiChatRequest) ([]string, error) {
var inputTexts []string
for _, content := range textRequest.Contents {
for _, part := range content.Parts {
if part.Text != "" {
inputTexts = append(inputTexts, part.Text)
}
}
}
if len(inputTexts) == 0 {
return nil, nil
}
sensitiveWords, err := service.CheckSensitiveInput(inputTexts)
return sensitiveWords, err
}
func getGeminiInputTokens(req *dto.GeminiChatRequest, info *relaycommon.RelayInfo) int {
// 计算输入 token 数量
var inputTexts []string
for _, content := range req.Contents {
for _, part := range content.Parts {
if part.Text != "" {
inputTexts = append(inputTexts, part.Text)
}
}
}
inputText := strings.Join(inputTexts, "\n")
inputTokens := service.CountTokenInput(inputText, info.UpstreamModelName)
info.PromptTokens = inputTokens
return inputTokens
}
func isNoThinkingRequest(req *dto.GeminiChatRequest) bool {
if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil {
configBudget := req.GenerationConfig.ThinkingConfig.ThinkingBudget
@@ -109,97 +50,61 @@ func trimModelThinking(modelName string) string {
return modelName
}
func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
req, err := getAndValidateGeminiRequest(c)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error()))
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
info.InitChannelMeta(c)
relayInfo := relaycommon.GenRelayInfoGemini(c)
// 检查 Gemini 流式模式
checkGeminiStreamMode(c, relayInfo)
if setting.ShouldCheckPromptSensitive() {
sensitiveWords, err := checkGeminiInputSensitive(req)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry())
}
request, ok := info.Request.(*dto.GeminiChatRequest)
if !ok {
common.FatalLog(fmt.Sprintf("invalid request type, expected dto.GeminiChatRequest, got %T", info.Request))
}
// model mapped 模型映射
err = helper.ModelMappedHelper(c, relayInfo, req)
err := helper.ModelMappedHelper(c, info, request)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
if value, exists := c.Get("prompt_tokens"); exists {
promptTokens := value.(int)
relayInfo.SetPromptTokens(promptTokens)
} else {
promptTokens := getGeminiInputTokens(req, relayInfo)
c.Set("prompt_tokens", promptTokens)
}
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
if isNoThinkingRequest(req) {
if isNoThinkingRequest(request) {
// check is thinking
if !strings.Contains(relayInfo.OriginModelName, "-nothinking") {
if !strings.Contains(info.OriginModelName, "-nothinking") {
// try to get no thinking model price
noThinkingModelName := relayInfo.OriginModelName + "-nothinking"
noThinkingModelName := info.OriginModelName + "-nothinking"
containPrice := helper.ContainPriceOrRatio(noThinkingModelName)
if containPrice {
relayInfo.OriginModelName = noThinkingModelName
relayInfo.UpstreamModelName = noThinkingModelName
info.OriginModelName = noThinkingModelName
info.UpstreamModelName = noThinkingModelName
}
}
}
if req.GenerationConfig.ThinkingConfig == nil {
gemini.ThinkingAdaptor(req, relayInfo)
if request.GenerationConfig.ThinkingConfig == nil {
gemini.ThinkingAdaptor(request, info)
}
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
// pre consume quota
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newAPIError != nil {
return newAPIError
}
defer func() {
if newAPIError != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType)
adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
adaptor.Init(info)
// Clean up empty system instruction
if req.SystemInstructions != nil {
if request.SystemInstructions != nil {
hasContent := false
for _, part := range req.SystemInstructions.Parts {
for _, part := range request.SystemInstructions.Parts {
if part.Text != "" {
hasContent = true
break
}
}
if !hasContent {
req.SystemInstructions = nil
request.SystemInstructions = nil
}
}
var requestBody io.Reader
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
@@ -207,7 +112,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
requestBody = bytes.NewReader(body)
} else {
// 使用 ConvertGeminiRequest 转换请求格式
convertedRequest, err := adaptor.ConvertGeminiRequest(c, relayInfo, req)
convertedRequest, err := adaptor.ConvertGeminiRequest(c, info, request)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
@@ -217,10 +122,10 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
// apply param override
if len(relayInfo.ParamOverride) > 0 {
if len(info.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
_ = common.Unmarshal(jsonData, &reqMap)
for key, value := range relayInfo.ParamOverride {
for key, value := range info.ParamOverride {
reqMap[key] = value
}
jsonData, err = common.Marshal(reqMap)
@@ -229,15 +134,14 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
}
if common.DebugEnabled {
println("Gemini request body: %s", string(jsonData))
}
logger.LogDebug(c, "Gemini request body: "+string(jsonData))
requestBody = bytes.NewReader(jsonData)
}
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
common.LogError(c, "Do gemini request failed: "+err.Error())
logger.LogError(c, "Do gemini request failed: "+err.Error())
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
@@ -246,7 +150,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
var httpResp *http.Response
if resp != nil {
httpResp = resp.(*http.Response)
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
newAPIError = service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
@@ -255,23 +159,22 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
}
usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo)
usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), info)
if openaiErr != nil {
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
postConsumeQuota(c, info, usage.(*dto.Usage), "")
return nil
}
func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
relayInfo := relaycommon.GenRelayInfoGemini(c)
func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
info.InitChannelMeta(c)
isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents")
relayInfo.IsGeminiBatchEmbedding = isBatch
info.IsGeminiBatchEmbedding = isBatch
var promptTokens int
var req any
var err error
var inputTexts []string
@@ -303,35 +206,17 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
}
}
}
promptTokens = service.CountTokenInput(strings.Join(inputTexts, "\n"), relayInfo.UpstreamModelName)
relayInfo.SetPromptTokens(promptTokens)
c.Set("prompt_tokens", promptTokens)
err = helper.ModelMappedHelper(c, relayInfo, req)
err = helper.ModelMappedHelper(c, info, req)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, 0)
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newAPIError != nil {
return newAPIError
}
defer func() {
if newAPIError != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType)
adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
adaptor.Init(info)
var requestBody io.Reader
jsonData, err := common.Marshal(req)
@@ -340,10 +225,10 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
}
// apply param override
if len(relayInfo.ParamOverride) > 0 {
if len(info.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
_ = common.Unmarshal(jsonData, &reqMap)
for key, value := range relayInfo.ParamOverride {
for key, value := range info.ParamOverride {
reqMap[key] = value
}
jsonData, err = common.Marshal(reqMap)
@@ -353,9 +238,9 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
}
requestBody = bytes.NewReader(jsonData)
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
common.LogError(c, "Do gemini request failed: "+err.Error())
logger.LogError(c, "Do gemini request failed: "+err.Error())
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
@@ -370,12 +255,12 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
}
}
usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo)
usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), info)
if openaiErr != nil {
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
postConsumeQuota(c, info, usage.(*dto.Usage), "")
return nil
}

View File

@@ -7,6 +7,7 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
"one-api/logger"
"one-api/types"
"github.com/gin-gonic/gin"
@@ -100,7 +101,7 @@ func Done(c *gin.Context) {
func WssString(c *gin.Context, ws *websocket.Conn, str string) error {
if ws == nil {
common.LogError(c, "websocket connection is nil")
logger.LogError(c, "websocket connection is nil")
return errors.New("websocket connection is nil")
}
//common.LogInfo(c, fmt.Sprintf("sending message: %s", str))
@@ -113,7 +114,7 @@ func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
return fmt.Errorf("error marshalling object: %w", err)
}
if ws == nil {
common.LogError(c, "websocket connection is nil")
logger.LogError(c, "websocket connection is nil")
return errors.New("websocket connection is nil")
}
//common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData))

View File

@@ -4,9 +4,10 @@ import (
"encoding/json"
"errors"
"fmt"
common2 "one-api/common"
"one-api/dto"
common2 "one-api/logger"
"one-api/relay/common"
"one-api/types"
"github.com/gin-gonic/gin"
)
@@ -54,29 +55,29 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) erro
}
if request != nil {
switch info.RelayFormat {
case common.RelayFormatGemini:
case types.RelayFormatGemini:
// Gemini 模型映射
case common.RelayFormatClaude:
case types.RelayFormatClaude:
if claudeRequest, ok := request.(*dto.ClaudeRequest); ok {
claudeRequest.Model = info.UpstreamModelName
}
case common.RelayFormatOpenAIResponses:
case types.RelayFormatOpenAIResponses:
if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok {
openAIResponsesRequest.Model = info.UpstreamModelName
}
case common.RelayFormatOpenAIAudio:
case types.RelayFormatOpenAIAudio:
if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok {
openAIAudioRequest.Model = info.UpstreamModelName
}
case common.RelayFormatOpenAIImage:
case types.RelayFormatOpenAIImage:
if imageRequest, ok := request.(*dto.ImageRequest); ok {
imageRequest.Model = info.UpstreamModelName
}
case common.RelayFormatRerank:
case types.RelayFormatRerank:
if rerankRequest, ok := request.(*dto.RerankRequest); ok {
rerankRequest.Model = info.UpstreamModelName
}
case common.RelayFormatEmbedding:
case types.RelayFormatEmbedding:
if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok {
embeddingRequest.Model = info.UpstreamModelName
}

View File

@@ -5,35 +5,14 @@ import (
"one-api/common"
relaycommon "one-api/relay/common"
"one-api/setting/ratio_setting"
"one-api/types"
"github.com/gin-gonic/gin"
)
type GroupRatioInfo struct {
GroupRatio float64
GroupSpecialRatio float64
HasSpecialRatio bool
}
type PriceData struct {
ModelPrice float64
ModelRatio float64
CompletionRatio float64
CacheRatio float64
CacheCreationRatio float64
ImageRatio float64
UsePrice bool
ShouldPreConsumedQuota int
GroupRatioInfo GroupRatioInfo
}
func (p PriceData) ToSetting() string {
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
}
// HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.UsingGroup if present
func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupRatioInfo {
groupRatioInfo := GroupRatioInfo{
func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) types.GroupRatioInfo {
groupRatioInfo := types.GroupRatioInfo{
GroupRatio: 1.0, // default ratio
GroupSpecialRatio: -1,
}
@@ -62,7 +41,7 @@ func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupR
return groupRatioInfo
}
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, meta *types.TokenCountMeta) (types.PriceData, error) {
modelPrice, usePrice := ratio_setting.GetModelPrice(info.OriginModelName, false)
groupRatioInfo := HandleGroupRatio(c, info)
@@ -75,8 +54,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
var cacheCreationRatio float64
if !usePrice {
preConsumedTokens := common.PreConsumedQuota
if maxTokens != 0 {
preConsumedTokens = promptTokens + maxTokens
if meta.MaxTokens != 0 {
preConsumedTokens = promptTokens + meta.MaxTokens
}
var success bool
var matchName string
@@ -87,7 +66,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
acceptUnsetRatio = true
}
if !acceptUnsetRatio {
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置请联系管理员设置或开始自用模式Model %s ratio or price not set, please set or start self-use mode", matchName, matchName)
return types.PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置请联系管理员设置或开始自用模式Model %s ratio or price not set, please set or start self-use mode", matchName, matchName)
}
}
completionRatio = ratio_setting.GetCompletionRatio(info.OriginModelName)
@@ -97,10 +76,13 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
ratio := modelRatio * groupRatioInfo.GroupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
if meta.ImagePriceRatio != 0 {
modelPrice = modelPrice * meta.ImagePriceRatio
}
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
}
priceData := PriceData{
priceData := types.PriceData{
ModelPrice: modelPrice,
ModelRatio: modelRatio,
CompletionRatio: completionRatio,
@@ -115,38 +97,32 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
if common.DebugEnabled {
println(fmt.Sprintf("model_price_helper result: %s", priceData.ToSetting()))
}
info.PriceData = priceData
return priceData, nil
}
type PerCallPriceData struct {
ModelPrice float64
Quota int
GroupRatioInfo GroupRatioInfo
}
// ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task)
func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) PerCallPriceData {
groupRatioInfo := HandleGroupRatio(c, info)
modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true)
// 如果没有配置价格,则使用默认价格
if !success {
defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName]
if !ok {
modelPrice = 0.1
} else {
modelPrice = defaultPrice
}
}
quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
priceData := PerCallPriceData{
ModelPrice: modelPrice,
Quota: quota,
GroupRatioInfo: groupRatioInfo,
}
return priceData
}
//func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData {
// groupRatioInfo := HandleGroupRatio(c, info)
//
// modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true)
// // 如果没有配置价格,则使用默认价格
// if !success {
// defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName]
// if !ok {
// modelPrice = 0.1
// } else {
// modelPrice = defaultPrice
// }
// }
// quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
// priceData := types.PerCallPriceData{
// ModelPrice: modelPrice,
// Quota: quota,
// GroupRatioInfo: groupRatioInfo,
// }
// return priceData
//}
func ContainPriceOrRatio(modelName string) bool {
_, ok := ratio_setting.GetModelPrice(modelName, false)

View File

@@ -8,6 +8,7 @@ import (
"net/http"
"one-api/common"
"one-api/constant"
"one-api/logger"
relaycommon "one-api/relay/common"
"one-api/setting/operation_setting"
"strings"
@@ -87,7 +88,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
select {
case <-done:
case <-time.After(5 * time.Second):
common.LogError(c, "timeout waiting for goroutines to exit")
logger.LogError(c, "timeout waiting for goroutines to exit")
}
close(stopChan)
@@ -109,7 +110,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
defer func() {
wg.Done()
if r := recover(); r != nil {
common.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r))
logger.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r))
common.SafeSendBool(stopChan, true)
}
if common.DebugEnabled {
@@ -136,14 +137,14 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
select {
case err := <-done:
if err != nil {
common.LogError(c, "ping data error: "+err.Error())
logger.LogError(c, "ping data error: "+err.Error())
return
}
if common.DebugEnabled {
println("ping data sent")
}
case <-time.After(10 * time.Second):
common.LogError(c, "ping data send timeout")
logger.LogError(c, "ping data send timeout")
return
case <-ctx.Done():
return
@@ -158,7 +159,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
// 监听客户端断开连接
return
case <-pingTimeout.C:
common.LogError(c, "ping goroutine max duration reached")
logger.LogError(c, "ping goroutine max duration reached")
return
}
}
@@ -171,7 +172,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
defer func() {
wg.Done()
if r := recover(); r != nil {
common.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r))
logger.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r))
}
common.SafeSendBool(stopChan, true)
if common.DebugEnabled {
@@ -223,7 +224,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
return
}
case <-time.After(10 * time.Second):
common.LogError(c, "data handler timeout")
logger.LogError(c, "data handler timeout")
return
case <-ctx.Done():
return
@@ -241,7 +242,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
if err := scanner.Err(); err != nil {
if err != io.EOF {
common.LogError(c, "scanner error: "+err.Error())
logger.LogError(c, "scanner error: "+err.Error())
}
}
})
@@ -250,12 +251,12 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
select {
case <-ticker.C:
// 超时处理逻辑
common.LogError(c, "streaming timeout")
logger.LogError(c, "streaming timeout")
case <-stopChan:
// 正常结束
common.LogInfo(c, "streaming finished")
logger.LogInfo(c, "streaming finished")
case <-c.Request.Context().Done():
// 客户端断开连接
common.LogInfo(c, "client disconnected")
logger.LogInfo(c, "client disconnected")
}
}

View File

@@ -0,0 +1,301 @@
package helper
import (
"errors"
"fmt"
"math"
"one-api/common"
"one-api/dto"
"one-api/logger"
relayconstant "one-api/relay/constant"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dto.Request, err error) {
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
switch format {
case types.RelayFormatOpenAI:
request, err = GetAndValidateTextRequest(c, relayMode)
case types.RelayFormatGemini:
request, err = GetAndValidateGeminiRequest(c)
case types.RelayFormatClaude:
request, err = GetAndValidateClaudeRequest(c)
case types.RelayFormatOpenAIResponses:
request, err = GetAndValidateResponsesRequest(c)
case types.RelayFormatOpenAIImage:
request, err = GetAndValidOpenAIImageRequest(c, relayMode)
case types.RelayFormatEmbedding:
request, err = GetAndValidateEmbeddingRequest(c, relayMode)
case types.RelayFormatRerank:
request, err = GetAndValidateRerankRequest(c)
case types.RelayFormatOpenAIAudio:
request, err = GetAndValidAudioRequest(c, relayMode)
case types.RelayFormatOpenAIRealtime:
// nothing to do, no request body
default:
return nil, fmt.Errorf("unsupported relay format: %s", format)
}
return request, err
}
func GetAndValidAudioRequest(c *gin.Context, relayMode int) (*dto.AudioRequest, error) {
audioRequest := &dto.AudioRequest{}
err := common.UnmarshalBodyReusable(c, audioRequest)
if err != nil {
return nil, err
}
switch relayMode {
case relayconstant.RelayModeAudioSpeech:
if audioRequest.Model == "" {
return nil, errors.New("model is required")
}
default:
err = c.Request.ParseForm()
if err != nil {
return nil, err
}
formData := c.Request.PostForm
if audioRequest.Model == "" {
audioRequest.Model = formData.Get("model")
}
if audioRequest.Model == "" {
return nil, errors.New("model is required")
}
audioRequest.ResponseFormat = formData.Get("response_format")
if audioRequest.ResponseFormat == "" {
audioRequest.ResponseFormat = "json"
}
}
return audioRequest, nil
}
func GetAndValidateRerankRequest(c *gin.Context) (*dto.RerankRequest, error) {
var rerankRequest *dto.RerankRequest
err := common.UnmarshalBodyReusable(c, &rerankRequest)
if err != nil {
logger.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
return nil, types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
if rerankRequest.Query == "" {
return nil, types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
if len(rerankRequest.Documents) == 0 {
return nil, types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
return rerankRequest, nil
}
func GetAndValidateEmbeddingRequest(c *gin.Context, relayMode int) (*dto.EmbeddingRequest, error) {
var embeddingRequest *dto.EmbeddingRequest
err := common.UnmarshalBodyReusable(c, &embeddingRequest)
if err != nil {
logger.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
return nil, types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
if embeddingRequest.Input == nil {
return nil, fmt.Errorf("input is empty")
}
if relayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
embeddingRequest.Model = "omni-moderation-latest"
}
if relayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
embeddingRequest.Model = c.Param("model")
}
return embeddingRequest, nil
}
func GetAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) {
request := &dto.OpenAIResponsesRequest{}
err := common.UnmarshalBodyReusable(c, request)
if err != nil {
return nil, err
}
if request.Model == "" {
return nil, errors.New("model is required")
}
if request.Input == nil {
return nil, errors.New("input is required")
}
return request, nil
}
func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageRequest, error) {
imageRequest := &dto.ImageRequest{}
switch relayMode {
case relayconstant.RelayModeImagesEdits:
_, err := c.MultipartForm()
if err != nil {
return nil, err
}
formData := c.Request.PostForm
imageRequest.Prompt = formData.Get("prompt")
imageRequest.Model = formData.Get("model")
imageRequest.N = uint(common.String2Int(formData.Get("n")))
imageRequest.Quality = formData.Get("quality")
imageRequest.Size = formData.Get("size")
if imageRequest.Model == "gpt-image-1" {
if imageRequest.Quality == "" {
imageRequest.Quality = "standard"
}
}
if imageRequest.N == 0 {
imageRequest.N = 1
}
watermark := formData.Has("watermark")
if watermark {
imageRequest.Watermark = &watermark
}
default:
err := common.UnmarshalBodyReusable(c, imageRequest)
if err != nil {
return nil, err
}
if imageRequest.Model == "" {
imageRequest.Model = "dall-e-3"
}
if strings.Contains(imageRequest.Size, "×") {
return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'")
}
// Not "256x256", "512x512", or "1024x1024"
if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e")
}
if imageRequest.Size == "" {
imageRequest.Size = "1024x1024"
}
} else if imageRequest.Model == "dall-e-3" {
if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3")
}
if imageRequest.Quality == "" {
imageRequest.Quality = "standard"
}
if imageRequest.Size == "" {
imageRequest.Size = "1024x1024"
}
} else if imageRequest.Model == "gpt-image-1" {
if imageRequest.Quality == "" {
imageRequest.Quality = "auto"
}
}
if imageRequest.Prompt == "" {
return nil, errors.New("prompt is required")
}
if imageRequest.N == 0 {
imageRequest.N = 1
}
}
return imageRequest, nil
}
func GetAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) {
textRequest = &dto.ClaudeRequest{}
err = c.ShouldBindJSON(textRequest)
if err != nil {
return nil, err
}
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
return nil, errors.New("field messages is required")
}
if textRequest.Model == "" {
return nil, errors.New("field model is required")
}
//if textRequest.Stream {
// relayInfo.IsStream = true
//}
return textRequest, nil
}
func GetAndValidateTextRequest(c *gin.Context, relayMode int) (*dto.GeneralOpenAIRequest, error) {
textRequest := &dto.GeneralOpenAIRequest{}
err := common.UnmarshalBodyReusable(c, textRequest)
if err != nil {
return nil, err
}
if relayMode == relayconstant.RelayModeModerations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
}
if relayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" {
textRequest.Model = c.Param("model")
}
if textRequest.MaxTokens > math.MaxInt32/2 {
return nil, errors.New("max_tokens is invalid")
}
if textRequest.Model == "" {
return nil, errors.New("model is required")
}
if textRequest.WebSearchOptions != nil {
if textRequest.WebSearchOptions.SearchContextSize != "" {
validSizes := map[string]bool{
"high": true,
"medium": true,
"low": true,
}
if !validSizes[textRequest.WebSearchOptions.SearchContextSize] {
return nil, errors.New("invalid search_context_size, must be one of: high, medium, low")
}
} else {
textRequest.WebSearchOptions.SearchContextSize = "medium"
}
}
switch relayMode {
case relayconstant.RelayModeCompletions:
if textRequest.Prompt == "" {
return nil, errors.New("field prompt is required")
}
case relayconstant.RelayModeChatCompletions:
if len(textRequest.Messages) == 0 {
return nil, errors.New("field messages is required")
}
case relayconstant.RelayModeEmbeddings:
case relayconstant.RelayModeModerations:
if textRequest.Input == nil || textRequest.Input == "" {
return nil, errors.New("field input is required")
}
case relayconstant.RelayModeEdits:
if textRequest.Instruction == "" {
return nil, errors.New("field instruction is required")
}
}
return textRequest, nil
}
func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) {
request := &dto.GeminiChatRequest{}
err := common.UnmarshalBodyReusable(c, request)
if err != nil {
return nil, err
}
if len(request.Contents) == 0 {
return nil, errors.New("contents is required")
}
//if c.Query("alt") == "sse" {
// relayInfo.IsStream = true
//}
return request, nil
}

View File

@@ -3,19 +3,15 @@ package relay
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
"one-api/setting/model_setting"
"one-api/types"
"strings"
@@ -23,183 +19,41 @@ import (
"github.com/gin-gonic/gin"
)
func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.ImageRequest, error) {
imageRequest := &dto.ImageRequest{}
func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
switch info.RelayMode {
case relayconstant.RelayModeImagesEdits:
_, err := c.MultipartForm()
if err != nil {
return nil, err
}
formData := c.Request.PostForm
imageRequest.Prompt = formData.Get("prompt")
imageRequest.Model = formData.Get("model")
imageRequest.N = common.String2Int(formData.Get("n"))
imageRequest.Quality = formData.Get("quality")
imageRequest.Size = formData.Get("size")
info.InitChannelMeta(c)
if imageRequest.Model == "gpt-image-1" {
if imageRequest.Quality == "" {
imageRequest.Quality = "standard"
}
}
if imageRequest.N == 0 {
imageRequest.N = 1
}
imageRequest, ok := info.Request.(*dto.ImageRequest)
if info.ApiType == constant.APITypeVolcEngine {
watermark := formData.Has("watermark")
imageRequest.Watermark = &watermark
}
default:
err := common.UnmarshalBodyReusable(c, imageRequest)
if err != nil {
return nil, err
}
if imageRequest.Model == "" {
imageRequest.Model = "dall-e-3"
}
if strings.Contains(imageRequest.Size, "×") {
return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'")
}
// Not "256x256", "512x512", or "1024x1024"
if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e")
}
if imageRequest.Size == "" {
imageRequest.Size = "1024x1024"
}
} else if imageRequest.Model == "dall-e-3" {
if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3")
}
if imageRequest.Quality == "" {
imageRequest.Quality = "standard"
}
if imageRequest.Size == "" {
imageRequest.Size = "1024x1024"
}
} else if imageRequest.Model == "gpt-image-1" {
if imageRequest.Quality == "" {
imageRequest.Quality = "auto"
}
}
if imageRequest.Prompt == "" {
return nil, errors.New("prompt is required")
}
if imageRequest.N == 0 {
imageRequest.N = 1
}
if !ok {
common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ImageRequest, got %T", info.Request))
}
if setting.ShouldCheckPromptSensitive() {
words, err := service.CheckSensitiveInput(imageRequest.Prompt)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
return nil, err
}
}
return imageRequest, nil
}
func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
relayInfo := relaycommon.GenRelayInfoImage(c)
imageRequest, err := getAndValidImageRequest(c, relayInfo)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error()))
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
err = helper.ModelMappedHelper(c, relayInfo, imageRequest)
err := helper.ModelMappedHelper(c, info, imageRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
var preConsumedQuota int
var quota int
var userQuota int
if !priceData.UsePrice {
// modelRatio 16 = modelPrice $0.04
// per 1 modelRatio = $0.04 / 16
// priceData.ModelPrice = 0.0025 * priceData.ModelRatio
preConsumedQuota, userQuota, newAPIError = preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newAPIError != nil {
return newAPIError
}
defer func() {
if newAPIError != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
} else {
sizeRatio := 1.0
qualityRatio := 1.0
if strings.HasPrefix(imageRequest.Model, "dall-e") {
// Size
if imageRequest.Size == "256x256" {
sizeRatio = 0.4
} else if imageRequest.Size == "512x512" {
sizeRatio = 0.45
} else if imageRequest.Size == "1024x1024" {
sizeRatio = 1
} else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
sizeRatio = 2
}
if imageRequest.Model == "dall-e-3" && imageRequest.Quality == "hd" {
qualityRatio = 2.0
if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
qualityRatio = 1.5
}
}
}
// reset model price
priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit)
userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
}
if userQuota-quota < 0 {
return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota, types.ErrOptionWithSkipRetry())
}
}
adaptor := GetAdaptor(relayInfo.ApiType)
adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
adaptor.Init(info)
var requestBody io.Reader
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
requestBody = bytes.NewBuffer(body)
} else {
convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest)
convertedRequest, err := adaptor.ConvertImageRequest(c, info, *imageRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits {
if info.RelayMode == relayconstant.RelayModeImagesEdits {
requestBody = convertedRequest.(io.Reader)
} else {
jsonData, err := json.Marshal(convertedRequest)
@@ -208,10 +62,10 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
// apply param override
if len(relayInfo.ParamOverride) > 0 {
if len(info.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
_ = common.Unmarshal(jsonData, &reqMap)
for key, value := range relayInfo.ParamOverride {
for key, value := range info.ParamOverride {
reqMap[key] = value
}
jsonData, err = common.Marshal(reqMap)
@@ -229,14 +83,14 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
var httpResp *http.Response
if resp != nil {
httpResp = resp.(*http.Response)
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
newAPIError = service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
@@ -245,7 +99,7 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
}
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
if newAPIError != nil {
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
@@ -253,17 +107,23 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
if usage.(*dto.Usage).TotalTokens == 0 {
usage.(*dto.Usage).TotalTokens = imageRequest.N
usage.(*dto.Usage).TotalTokens = int(imageRequest.N)
}
if usage.(*dto.Usage).PromptTokens == 0 {
usage.(*dto.Usage).PromptTokens = imageRequest.N
usage.(*dto.Usage).PromptTokens = int(imageRequest.N)
}
quality := "standard"
if imageRequest.Quality == "hd" {
quality = "hd"
}
logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, logContent)
var logContent string
if len(imageRequest.Size) > 0 {
logContent = fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
}
postConsumeQuota(c, info, usage.(*dto.Usage), logContent)
return nil
}

View File

@@ -2,172 +2,56 @@ package relay
import (
"bytes"
"errors"
"fmt"
"io"
"math"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
"one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
"one-api/setting/model_setting"
"one-api/setting/operation_setting"
"one-api/types"
"strings"
"time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/shopspring/decimal"
"github.com/gin-gonic/gin"
)
func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
textRequest := &dto.GeneralOpenAIRequest{}
err := common.UnmarshalBodyReusable(c, textRequest)
if err != nil {
return nil, err
}
if relayInfo.RelayMode == relayconstant.RelayModeModerations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
}
if relayInfo.RelayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" {
textRequest.Model = c.Param("model")
}
func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
if textRequest.MaxTokens > math.MaxInt32/2 {
return nil, errors.New("max_tokens is invalid")
}
if textRequest.Model == "" {
return nil, errors.New("model is required")
}
if textRequest.WebSearchOptions != nil {
if textRequest.WebSearchOptions.SearchContextSize != "" {
validSizes := map[string]bool{
"high": true,
"medium": true,
"low": true,
}
if !validSizes[textRequest.WebSearchOptions.SearchContextSize] {
return nil, errors.New("invalid search_context_size, must be one of: high, medium, low")
}
} else {
textRequest.WebSearchOptions.SearchContextSize = "medium"
}
}
switch relayInfo.RelayMode {
case relayconstant.RelayModeCompletions:
if textRequest.Prompt == "" {
return nil, errors.New("field prompt is required")
}
case relayconstant.RelayModeChatCompletions:
if len(textRequest.Messages) == 0 {
return nil, errors.New("field messages is required")
}
case relayconstant.RelayModeEmbeddings:
case relayconstant.RelayModeModerations:
if textRequest.Input == nil || textRequest.Input == "" {
return nil, errors.New("field input is required")
}
case relayconstant.RelayModeEdits:
if textRequest.Instruction == "" {
return nil, errors.New("field instruction is required")
}
}
relayInfo.IsStream = textRequest.Stream
return textRequest, nil
}
info.InitChannelMeta(c)
func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
textRequest, ok := info.Request.(*dto.GeneralOpenAIRequest)
relayInfo := relaycommon.GenRelayInfo(c)
// get & validate textRequest 获取并验证文本请求
textRequest, err := getAndValidateTextRequest(c, relayInfo)
if err != nil {
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
if !ok {
//return types.NewErrorWithStatusCode(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
common.FatalLog("invalid request type, expected dto.GeneralOpenAIRequest, got %T", info.Request)
}
if textRequest.WebSearchOptions != nil {
c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
}
if setting.ShouldCheckPromptSensitive() {
words, err := checkRequestSensitive(textRequest, relayInfo)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry())
}
}
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
err := helper.ModelMappedHelper(c, info, textRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
// 获取 promptTokens如果上下文中已经存在则直接使用
var promptTokens int
if value, exists := c.Get("prompt_tokens"); exists {
promptTokens = value.(int)
relayInfo.PromptTokens = promptTokens
} else {
promptTokens, err = getPromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误
if err != nil {
return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry())
}
c.Set("prompt_tokens", promptTokens)
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens))))
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, newApiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newApiErr != nil {
return newApiErr
}
defer func() {
if newApiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
includeUsage := true
// 判断用户是否需要返回使用情况
if textRequest.StreamOptions != nil {
includeUsage = textRequest.StreamOptions.IncludeUsage
}
// 如果不支持StreamOptions将StreamOptions设置为nil
if !relayInfo.SupportStreamOptions || !textRequest.Stream {
textRequest.StreamOptions = nil
} else {
// 如果支持StreamOptions且请求中没有设置StreamOptions根据配置文件设置StreamOptions
if constant.ForceStreamOption {
textRequest.StreamOptions = &dto.StreamOptions{
IncludeUsage: true,
}
}
}
relayInfo.ShouldIncludeUsage = includeUsage
adaptor := GetAdaptor(relayInfo.ApiType)
adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
adaptor.Init(info)
var requestBody io.Reader
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
@@ -177,12 +61,12 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
requestBody = bytes.NewBuffer(body)
} else {
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest)
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, textRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
if relayInfo.ChannelSetting.SystemPrompt != "" {
if info.ChannelSetting.SystemPrompt != "" {
// 如果有系统提示,则将其添加到请求中
request := convertedRequest.(*dto.GeneralOpenAIRequest)
containSystemPrompt := false
@@ -196,22 +80,22 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
// 如果没有系统提示,则添加系统提示
systemMessage := dto.Message{
Role: request.GetSystemRoleName(),
Content: relayInfo.ChannelSetting.SystemPrompt,
Content: info.ChannelSetting.SystemPrompt,
}
request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
} else if relayInfo.ChannelSetting.SystemPromptOverride {
} else if info.ChannelSetting.SystemPromptOverride {
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
// 如果有系统提示,且允许覆盖,则拼接到前面
for i, message := range request.Messages {
if message.Role == request.GetSystemRoleName() {
if message.IsStringContent() {
request.Messages[i].SetStringContent(relayInfo.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
} else {
contents := message.ParseContent()
contents = append([]dto.MediaContent{
{
Type: dto.ContentTypeText,
Text: relayInfo.ChannelSetting.SystemPrompt,
Text: info.ChannelSetting.SystemPrompt,
},
}, contents...)
request.Messages[i].Content = contents
@@ -228,10 +112,10 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
// apply param override
if len(relayInfo.ParamOverride) > 0 {
if len(info.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
_ = common.Unmarshal(jsonData, &reqMap)
for key, value := range relayInfo.ParamOverride {
for key, value := range info.ParamOverride {
reqMap[key] = value
}
jsonData, err = common.Marshal(reqMap)
@@ -240,14 +124,13 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
}
if common.DebugEnabled {
println("requestBody: ", string(jsonData))
}
logger.LogDebug(c, fmt.Sprintf("text request body: %s", string(jsonData)))
requestBody = bytes.NewBuffer(jsonData)
}
var httpResp *http.Response
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
@@ -256,125 +139,31 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
if resp != nil {
httpResp = resp.(*http.Response)
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
newApiErr = service.RelayErrorHandler(httpResp, false)
newApiErr := service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(newApiErr, statusCodeMappingStr)
return newApiErr
}
}
usage, newApiErr := adaptor.DoResponse(c, httpResp, relayInfo)
usage, newApiErr := adaptor.DoResponse(c, httpResp, info)
if newApiErr != nil {
// reset status code 重置状态码
service.ResetStatusCode(newApiErr, statusCodeMappingStr)
return newApiErr
}
if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
} else {
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
postConsumeQuota(c, info, usage.(*dto.Usage), "")
}
return nil
}
func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) {
var promptTokens int
var err error
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
promptTokens, err = service.CountTokenChatRequest(info, *textRequest)
case relayconstant.RelayModeCompletions:
promptTokens = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
case relayconstant.RelayModeModerations:
promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
case relayconstant.RelayModeEmbeddings:
promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
default:
err = errors.New("unknown relay mode")
promptTokens = 0
}
info.PromptTokens = promptTokens
return promptTokens, err
}
func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) ([]string, error) {
var err error
var words []string
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
words, err = service.CheckSensitiveMessages(textRequest.Messages)
case relayconstant.RelayModeCompletions:
words, err = service.CheckSensitiveInput(textRequest.Prompt)
case relayconstant.RelayModeModerations:
words, err = service.CheckSensitiveInput(textRequest.Input)
case relayconstant.RelayModeEmbeddings:
words, err = service.CheckSensitiveInput(textRequest.Input)
}
return words, err
}
// 预扣费并返回用户剩余配额
func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *types.NewAPIError) {
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
}
if userQuota <= 0 {
return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
}
if userQuota-preConsumedQuota < 0 {
return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
}
relayInfo.UserQuota = userQuota
if userQuota > 100*preConsumedQuota {
// 用户额度充足,判断令牌额度是否充足
if !relayInfo.TokenUnlimited {
// 非无限令牌,判断令牌额度是否充足
tokenQuota := c.GetInt("token_quota")
if tokenQuota > 100*preConsumedQuota {
// 令牌额度充足,信任令牌
preConsumedQuota = 0
common.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota))
}
} else {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota)))
}
}
if preConsumedQuota > 0 {
err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil {
return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
}
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
if err != nil {
return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
}
}
return preConsumedQuota, userQuota, nil
}
func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) {
if preConsumedQuota != 0 {
gopool.Go(func() {
relayInfoCopy := *relayInfo
err := service.PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
if err != nil {
common.SysError("error return pre-consumed quota: " + err.Error())
}
})
}
}
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) {
if usage == nil {
usage = &dto.Usage{
PromptTokens: relayInfo.PromptTokens,
@@ -392,12 +181,12 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
modelName := relayInfo.OriginModelName
tokenName := ctx.GetString("token_name")
completionRatio := priceData.CompletionRatio
cacheRatio := priceData.CacheRatio
imageRatio := priceData.ImageRatio
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatioInfo.GroupRatio
modelPrice := priceData.ModelPrice
completionRatio := relayInfo.PriceData.CompletionRatio
cacheRatio := relayInfo.PriceData.CacheRatio
imageRatio := relayInfo.PriceData.ImageRatio
modelRatio := relayInfo.PriceData.ModelRatio
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
modelPrice := relayInfo.PriceData.ModelPrice
// Convert values to decimal for precise calculation
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
@@ -470,7 +259,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
var audioInputQuota decimal.Decimal
var audioInputPrice float64
if !priceData.UsePrice {
if !relayInfo.PriceData.UsePrice {
baseTokens := dPromptTokens
// 减去 cached tokens
var cachedTokensWithRatio decimal.Decimal
@@ -518,7 +307,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
totalTokens := promptTokens + completionTokens
var logContent string
if !priceData.UsePrice {
if !relayInfo.PriceData.UsePrice {
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, groupRatio)
} else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
@@ -530,8 +319,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
logContent += fmt.Sprintf("(可能是上游超时)")
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
} else {
if !ratio.IsZero() && quota == 0 {
quota = 1
@@ -540,11 +329,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
quotaDelta := quota - preConsumedQuota
quotaDelta := quota - relayInfo.FinalPreConsumedQuota
if quotaDelta != 0 {
err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
err := service.PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
logger.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
}
@@ -560,7 +349,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
if extraContent != "" {
logContent += ", " + extraContent
}
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
if imageTokens != 0 {
other["image"] = true
other["image_ratio"] = imageRatio
@@ -604,7 +393,6 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
UserQuota: userQuota,
UseTimeSeconds: int(useTimeSeconds),
IsStream: relayInfo.IsStream,
Group: relayInfo.UsingGroup,

View File

@@ -10,6 +10,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
"one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
@@ -127,7 +128,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
logger.SysError("error consuming token remain quota: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")

View File

@@ -25,62 +25,33 @@ func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
return token
}
func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError) {
func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
var rerankRequest *dto.RerankRequest
err := common.UnmarshalBodyReusable(c, &rerankRequest)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
rerankRequest, ok := info.Request.(*dto.RerankRequest)
if !ok {
common.FatalLog(fmt.Sprintf("invalid request type, expected dto.RerankRequest, got %T", info.Request))
}
relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest)
if rerankRequest.Query == "" {
return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
if len(rerankRequest.Documents) == 0 {
return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
err = helper.ModelMappedHelper(c, relayInfo, rerankRequest)
err := helper.ModelMappedHelper(c, info, rerankRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
promptToken := getRerankPromptToken(*rerankRequest)
relayInfo.PromptTokens = promptToken
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newAPIError != nil {
return newAPIError
}
defer func() {
if newAPIError != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType)
adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
adaptor.Init(info)
var requestBody io.Reader
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
requestBody = bytes.NewBuffer(body)
} else {
convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
convertedRequest, err := adaptor.ConvertRerankRequest(c, info.RelayMode, *rerankRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
@@ -90,10 +61,10 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
}
// apply param override
if len(relayInfo.ParamOverride) > 0 {
if len(info.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
_ = common.Unmarshal(jsonData, &reqMap)
for key, value := range relayInfo.ParamOverride {
for key, value := range info.ParamOverride {
reqMap[key] = value
}
jsonData, err = common.Marshal(reqMap)
@@ -108,7 +79,7 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
requestBody = bytes.NewBuffer(jsonData)
}
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
@@ -125,12 +96,12 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
}
}
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
if newAPIError != nil {
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
postConsumeQuota(c, info, usage.(*dto.Usage), "")
return nil
}

View File

@@ -3,7 +3,6 @@ package relay
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@@ -12,7 +11,6 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
"one-api/setting/model_setting"
"one-api/types"
"strings"
@@ -20,82 +18,24 @@ import (
"github.com/gin-gonic/gin"
)
func getAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) {
request := &dto.OpenAIResponsesRequest{}
err := common.UnmarshalBodyReusable(c, request)
if err != nil {
return nil, err
}
if request.Model == "" {
return nil, errors.New("model is required")
}
if len(request.Input) == 0 {
return nil, errors.New("input is required")
}
return request, nil
func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
info.InitChannelMeta(c)
}
func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) ([]string, error) {
sensitiveWords, err := service.CheckSensitiveInput(textRequest.Input)
return sensitiveWords, err
}
func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) int {
inputTokens := service.CountTokenInput(req.Input, req.Model)
info.PromptTokens = inputTokens
return inputTokens
}
func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
req, err := getAndValidateResponsesRequest(c)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error()))
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
request, ok := info.Request.(*dto.OpenAIResponsesRequest)
if !ok {
common.FatalLog(fmt.Sprintf("invalid request type, expected dto.OpenAIResponsesRequest, got %T", info.Request))
}
relayInfo := relaycommon.GenRelayInfoResponses(c, req)
if setting.ShouldCheckPromptSensitive() {
sensitiveWords, err := checkInputSensitive(req, relayInfo)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry())
}
}
err = helper.ModelMappedHelper(c, relayInfo, req)
err := helper.ModelMappedHelper(c, info, request)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
if value, exists := c.Get("prompt_tokens"); exists {
promptTokens := value.(int)
relayInfo.SetPromptTokens(promptTokens)
} else {
promptTokens := getInputTokens(req, relayInfo)
c.Set("prompt_tokens", promptTokens)
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.MaxOutputTokens))
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
// pre consume quota
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newAPIError != nil {
return newAPIError
}
defer func() {
if newAPIError != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType)
adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
adaptor.Init(info)
var requestBody io.Reader
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
body, err := common.GetRequestBody(c)
@@ -104,7 +44,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
requestBody = bytes.NewBuffer(body)
} else {
convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, relayInfo, *req)
convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, info, *request)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
@@ -113,13 +53,13 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
// apply param override
if len(relayInfo.ParamOverride) > 0 {
if len(info.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
err = json.Unmarshal(jsonData, &reqMap)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}
for key, value := range relayInfo.ParamOverride {
for key, value := range info.ParamOverride {
reqMap[key] = value
}
jsonData, err = json.Marshal(reqMap)
@@ -135,7 +75,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
var httpResp *http.Response
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
@@ -153,17 +93,17 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
}
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
if newAPIError != nil {
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
} else {
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
postConsumeQuota(c, info, usage.(*dto.Usage), "")
}
return nil
}

View File

@@ -15,13 +15,6 @@ import (
func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIError) {
relayInfo := relaycommon.GenRelayInfoWs(c, ws)
// get & validate textRequest 获取并验证文本请求
//realtimeEvent, err := getAndValidateWssRequest(c, ws)
//if err != nil {
// common.LogError(c, fmt.Sprintf("getAndValidateWssRequest failed: %s", err.Error()))
// return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
//}
err := helper.ModelMappedHelper(c, relayInfo, nil)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())