Merge branch 'alpha' into feature/claude-code

# Conflicts:
#	web/src/components/table/channels/modals/EditChannelModal.jsx
This commit is contained in:
Seefs
2025-07-31 21:19:43 +08:00
52 changed files with 1400 additions and 312 deletions

View File

@@ -62,7 +62,7 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error()))
return types.NewError(err, types.ErrorCodeInvalidRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
promptTokens := 0
@@ -75,7 +75,7 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError)
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
@@ -90,18 +90,18 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
resp, err := adaptor.DoRequest(c, relayInfo, ioReader)

View File

@@ -132,12 +132,12 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
var aliTaskResponse AliResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &aliTaskResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
if aliTaskResponse.Message != "" {

View File

@@ -34,14 +34,14 @@ func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest {
func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
var aliResponse AliRerankResponse
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {

View File

@@ -43,7 +43,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIErro
var fullTextResponse dto.FlexibleEmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&fullTextResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
@@ -179,12 +179,12 @@ func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.U
var aliResponse AliResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return types.WithOpenAIError(types.OpenAIError{

View File

@@ -223,7 +223,7 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
helper.SetEventStreamHeaders(c)
// 处理流式请求的 ping 保活
generalSettings := operation_setting.GetGeneralSetting()
if generalSettings.PingIntervalEnabled {
if generalSettings.PingIntervalEnabled && !info.DisablePing {
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
stopPinger = startPingKeepAlive(c, pingInterval)
// 使用defer确保在任何情况下都能停止ping goroutine

View File

@@ -1,12 +1,10 @@
package gemini
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
@@ -175,6 +173,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeGemini {
if info.IsStream {
info.DisablePing = true
return GeminiTextGenerationStreamHandler(c, info, resp)
} else {
return GeminiTextGenerationHandler(c, info, resp)
@@ -212,60 +211,6 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
return nil, types.NewError(errors.New("not implemented"), types.ErrorCodeBadResponseBody)
}
func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody)
}
_ = resp.Body.Close()
var geminiResponse GeminiImageResponse
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
}
if len(geminiResponse.Predictions) == 0 {
return nil, types.NewError(errors.New("no images generated"), types.ErrorCodeBadResponseBody)
}
// convert to openai format response
openAIResponse := dto.ImageResponse{
Created: common.GetTimestamp(),
Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
}
for _, prediction := range geminiResponse.Predictions {
if prediction.RaiFilteredReason != "" {
continue // skip filtered image
}
openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
B64Json: prediction.BytesBase64Encoded,
})
}
jsonResponse, jsonErr := json.Marshal(openAIResponse)
if jsonErr != nil {
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
// https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
// each image has fixed 258 tokens
const imageTokens = 258
generatedImages := len(openAIResponse.Data)
usage := &dto.Usage{
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
CompletionTokens: 0, // image generation does not calculate completion tokens
TotalTokens: imageTokens * generatedImages,
}
return usage, nil
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}

View File

@@ -20,7 +20,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
// 读取响应体
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if common.DebugEnabled {
@@ -31,7 +31,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
var geminiResponse GeminiChatResponse
err = common.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
// 计算使用量(基于 UsageMetadata
@@ -54,7 +54,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
// 直接返回 Gemini 原生格式的 JSON 响应
jsonResponse, err := common.Marshal(geminiResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
common.IOCopyBytesGracefully(c, resp, jsonResponse)

View File

@@ -907,7 +907,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
if common.DebugEnabled {
@@ -916,10 +916,10 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
var geminiResponse GeminiChatResponse
err = common.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if len(geminiResponse.Candidates) == 0 {
return nil, types.NewError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
fullTextResponse.Model = info.UpstreamModelName
@@ -956,12 +956,12 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
var geminiResponse GeminiEmbeddingResponse
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
// convert to openai format response
@@ -991,9 +991,63 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
jsonResponse, jsonErr := common.Marshal(openAIResponse)
if jsonErr != nil {
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
common.IOCopyBytesGracefully(c, resp, jsonResponse)
return usage, nil
}
func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
_ = resp.Body.Close()
var geminiResponse GeminiImageResponse
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if len(geminiResponse.Predictions) == 0 {
return nil, types.NewOpenAIError(errors.New("no images generated"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
// convert to openai format response
openAIResponse := dto.ImageResponse{
Created: common.GetTimestamp(),
Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
}
for _, prediction := range geminiResponse.Predictions {
if prediction.RaiFilteredReason != "" {
continue // skip filtered image
}
openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
B64Json: prediction.BytesBase64Encoded,
})
}
jsonResponse, jsonErr := json.Marshal(openAIResponse)
if jsonErr != nil {
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
// https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
// each image has fixed 258 tokens
const imageTokens = 258
generatedImages := len(openAIResponse.Data)
usage := &dto.Usage{
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
CompletionTokens: 0, // image generation does not calculate completion tokens
TotalTokens: imageTokens * generatedImages,
}
return usage, nil
}

View File

@@ -52,13 +52,13 @@ func jimengImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.R
var jimengResponse ImageResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &jimengResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
// Check if the response indicates an error

View File

@@ -109,7 +109,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
if resp == nil || resp.Body == nil {
common.LogError(c, "invalid response or response body")
return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
}
defer common.CloseResponseBodyGracefully(resp)
@@ -178,11 +178,11 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
var simpleResponse dto.OpenAITextResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
err = common.Unmarshal(responseBody, &simpleResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if simpleResponse.Error != nil && simpleResponse.Error.Type != "" {
return nil, types.WithOpenAIError(*simpleResponse.Error, resp.StatusCode)
@@ -263,7 +263,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
// 写入新的 response body
common.IOCopyBytesGracefully(c, resp, responseBody)
@@ -547,13 +547,13 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
var usageResp dto.SimpleResponse
err = common.Unmarshal(responseBody, &usageResp)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
// 写入新的 response body

View File

@@ -22,11 +22,11 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
var responsesResponse dto.OpenAIResponsesResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
err = common.Unmarshal(responseBody, &responsesResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if responsesResponse.Error != nil {
return nil, types.WithOpenAIError(*responsesResponse.Error, resp.StatusCode)

View File

@@ -127,13 +127,13 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
return nil, types.WithOpenAIError(types.OpenAIError{

View File

@@ -15,13 +15,13 @@ import (
func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
var siliconflowResp SFRerankResponse
err = json.Unmarshal(responseBody, &siliconflowResp)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
usage := &dto.Usage{
PromptTokens: siliconflowResp.Meta.Tokens.InputTokens,

View File

@@ -136,12 +136,12 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp
var tencentSb TencentChatResponseSB
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &tencentSb)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if tencentSb.Response.Error.Code != 0 {
return nil, types.WithOpenAIError(types.OpenAIError{

View File

@@ -67,10 +67,10 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "claude") {
a.RequestMode = RequestModeClaude
} else if strings.HasPrefix(info.UpstreamModelName, "gemini") {
a.RequestMode = RequestModeGemini
} else if strings.Contains(info.UpstreamModelName, "llama") {
a.RequestMode = RequestModeLlama
} else {
a.RequestMode = RequestModeGemini
}
}
@@ -83,6 +83,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
a.AccountCredentials = *adc
suffix := ""
if a.RequestMode == RequestModeGemini {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
// 新增逻辑:处理 -thinking-<budget> 格式
if strings.Contains(info.UpstreamModelName, "-thinking-") {
@@ -100,6 +101,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} else {
suffix = "generateContent"
}
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
suffix = "predict"
}
if region == "global" {
return fmt.Sprintf(
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
@@ -231,6 +237,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.RelayMode == constant.RelayModeGemini {
usage, err = gemini.GeminiTextGenerationHandler(c, info, resp)
} else {
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return gemini.GeminiImageHandler(c, info, resp)
}
usage, err = gemini.GeminiChatHandler(c, info, resp)
}
case RequestModeLlama:

View File

@@ -220,12 +220,12 @@ func zhipuHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
var zhipuResponse ZhipuResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &zhipuResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if !zhipuResponse.Success {
return nil, types.WithOpenAIError(types.OpenAIError{

View File

@@ -40,7 +40,7 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
// get & validate textRequest 获取并验证文本请求
textRequest, err := getAndValidateClaudeRequest(c)
if err != nil {
return types.NewError(err, types.ErrorCodeInvalidRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
if textRequest.Stream {
@@ -49,18 +49,18 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误
if err != nil {
return types.NewError(err, types.ErrorCodeCountTokenFailed)
return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry())
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError)
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
// pre-consume quota 预消耗配额
@@ -77,7 +77,7 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
@@ -111,17 +111,17 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest)
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
requestBody = bytes.NewBuffer(body)
} else {
convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
jsonData, err := common.Marshal(convertedRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
// apply param override
@@ -133,7 +133,7 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
jsonData, err = common.Marshal(reqMap)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}
}

View File

@@ -88,6 +88,7 @@ type RelayInfo struct {
BaseUrl string
SupportStreamOptions bool
ShouldIncludeUsage bool
DisablePing bool // 是否禁止向下游发送自定义 Ping
IsModelMapped bool
ClientWs *websocket.Conn
TargetWs *websocket.Conn

View File

@@ -16,7 +16,7 @@ import (
func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
if common.DebugEnabled {
@@ -27,7 +27,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
var xinRerankResponse xinference.XinRerankResponse
err = common.Unmarshal(responseBody, &xinRerankResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
for i, result := range xinRerankResponse.Results {
@@ -62,7 +62,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
} else {
err = common.Unmarshal(responseBody, &jinaResp)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens
}

View File

@@ -41,17 +41,17 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
err := common.UnmarshalBodyReusable(c, &embeddingRequest)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
return types.NewError(err, types.ErrorCodeInvalidRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeInvalidRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
promptToken := getEmbeddingPromptToken(*embeddingRequest)
@@ -59,7 +59,7 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError)
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
@@ -74,18 +74,17 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
requestBody := bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping")

View File

@@ -2,7 +2,6 @@ package relay
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
@@ -81,7 +80,7 @@ func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.Relay
func isNoThinkingRequest(req *gemini.GeminiChatRequest) bool {
if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil {
return *req.GenerationConfig.ThinkingConfig.ThinkingBudget <= 0
return *req.GenerationConfig.ThinkingConfig.ThinkingBudget == 0
}
return false
}
@@ -110,7 +109,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
req, err := getAndValidateGeminiRequest(c)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error()))
return types.NewError(err, types.ErrorCodeInvalidRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
relayInfo := relaycommon.GenRelayInfoGemini(c)
@@ -122,14 +121,14 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
sensitiveWords, err := checkGeminiInputSensitive(req)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry())
}
}
// model mapped 模型映射
err = helper.ModelMappedHelper(c, relayInfo, req)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
if value, exists := c.Get("prompt_tokens"); exists {
@@ -160,7 +159,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError)
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
// pre consume quota
@@ -176,7 +175,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
@@ -199,13 +198,13 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest)
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
requestBody = bytes.NewReader(body)
} else {
jsonData, err := json.Marshal(req)
jsonData, err := common.Marshal(req)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
// apply param override
@@ -217,7 +216,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
jsonData, err = common.Marshal(reqMap)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}
}
@@ -230,7 +229,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
common.LogError(c, "Do gemini request failed: "+err.Error())
return types.NewError(err, types.ErrorCodeDoRequestFailed)
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
statusCodeMappingStr := c.GetString("status_code_mapping")

View File

@@ -54,7 +54,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
)
generalSettings := operation_setting.GetGeneralSetting()
pingEnabled := generalSettings.PingIntervalEnabled
pingEnabled := generalSettings.PingIntervalEnabled && !info.DisablePing
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
if pingInterval <= 0 {
pingInterval = DefaultPingInterval

View File

@@ -115,17 +115,17 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
imageRequest, err := getAndValidImageRequest(c, relayInfo)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error()))
return types.NewError(err, types.ErrorCodeInvalidRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
err = helper.ModelMappedHelper(c, relayInfo, imageRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError)
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
var preConsumedQuota int
var quota int
@@ -173,16 +173,16 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit)
userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return types.NewError(err, types.ErrorCodeQueryDataError)
return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
}
if userQuota-quota < 0 {
return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota)
return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota, types.ErrOptionWithSkipRetry())
}
}
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
@@ -191,20 +191,20 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest)
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
requestBody = bytes.NewBuffer(body)
} else {
convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits {
requestBody = convertedRequest.(io.Reader)
} else {
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
// apply param override
@@ -216,7 +216,7 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
jsonData, err = common.Marshal(reqMap)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}
}

View File

@@ -90,9 +90,8 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
// get & validate textRequest 获取并验证文本请求
textRequest, err := getAndValidateTextRequest(c, relayInfo)
if err != nil {
return types.NewError(err, types.ErrorCodeInvalidRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
if textRequest.WebSearchOptions != nil {
@@ -103,13 +102,13 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
words, err := checkRequestSensitive(textRequest, relayInfo)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry())
}
}
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
// 获取 promptTokens如果上下文中已经存在则直接使用
@@ -121,14 +120,14 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
promptTokens, err = getPromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误
if err != nil {
return types.NewError(err, types.ErrorCodeCountTokenFailed)
return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry())
}
c.Set("prompt_tokens", promptTokens)
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens))))
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError)
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
// pre-consume quota 预消耗配额
@@ -165,7 +164,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
var requestBody io.Reader
@@ -173,7 +172,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest)
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
if common.DebugEnabled {
println("requestBody: ", string(body))
@@ -182,7 +181,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
} else {
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
if relayInfo.ChannelSetting.SystemPrompt != "" {
@@ -207,7 +206,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
jsonData, err := common.Marshal(convertedRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
// apply param override
@@ -219,7 +218,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
jsonData, err = common.Marshal(reqMap)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}
}
@@ -231,7 +230,6 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
var httpResp *http.Response
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
@@ -304,13 +302,13 @@ func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycom
func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *types.NewAPIError) {
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError)
return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
}
if userQuota <= 0 {
return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden)
return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry())
}
if userQuota-preConsumedQuota < 0 {
return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden)
return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry())
}
relayInfo.UserQuota = userQuota
if userQuota > 100*preConsumedQuota {
@@ -334,11 +332,11 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if preConsumedQuota > 0 {
err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil {
return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden)
return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry())
}
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
if err != nil {
return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError)
return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
}
}
return preConsumedQuota, userQuota, nil
@@ -517,6 +515,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
} else {
if !ratio.IsZero() && quota == 0 {
quota = 1
}
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}

View File

@@ -31,21 +31,21 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
err := common.UnmarshalBodyReusable(c, &rerankRequest)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
return types.NewError(err, types.ErrorCodeInvalidRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest)
if rerankRequest.Query == "" {
return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest)
return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
if len(rerankRequest.Documents) == 0 {
return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest)
return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
err = helper.ModelMappedHelper(c, relayInfo, rerankRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
promptToken := getRerankPromptToken(*rerankRequest)
@@ -53,7 +53,7 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError)
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
@@ -68,7 +68,7 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
@@ -76,17 +76,17 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest)
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
requestBody = bytes.NewBuffer(body)
} else {
convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
jsonData, err := common.Marshal(convertedRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
// apply param override
@@ -98,7 +98,7 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
}
jsonData, err = common.Marshal(reqMap)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}
}

View File

@@ -51,7 +51,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
req, err := getAndValidateResponsesRequest(c)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error()))
return types.NewError(err, types.ErrorCodeInvalidRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
relayInfo := relaycommon.GenRelayInfoResponses(c, req)
@@ -60,13 +60,13 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
sensitiveWords, err := checkInputSensitive(req, relayInfo)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry())
}
}
err = helper.ModelMappedHelper(c, relayInfo, req)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
if value, exists := c.Get("prompt_tokens"); exists {
@@ -79,7 +79,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.MaxOutputTokens))
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError)
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
// pre consume quota
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
@@ -93,38 +93,38 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}()
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
var requestBody io.Reader
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return types.NewError(err, types.ErrorCodeReadRequestBodyFailed)
return types.NewError(err, types.ErrorCodeReadRequestBodyFailed, types.ErrOptionWithSkipRetry())
}
requestBody = bytes.NewBuffer(body)
} else {
convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, relayInfo, *req)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
// apply param override
if len(relayInfo.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
err = json.Unmarshal(jsonData, &reqMap)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}
for key, value := range relayInfo.ParamOverride {
reqMap[key] = value
}
jsonData, err = json.Marshal(reqMap)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
}

View File

@@ -24,12 +24,12 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIErr
err := helper.ModelMappedHelper(c, relayInfo, nil)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError)
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
// pre-consume quota 预消耗配额
@@ -46,7 +46,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIErr
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
//var requestBody io.Reader