From 1f5ef24ecdaf1cb4bd437bab628a67e04322d66e Mon Sep 17 00:00:00 2001 From: Xyfacai Date: Wed, 30 Jul 2025 22:35:31 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=98=BE=E5=BC=8F=E6=8C=87=E5=AE=9A=20?= =?UTF-8?q?error=20=E8=B7=B3=E8=BF=87=E9=87=8D=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/playground.go | 10 +++---- controller/relay.go | 8 +++--- middleware/distributor.go | 2 +- model/channel.go | 2 +- relay/audio_handler.go | 10 +++---- relay/claude_handler.go | 18 ++++++------- relay/embedding_handler.go | 15 +++++------ relay/gemini_handler.go | 16 +++++------ relay/image_handler.go | 20 +++++++------- relay/relay-text.go | 32 +++++++++++----------- relay/rerank_handler.go | 20 +++++++------- relay/responses_handler.go | 20 +++++++------- relay/websocket.go | 6 ++--- service/channel.go | 2 +- types/error.go | 54 ++++++++++++++++++++++++++++---------- 15 files changed, 129 insertions(+), 106 deletions(-) diff --git a/controller/playground.go b/controller/playground.go index 0073cf06..64c0e1ce 100644 --- a/controller/playground.go +++ b/controller/playground.go @@ -28,19 +28,19 @@ func Playground(c *gin.Context) { useAccessToken := c.GetBool("use_access_token") if useAccessToken { - newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied) + newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry()) return } playgroundRequest := &dto.PlayGroundRequest{} err := common.UnmarshalBodyReusable(c, playgroundRequest) if err != nil { - newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest) + newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) return } if playgroundRequest.Model == "" { - newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest) + newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) return } c.Set("original_model", playgroundRequest.Model) @@ -51,7 +51,7 @@ func Playground(c *gin.Context) { group = userGroup } else { if !setting.GroupInUserUsableGroups(group) && group != userGroup { - newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied) + newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry()) return } c.Set("group", group) @@ -62,7 +62,7 @@ func Playground(c *gin.Context) { // Write user context to ensure acceptUnsetRatio is available userCache, err := model.GetUserCache(userId) if err != nil { - newAPIError = types.NewError(err, types.ErrorCodeQueryDataError) + newAPIError = types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) return } userCache.WriteContext(c) diff --git a/controller/relay.go b/controller/relay.go index 01081d3d..e7318e9b 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -127,7 +127,7 @@ func WssRelay(c *gin.Context) { defer ws.Close() if err != nil { - helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed).ToOpenAIError()) + helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError()) return } @@ -258,10 +258,10 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m } channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount) if err != nil { - return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed) + return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) } if channel == nil { - return nil, types.NewError(errors.New(fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel)), types.ErrorCodeGetChannelFailed) + return nil, types.NewError(errors.New(fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel)), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) } newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel) if newAPIError != nil { @@ -277,7 +277,7 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b if types.IsChannelError(openaiErr) { return true } - if types.IsLocalError(openaiErr) { + if types.IsSkipRetryError(openaiErr) { return false } if retryTimes <= 0 { diff --git a/middleware/distributor.go b/middleware/distributor.go index cba9b521..fb4a6645 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -247,7 +247,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError { c.Set("original_model", modelName) // for retry if channel == nil { - return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed) + return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) } common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id) common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name) diff --git a/model/channel.go b/model/channel.go index 6277fcda..ea263c84 100644 --- a/model/channel.go +++ b/model/channel.go @@ -138,7 +138,7 @@ func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) { channelInfo, err := CacheGetChannelInfo(channel.Id) if err != nil { - return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed) + return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) } //println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex) defer func() { diff --git a/relay/audio_handler.go b/relay/audio_handler.go index f39dbd82..88777838 100644 --- a/relay/audio_handler.go +++ b/relay/audio_handler.go @@ -62,7 +62,7 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) { if err != nil { common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } promptTokens := 0 @@ -75,7 +75,7 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) { priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) @@ -90,18 +90,18 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) { err = helper.ModelMappedHelper(c, relayInfo, audioRequest) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } resp, err := adaptor.DoRequest(c, relayInfo, ioReader) diff --git a/relay/claude_handler.go b/relay/claude_handler.go index 2c60a91e..b4bf78ff 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -40,7 +40,7 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { // get & validate textRequest 获取并验证文本请求 textRequest, err := getAndValidateClaudeRequest(c) if err != nil { - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } if textRequest.Stream { @@ -49,18 +49,18 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { err = helper.ModelMappedHelper(c, relayInfo, textRequest) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } promptTokens, err := getClaudePromptTokens(textRequest, relayInfo) // count messages token error 计算promptTokens错误 if err != nil { - return types.NewError(err, types.ErrorCodeCountTokenFailed) + return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry()) } priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens)) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } // pre-consume quota 预消耗配额 @@ -77,7 +77,7 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) @@ -111,17 +111,17 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { - return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest) + return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } requestBody = bytes.NewBuffer(body) } else { convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } jsonData, err := common.Marshal(convertedRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } // apply param override @@ -133,7 +133,7 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } jsonData, err = common.Marshal(reqMap) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid) + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } } diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go index be11bb2b..fef8d2c9 100644 --- a/relay/embedding_handler.go +++ b/relay/embedding_handler.go @@ -41,17 +41,17 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) { err := common.UnmarshalBodyReusable(c, &embeddingRequest) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest) if err != nil { - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } promptToken := getEmbeddingPromptToken(*embeddingRequest) @@ -59,7 +59,7 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) { priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } // pre-consume quota 预消耗配额 preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) @@ -74,18 +74,17 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) { adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest) - if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } jsonData, err := json.Marshal(convertedRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } requestBody := bytes.NewBuffer(jsonData) statusCodeMappingStr := c.GetString("status_code_mapping") diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index 2e2c1480..43c7ca58 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -109,7 +109,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { req, err := getAndValidateGeminiRequest(c) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } relayInfo := relaycommon.GenRelayInfoGemini(c) @@ -121,14 +121,14 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { sensitiveWords, err := checkGeminiInputSensitive(req) if err != nil { common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", "))) - return types.NewError(err, types.ErrorCodeSensitiveWordsDetected) + return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry()) } } // model mapped 模型映射 err = helper.ModelMappedHelper(c, relayInfo, req) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } if value, exists := c.Get("prompt_tokens"); exists { @@ -159,7 +159,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens)) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } // pre consume quota @@ -175,7 +175,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) @@ -198,13 +198,13 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { - return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest) + return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } requestBody = bytes.NewReader(body) } else { jsonData, err := common.Marshal(req) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } // apply param override @@ -216,7 +216,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } jsonData, err = common.Marshal(reqMap) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid) + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } } diff --git a/relay/image_handler.go b/relay/image_handler.go index c97eb48e..f0b69699 100644 --- a/relay/image_handler.go +++ b/relay/image_handler.go @@ -115,17 +115,17 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { imageRequest, err := getAndValidImageRequest(c, relayInfo) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } err = helper.ModelMappedHelper(c, relayInfo, imageRequest) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } var preConsumedQuota int var quota int @@ -173,16 +173,16 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit) userQuota, err = model.GetUserQuota(relayInfo.UserId, false) if err != nil { - return types.NewError(err, types.ErrorCodeQueryDataError) + return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) } if userQuota-quota < 0 { - return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota) + return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota, types.ErrOptionWithSkipRetry()) } } adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) @@ -191,20 +191,20 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { - return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest) + return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } requestBody = bytes.NewBuffer(body) } else { convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits { requestBody = convertedRequest.(io.Reader) } else { jsonData, err := json.Marshal(convertedRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } // apply param override @@ -216,7 +216,7 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } jsonData, err = common.Marshal(reqMap) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid) + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } } diff --git a/relay/relay-text.go b/relay/relay-text.go index 1856a2a1..97313be6 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -90,9 +90,8 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { // get & validate textRequest 获取并验证文本请求 textRequest, err := getAndValidateTextRequest(c, relayInfo) - if err != nil { - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } if textRequest.WebSearchOptions != nil { @@ -103,13 +102,13 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { words, err := checkRequestSensitive(textRequest, relayInfo) if err != nil { common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", "))) - return types.NewError(err, types.ErrorCodeSensitiveWordsDetected) + return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry()) } } err = helper.ModelMappedHelper(c, relayInfo, textRequest) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } // 获取 promptTokens,如果上下文中已经存在,则直接使用 @@ -121,14 +120,14 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { promptTokens, err = getPromptTokens(textRequest, relayInfo) // count messages token error 计算promptTokens错误 if err != nil { - return types.NewError(err, types.ErrorCodeCountTokenFailed) + return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry()) } c.Set("prompt_tokens", promptTokens) } priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens)))) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } // pre-consume quota 预消耗配额 @@ -165,7 +164,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) var requestBody io.Reader @@ -173,7 +172,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { - return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest) + return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } if common.DebugEnabled { println("requestBody: ", string(body)) @@ -182,7 +181,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } else { convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } if relayInfo.ChannelSetting.SystemPrompt != "" { @@ -207,7 +206,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { jsonData, err := common.Marshal(convertedRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } // apply param override @@ -219,7 +218,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } jsonData, err = common.Marshal(reqMap) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid) + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } } @@ -231,7 +230,6 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { var httpResp *http.Response resp, err := adaptor.DoRequest(c, relayInfo, requestBody) - if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } @@ -304,13 +302,13 @@ func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycom func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *types.NewAPIError) { userQuota, err := model.GetUserQuota(relayInfo.UserId, false) if err != nil { - return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError) + return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) } if userQuota <= 0 { - return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden) + return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry()) } if userQuota-preConsumedQuota < 0 { - return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden) + return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry()) } relayInfo.UserQuota = userQuota if userQuota > 100*preConsumedQuota { @@ -334,11 +332,11 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo if preConsumedQuota > 0 { err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota) if err != nil { - return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden) + return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry()) } err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota) if err != nil { - return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError) + return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) } } return preConsumedQuota, userQuota, nil diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go index 0190cf08..1e547e2a 100644 --- a/relay/rerank_handler.go +++ b/relay/rerank_handler.go @@ -31,21 +31,21 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError err := common.UnmarshalBodyReusable(c, &rerankRequest) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest) if rerankRequest.Query == "" { - return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest) + return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } if len(rerankRequest.Documents) == 0 { - return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest) + return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } err = helper.ModelMappedHelper(c, relayInfo, rerankRequest) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } promptToken := getRerankPromptToken(*rerankRequest) @@ -53,7 +53,7 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } // pre-consume quota 预消耗配额 preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) @@ -68,7 +68,7 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) @@ -76,17 +76,17 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { - return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest) + return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } requestBody = bytes.NewBuffer(body) } else { convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } jsonData, err := common.Marshal(convertedRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } // apply param override @@ -98,7 +98,7 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError } jsonData, err = common.Marshal(reqMap) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid) + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } } diff --git a/relay/responses_handler.go b/relay/responses_handler.go index 52d1db6e..65c240b2 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -51,7 +51,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { req, err := getAndValidateResponsesRequest(c) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } relayInfo := relaycommon.GenRelayInfoResponses(c, req) @@ -60,13 +60,13 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { sensitiveWords, err := checkInputSensitive(req, relayInfo) if err != nil { common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", "))) - return types.NewError(err, types.ErrorCodeSensitiveWordsDetected) + return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry()) } } err = helper.ModelMappedHelper(c, relayInfo, req) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } if value, exists := c.Get("prompt_tokens"); exists { @@ -79,7 +79,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.MaxOutputTokens)) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } // pre consume quota preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) @@ -93,38 +93,38 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { }() adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) var requestBody io.Reader if model_setting.GetGlobalSettings().PassThroughRequestEnabled { body, err := common.GetRequestBody(c) if err != nil { - return types.NewError(err, types.ErrorCodeReadRequestBodyFailed) + return types.NewError(err, types.ErrorCodeReadRequestBodyFailed, types.ErrOptionWithSkipRetry()) } requestBody = bytes.NewBuffer(body) } else { convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, relayInfo, *req) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } jsonData, err := json.Marshal(convertedRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } // apply param override if len(relayInfo.ParamOverride) > 0 { reqMap := make(map[string]interface{}) err = json.Unmarshal(jsonData, &reqMap) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid) + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } for key, value := range relayInfo.ParamOverride { reqMap[key] = value } jsonData, err = json.Marshal(reqMap) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } } diff --git a/relay/websocket.go b/relay/websocket.go index 659e27d5..3715b237 100644 --- a/relay/websocket.go +++ b/relay/websocket.go @@ -24,12 +24,12 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIErr err := helper.ModelMappedHelper(c, relayInfo, nil) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } // pre-consume quota 预消耗配额 @@ -46,7 +46,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIErr adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) //var requestBody io.Reader diff --git a/service/channel.go b/service/channel.go index 4d38e6ed..faac6d10 100644 --- a/service/channel.go +++ b/service/channel.go @@ -45,7 +45,7 @@ func ShouldDisableChannel(channelType int, err *types.NewAPIError) bool { if types.IsChannelError(err) { return true } - if types.IsLocalError(err) { + if types.IsSkipRetryError(err) { return false } if err.StatusCode == http.StatusUnauthorized { diff --git a/types/error.go b/types/error.go index 2a8105c7..74c3bae5 100644 --- a/types/error.go +++ b/types/error.go @@ -78,6 +78,7 @@ const ( type NewAPIError struct { Err error RelayError any + skipRetry bool errorType ErrorType errorCode ErrorCode StatusCode int @@ -170,33 +171,39 @@ func (e *NewAPIError) ToClaudeError() ClaudeError { return result } -func NewError(err error, errorCode ErrorCode) *NewAPIError { - return &NewAPIError{ +type NewAPIErrorOptions func(*NewAPIError) + +func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPIError { + e := &NewAPIError{ Err: err, RelayError: nil, errorType: ErrorTypeNewAPIError, StatusCode: http.StatusInternalServerError, errorCode: errorCode, } + for _, op := range ops { + op(e) + } + return e } -func NewOpenAIError(err error, errorCode ErrorCode, statusCode int) *NewAPIError { +func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { openaiError := OpenAIError{ Message: err.Error(), Type: string(errorCode), } - return WithOpenAIError(openaiError, statusCode) + return WithOpenAIError(openaiError, statusCode, ops...) } -func InitOpenAIError(errorCode ErrorCode, statusCode int) *NewAPIError { +func InitOpenAIError(errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { openaiError := OpenAIError{ Type: string(errorCode), } - return WithOpenAIError(openaiError, statusCode) + return WithOpenAIError(openaiError, statusCode, ops...) } -func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int) *NewAPIError { - return &NewAPIError{ +func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { + e := &NewAPIError{ Err: err, RelayError: OpenAIError{ Message: err.Error(), @@ -206,9 +213,14 @@ func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int) *New StatusCode: statusCode, errorCode: errorCode, } + for _, op := range ops { + op(e) + } + + return e } -func WithOpenAIError(openAIError OpenAIError, statusCode int) *NewAPIError { +func WithOpenAIError(openAIError OpenAIError, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { code, ok := openAIError.Code.(string) if !ok { code = fmt.Sprintf("%v", openAIError.Code) @@ -216,26 +228,34 @@ func WithOpenAIError(openAIError OpenAIError, statusCode int) *NewAPIError { if openAIError.Type == "" { openAIError.Type = "upstream_error" } - return &NewAPIError{ + e := &NewAPIError{ RelayError: openAIError, errorType: ErrorTypeOpenAIError, StatusCode: statusCode, Err: errors.New(openAIError.Message), errorCode: ErrorCode(code), } + for _, op := range ops { + op(e) + } + return e } -func WithClaudeError(claudeError ClaudeError, statusCode int) *NewAPIError { +func WithClaudeError(claudeError ClaudeError, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { if claudeError.Type == "" { claudeError.Type = "upstream_error" } - return &NewAPIError{ + e := &NewAPIError{ RelayError: claudeError, errorType: ErrorTypeClaudeError, StatusCode: statusCode, Err: errors.New(claudeError.Message), errorCode: ErrorCode(claudeError.Type), } + for _, op := range ops { + op(e) + } + return e } func IsChannelError(err *NewAPIError) bool { @@ -245,10 +265,16 @@ func IsChannelError(err *NewAPIError) bool { return strings.HasPrefix(string(err.errorCode), "channel:") } -func IsLocalError(err *NewAPIError) bool { +func IsSkipRetryError(err *NewAPIError) bool { if err == nil { return false } - return err.errorType == ErrorTypeNewAPIError + return err.skipRetry +} + +func ErrOptionWithSkipRetry() NewAPIErrorOptions { + return func(e *NewAPIError) { + e.skipRetry = true + } }