diff --git a/common/custom-event.go b/common/custom-event.go index d8f9ec9f..256db546 100644 --- a/common/custom-event.go +++ b/common/custom-event.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "strings" + "sync" ) type stringWriter interface { @@ -52,6 +53,8 @@ type CustomEvent struct { Id string Retry uint Data interface{} + + Mutex sync.Mutex } func encode(writer io.Writer, event CustomEvent) error { @@ -73,6 +76,8 @@ func (r CustomEvent) Render(w http.ResponseWriter) error { } func (r CustomEvent) WriteContentType(w http.ResponseWriter) { + r.Mutex.Lock() + defer r.Mutex.Unlock() header := w.Header() header["Content-Type"] = contentType diff --git a/common/str.go b/common/str.go index 88b58c72..f5399eab 100644 --- a/common/str.go +++ b/common/str.go @@ -4,7 +4,10 @@ import ( "encoding/base64" "encoding/json" "math/rand" + "net/url" + "regexp" "strconv" + "strings" "unsafe" ) @@ -95,3 +98,95 @@ func GetJsonString(data any) string { b, _ := json.Marshal(data) return string(b) } + +// MaskSensitiveInfo masks sensitive information like URLs, IPs in a string +// Example: +// http://example.com -> http://***.com +// https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=*** +// https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/*** +// 192.168.1.1 -> ***.***.***.*** +func MaskSensitiveInfo(str string) string { + // Mask URLs + urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`) + str = urlPattern.ReplaceAllStringFunc(str, func(urlStr string) string { + u, err := url.Parse(urlStr) + if err != nil { + return urlStr + } + + host := u.Host + if host == "" { + return urlStr + } + + // Split host by dots + parts := strings.Split(host, ".") + if len(parts) < 2 { + // If less than 2 parts, just mask the whole host + return u.Scheme + "://***" + u.Path + } + + // Keep the TLD (Top Level Domain) and mask the rest + var maskedHost string + if len(parts) == 2 { + // example.com -> ***.com + maskedHost = "***." + parts[len(parts)-1] + } else { + // Handle cases like sub.domain.co.uk or api.example.com + // Keep last 2 parts if they look like country code TLD (co.uk, com.cn, etc.) + lastPart := parts[len(parts)-1] + secondLastPart := parts[len(parts)-2] + + if len(lastPart) == 2 && len(secondLastPart) <= 3 { + // Likely country code TLD like co.uk, com.cn + maskedHost = "***." + secondLastPart + "." + lastPart + } else { + // Regular TLD like .com, .org + maskedHost = "***." + lastPart + } + } + + result := u.Scheme + "://" + maskedHost + + // Mask path + if u.Path != "" && u.Path != "/" { + pathParts := strings.Split(strings.Trim(u.Path, "/"), "/") + maskedPathParts := make([]string, len(pathParts)) + for i := range pathParts { + if pathParts[i] != "" { + maskedPathParts[i] = "***" + } + } + if len(maskedPathParts) > 0 { + result += "/" + strings.Join(maskedPathParts, "/") + } + } else if u.Path == "/" { + result += "/" + } + + // Mask query parameters + if u.RawQuery != "" { + values, err := url.ParseQuery(u.RawQuery) + if err != nil { + // If can't parse query, just mask the whole query string + result += "?***" + } else { + maskedParams := make([]string, 0, len(values)) + for key := range values { + maskedParams = append(maskedParams, key+"=***") + } + if len(maskedParams) > 0 { + result += "?" + strings.Join(maskedParams, "&") + } + } + } + + return result + }) + + // Mask IP addresses + ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`) + str = ipPattern.ReplaceAllString(str, "***.***.***.***") + + return str +} diff --git a/controller/channel-test.go b/controller/channel-test.go index c1c3c21d..75fec463 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -209,7 +209,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { return testResult{ context: c, localErr: err, - newAPIError: types.NewError(err, types.ErrorCodeDoRequestFailed), + newAPIError: types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError), } } var httpResp *http.Response @@ -220,7 +220,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { return testResult{ context: c, localErr: err, - newAPIError: types.NewError(err, types.ErrorCodeBadResponse), + newAPIError: types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError), } } } @@ -236,7 +236,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { return testResult{ context: c, localErr: errors.New("usage is nil"), - newAPIError: types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody), + newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError), } } usage := usageA.(*dto.Usage) @@ -246,7 +246,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { return testResult{ context: c, localErr: err, - newAPIError: types.NewError(err, types.ErrorCodeReadResponseBodyFailed), + newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), } } info.PromptTokens = usage.PromptTokens @@ -417,7 +417,7 @@ func testAllChannels(notify bool) error { if common.AutomaticDisableChannelEnabled && !shouldBanChannel { if milliseconds > disableThreshold { err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) - newAPIError = types.NewError(err, types.ErrorCodeChannelResponseTimeExceeded) + newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout) shouldBanChannel = true } } diff --git a/controller/playground.go b/controller/playground.go index 0073cf06..64c0e1ce 100644 --- a/controller/playground.go +++ b/controller/playground.go @@ -28,19 +28,19 @@ func Playground(c *gin.Context) { useAccessToken := c.GetBool("use_access_token") if useAccessToken { - newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied) + newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry()) return } playgroundRequest := &dto.PlayGroundRequest{} err := common.UnmarshalBodyReusable(c, playgroundRequest) if err != nil { - newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest) + newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) return } if playgroundRequest.Model == "" { - newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest) + newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) return } c.Set("original_model", playgroundRequest.Model) @@ -51,7 +51,7 @@ func Playground(c *gin.Context) { group = userGroup } else { if !setting.GroupInUserUsableGroups(group) && group != userGroup { - newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied) + newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry()) return } c.Set("group", group) @@ -62,7 +62,7 @@ func Playground(c *gin.Context) { // Write user context to ensure acceptUnsetRatio is available userCache, err := model.GetUserCache(userId) if err != nil { - newAPIError = types.NewError(err, types.ErrorCodeQueryDataError) + newAPIError = types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) return } userCache.WriteContext(c) diff --git a/controller/relay.go b/controller/relay.go index d4b5fd18..e7318e9b 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -62,8 +62,7 @@ func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError { other["channel_id"] = channelId other["channel_name"] = c.GetString("channel_name") other["channel_type"] = c.GetInt("channel_type") - - model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error(), tokenId, 0, false, userGroup, other) + model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other) } return err @@ -128,7 +127,7 @@ func WssRelay(c *gin.Context) { defer ws.Close() if err != nil { - helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed).ToOpenAIError()) + helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError()) return } @@ -259,10 +258,10 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m } channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount) if err != nil { - return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed) + return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) } if channel == nil { - return nil, types.NewError(errors.New(fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel)), types.ErrorCodeGetChannelFailed) + return nil, types.NewError(errors.New(fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel)), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) } newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel) if newAPIError != nil { @@ -278,7 +277,7 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b if types.IsChannelError(openaiErr) { return true } - if types.IsLocalError(openaiErr) { + if types.IsSkipRetryError(openaiErr) { return false } if retryTimes <= 0 { diff --git a/middleware/distributor.go b/middleware/distributor.go index cba9b521..fb4a6645 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -247,7 +247,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError { c.Set("original_model", modelName) // for retry if channel == nil { - return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed) + return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) } common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id) common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name) diff --git a/model/channel.go b/model/channel.go index 6277fcda..58f0a064 100644 --- a/model/channel.go +++ b/model/channel.go @@ -75,7 +75,7 @@ func (channel *Channel) getKeys() []string { // If the key starts with '[', try to parse it as a JSON array (e.g., for Vertex AI scenarios) if strings.HasPrefix(trimmed, "[") { var arr []json.RawMessage - if err := json.Unmarshal([]byte(trimmed), &arr); err == nil { + if err := common.Unmarshal([]byte(trimmed), &arr); err == nil { res := make([]string, len(arr)) for i, v := range arr { res[i] = string(v) @@ -138,7 +138,7 @@ func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) { channelInfo, err := CacheGetChannelInfo(channel.Id) if err != nil { - return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed) + return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) } //println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex) defer func() { @@ -197,7 +197,7 @@ func (channel *Channel) GetGroups() []string { func (channel *Channel) GetOtherInfo() map[string]interface{} { otherInfo := make(map[string]interface{}) if channel.OtherInfo != "" { - err := json.Unmarshal([]byte(channel.OtherInfo), &otherInfo) + err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo) if err != nil { common.SysError("failed to unmarshal other info: " + err.Error()) } @@ -425,7 +425,7 @@ func (channel *Channel) Update() error { trimmed := strings.TrimSpace(keyStr) if strings.HasPrefix(trimmed, "[") { var arr []json.RawMessage - if err := json.Unmarshal([]byte(trimmed), &arr); err == nil { + if err := common.Unmarshal([]byte(trimmed), &arr); err == nil { keys = make([]string, len(arr)) for i, v := range arr { keys[i] = string(v) @@ -571,10 +571,6 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri if channelCache.Status == status { return false } - // 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回 - if status != common.ChannelStatusEnabled { - return false - } CacheUpdateChannelStatus(channelId, status) } } @@ -778,7 +774,7 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str func (channel *Channel) ValidateSettings() error { channelParams := &dto.ChannelSettings{} if channel.Setting != nil && *channel.Setting != "" { - err := json.Unmarshal([]byte(*channel.Setting), channelParams) + err := common.Unmarshal([]byte(*channel.Setting), channelParams) if err != nil { return err } @@ -789,7 +785,7 @@ func (channel *Channel) ValidateSettings() error { func (channel *Channel) GetSetting() dto.ChannelSettings { setting := dto.ChannelSettings{} if channel.Setting != nil && *channel.Setting != "" { - err := json.Unmarshal([]byte(*channel.Setting), &setting) + err := common.Unmarshal([]byte(*channel.Setting), &setting) if err != nil { common.SysError("failed to unmarshal setting: " + err.Error()) channel.Setting = nil // 清空设置以避免后续错误 @@ -800,7 +796,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings { } func (channel *Channel) SetSetting(setting dto.ChannelSettings) { - settingBytes, err := json.Marshal(setting) + settingBytes, err := common.Marshal(setting) if err != nil { common.SysError("failed to marshal setting: " + err.Error()) return @@ -811,7 +807,7 @@ func (channel *Channel) SetSetting(setting dto.ChannelSettings) { func (channel *Channel) GetParamOverride() map[string]interface{} { paramOverride := make(map[string]interface{}) if channel.ParamOverride != nil && *channel.ParamOverride != "" { - err := json.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride) + err := common.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride) if err != nil { common.SysError("failed to unmarshal param override: " + err.Error()) } diff --git a/relay/audio_handler.go b/relay/audio_handler.go index f39dbd82..88777838 100644 --- a/relay/audio_handler.go +++ b/relay/audio_handler.go @@ -62,7 +62,7 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) { if err != nil { common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } promptTokens := 0 @@ -75,7 +75,7 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) { priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) @@ -90,18 +90,18 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) { err = helper.ModelMappedHelper(c, relayInfo, audioRequest) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } resp, err := adaptor.DoRequest(c, relayInfo, ioReader) diff --git a/relay/channel/ali/image.go b/relay/channel/ali/image.go index 0d430c62..754f29c8 100644 --- a/relay/channel/ali/image.go +++ b/relay/channel/ali/image.go @@ -132,12 +132,12 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela var aliTaskResponse AliResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil + return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &aliTaskResponse) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil } if aliTaskResponse.Message != "" { diff --git a/relay/channel/ali/rerank.go b/relay/channel/ali/rerank.go index 59cb0a11..4f448e01 100644 --- a/relay/channel/ali/rerank.go +++ b/relay/channel/ali/rerank.go @@ -34,14 +34,14 @@ func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest { func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil + return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } common.CloseResponseBodyGracefully(resp) var aliResponse AliRerankResponse err = json.Unmarshal(responseBody, &aliResponse) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil } if aliResponse.Code != "" { diff --git a/relay/channel/ali/text.go b/relay/channel/ali/text.go index 6d90fa71..fcf63854 100644 --- a/relay/channel/ali/text.go +++ b/relay/channel/ali/text.go @@ -43,7 +43,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIErro var fullTextResponse dto.FlexibleEmbeddingResponse err := json.NewDecoder(resp.Body).Decode(&fullTextResponse) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil } common.CloseResponseBodyGracefully(resp) @@ -179,12 +179,12 @@ func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.U var aliResponse AliResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil + return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &aliResponse) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil } if aliResponse.Code != "" { return types.WithOpenAIError(types.OpenAIError{ diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index ff7c63fa..3ccd2d78 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -223,7 +223,7 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http helper.SetEventStreamHeaders(c) // 处理流式请求的 ping 保活 generalSettings := operation_setting.GetGeneralSetting() - if generalSettings.PingIntervalEnabled { + if generalSettings.PingIntervalEnabled && !info.DisablePing { pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second stopPinger = startPingKeepAlive(c, pingInterval) // 使用defer确保在任何情况下都能停止ping goroutine diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 2e31ec55..2b7b7e39 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -1,12 +1,10 @@ package gemini import ( - "encoding/json" "errors" "fmt" "io" "net/http" - "one-api/common" "one-api/dto" "one-api/relay/channel" "one-api/relay/channel/openai" @@ -175,6 +173,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.RelayMode == constant.RelayModeGemini { if info.IsStream { + info.DisablePing = true return GeminiTextGenerationStreamHandler(c, info, resp) } else { return GeminiTextGenerationHandler(c, info, resp) @@ -212,60 +211,6 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom return nil, types.NewError(errors.New("not implemented"), types.ErrorCodeBadResponseBody) } -func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - responseBody, readErr := io.ReadAll(resp.Body) - if readErr != nil { - return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody) - } - _ = resp.Body.Close() - - var geminiResponse GeminiImageResponse - if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { - return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody) - } - - if len(geminiResponse.Predictions) == 0 { - return nil, types.NewError(errors.New("no images generated"), types.ErrorCodeBadResponseBody) - } - - // convert to openai format response - openAIResponse := dto.ImageResponse{ - Created: common.GetTimestamp(), - Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)), - } - - for _, prediction := range geminiResponse.Predictions { - if prediction.RaiFilteredReason != "" { - continue // skip filtered image - } - openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{ - B64Json: prediction.BytesBase64Encoded, - }) - } - - jsonResponse, jsonErr := json.Marshal(openAIResponse) - if jsonErr != nil { - return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody) - } - - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, _ = c.Writer.Write(jsonResponse) - - // https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb - // each image has fixed 258 tokens - const imageTokens = 258 - generatedImages := len(openAIResponse.Data) - - usage := &dto.Usage{ - PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens - CompletionTokens: 0, // image generation does not calculate completion tokens - TotalTokens: imageTokens * generatedImages, - } - - return usage, nil -} - func (a *Adaptor) GetModelList() []string { return ModelList } diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index 0870e3fa..7d459cc2 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -20,7 +20,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re // 读取响应体 responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if common.DebugEnabled { @@ -31,7 +31,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re var geminiResponse GeminiChatResponse err = common.Unmarshal(responseBody, &geminiResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } // 计算使用量(基于 UsageMetadata) @@ -54,7 +54,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re // 直接返回 Gemini 原生格式的 JSON 响应 jsonResponse, err := common.Marshal(geminiResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } common.IOCopyBytesGracefully(c, resp, jsonResponse) diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 7e57bdac..5dac0ce5 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -907,7 +907,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } common.CloseResponseBodyGracefully(resp) if common.DebugEnabled { @@ -916,10 +916,10 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R var geminiResponse GeminiChatResponse err = common.Unmarshal(responseBody, &geminiResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if len(geminiResponse.Candidates) == 0 { - return nil, types.NewError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse) fullTextResponse.Model = info.UpstreamModelName @@ -956,12 +956,12 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h responseBody, readErr := io.ReadAll(resp.Body) if readErr != nil { - return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } var geminiResponse GeminiEmbeddingResponse if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { - return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } // convert to openai format response @@ -991,9 +991,63 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h jsonResponse, jsonErr := common.Marshal(openAIResponse) if jsonErr != nil { - return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } common.IOCopyBytesGracefully(c, resp, jsonResponse) return usage, nil } + +func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { + responseBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + _ = resp.Body.Close() + + var geminiResponse GeminiImageResponse + if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { + return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + + if len(geminiResponse.Predictions) == 0 { + return nil, types.NewOpenAIError(errors.New("no images generated"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + + // convert to openai format response + openAIResponse := dto.ImageResponse{ + Created: common.GetTimestamp(), + Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)), + } + + for _, prediction := range geminiResponse.Predictions { + if prediction.RaiFilteredReason != "" { + continue // skip filtered image + } + openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{ + B64Json: prediction.BytesBase64Encoded, + }) + } + + jsonResponse, jsonErr := json.Marshal(openAIResponse) + if jsonErr != nil { + return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody) + } + + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(jsonResponse) + + // https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb + // each image has fixed 258 tokens + const imageTokens = 258 + generatedImages := len(openAIResponse.Data) + + usage := &dto.Usage{ + PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens + CompletionTokens: 0, // image generation does not calculate completion tokens + TotalTokens: imageTokens * generatedImages, + } + + return usage, nil +} diff --git a/relay/channel/jimeng/image.go b/relay/channel/jimeng/image.go index 3c6a1d99..28af1866 100644 --- a/relay/channel/jimeng/image.go +++ b/relay/channel/jimeng/image.go @@ -52,13 +52,13 @@ func jimengImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.R var jimengResponse ImageResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &jimengResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } // Check if the response indicates an error diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 82bd2d26..2252b407 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -109,7 +109,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { if resp == nil || resp.Body == nil { common.LogError(c, "invalid response or response body") - return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse) + return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError) } defer common.CloseResponseBodyGracefully(resp) @@ -178,11 +178,11 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo var simpleResponse dto.OpenAITextResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } err = common.Unmarshal(responseBody, &simpleResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if simpleResponse.Error != nil && simpleResponse.Error.Type != "" { return nil, types.WithOpenAIError(*simpleResponse.Error, resp.StatusCode) @@ -263,7 +263,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } responseBody, err := io.ReadAll(resp.Body) if err != nil { - return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil + return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } // 写入新的 response body common.IOCopyBytesGracefully(c, resp, responseBody) @@ -547,13 +547,13 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } var usageResp dto.SimpleResponse err = common.Unmarshal(responseBody, &usageResp) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } // 写入新的 response body diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index d9dd96b9..fd57924b 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -22,11 +22,11 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http var responsesResponse dto.OpenAIResponsesResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } err = common.Unmarshal(responseBody, &responsesResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if responsesResponse.Error != nil { return nil, types.WithOpenAIError(*responsesResponse.Error, resp.StatusCode) diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index 4db31573..cbd60f5e 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -127,13 +127,13 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } common.CloseResponseBodyGracefully(resp) var palmResponse PaLMChatResponse err = json.Unmarshal(responseBody, &palmResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { return nil, types.WithOpenAIError(types.OpenAIError{ diff --git a/relay/channel/siliconflow/relay-siliconflow.go b/relay/channel/siliconflow/relay-siliconflow.go index fabaf9c6..2e37ad15 100644 --- a/relay/channel/siliconflow/relay-siliconflow.go +++ b/relay/channel/siliconflow/relay-siliconflow.go @@ -15,13 +15,13 @@ import ( func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } common.CloseResponseBodyGracefully(resp) var siliconflowResp SFRerankResponse err = json.Unmarshal(responseBody, &siliconflowResp) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } usage := &dto.Usage{ PromptTokens: siliconflowResp.Meta.Tokens.InputTokens, diff --git a/relay/channel/tencent/relay-tencent.go b/relay/channel/tencent/relay-tencent.go index c3d96c49..78ce6238 100644 --- a/relay/channel/tencent/relay-tencent.go +++ b/relay/channel/tencent/relay-tencent.go @@ -136,12 +136,12 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp var tencentSb TencentChatResponseSB responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &tencentSb) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if tencentSb.Response.Error.Code != 0 { return nil, types.WithOpenAIError(types.OpenAIError{ diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index fa895de0..c88b4359 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -67,10 +67,10 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf func (a *Adaptor) Init(info *relaycommon.RelayInfo) { if strings.HasPrefix(info.UpstreamModelName, "claude") { a.RequestMode = RequestModeClaude - } else if strings.HasPrefix(info.UpstreamModelName, "gemini") { - a.RequestMode = RequestModeGemini } else if strings.Contains(info.UpstreamModelName, "llama") { a.RequestMode = RequestModeLlama + } else { + a.RequestMode = RequestModeGemini } } @@ -83,6 +83,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { a.AccountCredentials = *adc suffix := "" if a.RequestMode == RequestModeGemini { + if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { // 新增逻辑:处理 -thinking- 格式 if strings.Contains(info.UpstreamModelName, "-thinking-") { @@ -100,6 +101,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } else { suffix = "generateContent" } + + if strings.HasPrefix(info.UpstreamModelName, "imagen") { + suffix = "predict" + } + if region == "global" { return fmt.Sprintf( "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s", @@ -231,6 +237,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.RelayMode == constant.RelayModeGemini { usage, err = gemini.GeminiTextGenerationHandler(c, info, resp) } else { + if strings.HasPrefix(info.UpstreamModelName, "imagen") { + return gemini.GeminiImageHandler(c, info, resp) + } usage, err = gemini.GeminiChatHandler(c, info, resp) } case RequestModeLlama: diff --git a/relay/channel/zhipu/relay-zhipu.go b/relay/channel/zhipu/relay-zhipu.go index 916a200d..35882ed5 100644 --- a/relay/channel/zhipu/relay-zhipu.go +++ b/relay/channel/zhipu/relay-zhipu.go @@ -220,12 +220,12 @@ func zhipuHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon var zhipuResponse ZhipuResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &zhipuResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if !zhipuResponse.Success { return nil, types.WithOpenAIError(types.OpenAIError{ diff --git a/relay/claude_handler.go b/relay/claude_handler.go index 2c60a91e..b4bf78ff 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -40,7 +40,7 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { // get & validate textRequest 获取并验证文本请求 textRequest, err := getAndValidateClaudeRequest(c) if err != nil { - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } if textRequest.Stream { @@ -49,18 +49,18 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { err = helper.ModelMappedHelper(c, relayInfo, textRequest) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } promptTokens, err := getClaudePromptTokens(textRequest, relayInfo) // count messages token error 计算promptTokens错误 if err != nil { - return types.NewError(err, types.ErrorCodeCountTokenFailed) + return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry()) } priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens)) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } // pre-consume quota 预消耗配额 @@ -77,7 +77,7 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) @@ -111,17 +111,17 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { - return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest) + return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } requestBody = bytes.NewBuffer(body) } else { convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } jsonData, err := common.Marshal(convertedRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } // apply param override @@ -133,7 +133,7 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } jsonData, err = common.Marshal(reqMap) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid) + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } } diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 45fde019..27827d97 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -88,6 +88,7 @@ type RelayInfo struct { BaseUrl string SupportStreamOptions bool ShouldIncludeUsage bool + DisablePing bool // 是否禁止向下游发送自定义 Ping IsModelMapped bool ClientWs *websocket.Conn TargetWs *websocket.Conn diff --git a/relay/common_handler/rerank.go b/relay/common_handler/rerank.go index ce823b3a..57df5fe3 100644 --- a/relay/common_handler/rerank.go +++ b/relay/common_handler/rerank.go @@ -16,7 +16,7 @@ import ( func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } common.CloseResponseBodyGracefully(resp) if common.DebugEnabled { @@ -27,7 +27,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo var xinRerankResponse xinference.XinRerankResponse err = common.Unmarshal(responseBody, &xinRerankResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results)) for i, result := range xinRerankResponse.Results { @@ -62,7 +62,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo } else { err = common.Unmarshal(responseBody, &jinaResp) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens } diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go index be11bb2b..fef8d2c9 100644 --- a/relay/embedding_handler.go +++ b/relay/embedding_handler.go @@ -41,17 +41,17 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) { err := common.UnmarshalBodyReusable(c, &embeddingRequest) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest) if err != nil { - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } promptToken := getEmbeddingPromptToken(*embeddingRequest) @@ -59,7 +59,7 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) { priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } // pre-consume quota 预消耗配额 preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) @@ -74,18 +74,17 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) { adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest) - if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } jsonData, err := json.Marshal(convertedRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } requestBody := bytes.NewBuffer(jsonData) statusCodeMappingStr := c.GetString("status_code_mapping") diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index 0f1aa5bf..43c7ca58 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -2,7 +2,6 @@ package relay import ( "bytes" - "encoding/json" "errors" "fmt" "io" @@ -81,7 +80,7 @@ func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.Relay func isNoThinkingRequest(req *gemini.GeminiChatRequest) bool { if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil { - return *req.GenerationConfig.ThinkingConfig.ThinkingBudget <= 0 + return *req.GenerationConfig.ThinkingConfig.ThinkingBudget == 0 } return false } @@ -110,7 +109,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { req, err := getAndValidateGeminiRequest(c) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } relayInfo := relaycommon.GenRelayInfoGemini(c) @@ -122,14 +121,14 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { sensitiveWords, err := checkGeminiInputSensitive(req) if err != nil { common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", "))) - return types.NewError(err, types.ErrorCodeSensitiveWordsDetected) + return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry()) } } // model mapped 模型映射 err = helper.ModelMappedHelper(c, relayInfo, req) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } if value, exists := c.Get("prompt_tokens"); exists { @@ -160,7 +159,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens)) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } // pre consume quota @@ -176,7 +175,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) @@ -199,13 +198,13 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { - return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest) + return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } requestBody = bytes.NewReader(body) } else { - jsonData, err := json.Marshal(req) + jsonData, err := common.Marshal(req) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } // apply param override @@ -217,7 +216,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } jsonData, err = common.Marshal(reqMap) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid) + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } } @@ -230,7 +229,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { resp, err := adaptor.DoRequest(c, relayInfo, requestBody) if err != nil { common.LogError(c, "Do gemini request failed: "+err.Error()) - return types.NewError(err, types.ErrorCodeDoRequestFailed) + return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } statusCodeMappingStr := c.GetString("status_code_mapping") diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go index c72aea6a..df8c5072 100644 --- a/relay/helper/stream_scanner.go +++ b/relay/helper/stream_scanner.go @@ -54,7 +54,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon ) generalSettings := operation_setting.GetGeneralSetting() - pingEnabled := generalSettings.PingIntervalEnabled + pingEnabled := generalSettings.PingIntervalEnabled && !info.DisablePing pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second if pingInterval <= 0 { pingInterval = DefaultPingInterval diff --git a/relay/image_handler.go b/relay/image_handler.go index c97eb48e..f0b69699 100644 --- a/relay/image_handler.go +++ b/relay/image_handler.go @@ -115,17 +115,17 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { imageRequest, err := getAndValidImageRequest(c, relayInfo) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } err = helper.ModelMappedHelper(c, relayInfo, imageRequest) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } var preConsumedQuota int var quota int @@ -173,16 +173,16 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit) userQuota, err = model.GetUserQuota(relayInfo.UserId, false) if err != nil { - return types.NewError(err, types.ErrorCodeQueryDataError) + return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) } if userQuota-quota < 0 { - return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota) + return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota, types.ErrOptionWithSkipRetry()) } } adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) @@ -191,20 +191,20 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { - return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest) + return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } requestBody = bytes.NewBuffer(body) } else { convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits { requestBody = convertedRequest.(io.Reader) } else { jsonData, err := json.Marshal(convertedRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } // apply param override @@ -216,7 +216,7 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } jsonData, err = common.Marshal(reqMap) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid) + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } } diff --git a/relay/relay-text.go b/relay/relay-text.go index 84d4e38b..97313be6 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -90,9 +90,8 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { // get & validate textRequest 获取并验证文本请求 textRequest, err := getAndValidateTextRequest(c, relayInfo) - if err != nil { - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } if textRequest.WebSearchOptions != nil { @@ -103,13 +102,13 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { words, err := checkRequestSensitive(textRequest, relayInfo) if err != nil { common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", "))) - return types.NewError(err, types.ErrorCodeSensitiveWordsDetected) + return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry()) } } err = helper.ModelMappedHelper(c, relayInfo, textRequest) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } // 获取 promptTokens,如果上下文中已经存在,则直接使用 @@ -121,14 +120,14 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { promptTokens, err = getPromptTokens(textRequest, relayInfo) // count messages token error 计算promptTokens错误 if err != nil { - return types.NewError(err, types.ErrorCodeCountTokenFailed) + return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry()) } c.Set("prompt_tokens", promptTokens) } priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens)))) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } // pre-consume quota 预消耗配额 @@ -165,7 +164,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) var requestBody io.Reader @@ -173,7 +172,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { - return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest) + return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } if common.DebugEnabled { println("requestBody: ", string(body)) @@ -182,7 +181,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } else { convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } if relayInfo.ChannelSetting.SystemPrompt != "" { @@ -207,7 +206,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { jsonData, err := common.Marshal(convertedRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } // apply param override @@ -219,7 +218,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } jsonData, err = common.Marshal(reqMap) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid) + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } } @@ -231,7 +230,6 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { var httpResp *http.Response resp, err := adaptor.DoRequest(c, relayInfo, requestBody) - if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } @@ -304,13 +302,13 @@ func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycom func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *types.NewAPIError) { userQuota, err := model.GetUserQuota(relayInfo.UserId, false) if err != nil { - return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError) + return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) } if userQuota <= 0 { - return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden) + return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry()) } if userQuota-preConsumedQuota < 0 { - return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden) + return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry()) } relayInfo.UserQuota = userQuota if userQuota > 100*preConsumedQuota { @@ -334,11 +332,11 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo if preConsumedQuota > 0 { err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota) if err != nil { - return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden) + return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry()) } err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota) if err != nil { - return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError) + return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) } } return preConsumedQuota, userQuota, nil @@ -517,6 +515,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota)) } else { + if !ratio.IsZero() && quota == 0 { + quota = 1 + } model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go index 0190cf08..1e547e2a 100644 --- a/relay/rerank_handler.go +++ b/relay/rerank_handler.go @@ -31,21 +31,21 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError err := common.UnmarshalBodyReusable(c, &rerankRequest) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest) if rerankRequest.Query == "" { - return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest) + return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } if len(rerankRequest.Documents) == 0 { - return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest) + return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } err = helper.ModelMappedHelper(c, relayInfo, rerankRequest) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } promptToken := getRerankPromptToken(*rerankRequest) @@ -53,7 +53,7 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } // pre-consume quota 预消耗配额 preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) @@ -68,7 +68,7 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) @@ -76,17 +76,17 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { - return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest) + return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } requestBody = bytes.NewBuffer(body) } else { convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } jsonData, err := common.Marshal(convertedRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } // apply param override @@ -98,7 +98,7 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError } jsonData, err = common.Marshal(reqMap) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid) + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } } diff --git a/relay/responses_handler.go b/relay/responses_handler.go index 52d1db6e..65c240b2 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -51,7 +51,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { req, err := getAndValidateResponsesRequest(c) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } relayInfo := relaycommon.GenRelayInfoResponses(c, req) @@ -60,13 +60,13 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { sensitiveWords, err := checkInputSensitive(req, relayInfo) if err != nil { common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", "))) - return types.NewError(err, types.ErrorCodeSensitiveWordsDetected) + return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry()) } } err = helper.ModelMappedHelper(c, relayInfo, req) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } if value, exists := c.Get("prompt_tokens"); exists { @@ -79,7 +79,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.MaxOutputTokens)) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } // pre consume quota preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) @@ -93,38 +93,38 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { }() adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) var requestBody io.Reader if model_setting.GetGlobalSettings().PassThroughRequestEnabled { body, err := common.GetRequestBody(c) if err != nil { - return types.NewError(err, types.ErrorCodeReadRequestBodyFailed) + return types.NewError(err, types.ErrorCodeReadRequestBodyFailed, types.ErrOptionWithSkipRetry()) } requestBody = bytes.NewBuffer(body) } else { convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, relayInfo, *req) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } jsonData, err := json.Marshal(convertedRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } // apply param override if len(relayInfo.ParamOverride) > 0 { reqMap := make(map[string]interface{}) err = json.Unmarshal(jsonData, &reqMap) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid) + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } for key, value := range relayInfo.ParamOverride { reqMap[key] = value } jsonData, err = json.Marshal(reqMap) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } } diff --git a/relay/websocket.go b/relay/websocket.go index 659e27d5..3715b237 100644 --- a/relay/websocket.go +++ b/relay/websocket.go @@ -24,12 +24,12 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIErr err := helper.ModelMappedHelper(c, relayInfo, nil) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } // pre-consume quota 预消耗配额 @@ -46,7 +46,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIErr adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) //var requestBody io.Reader diff --git a/service/channel.go b/service/channel.go index 4d38e6ed..faac6d10 100644 --- a/service/channel.go +++ b/service/channel.go @@ -45,7 +45,7 @@ func ShouldDisableChannel(channelType int, err *types.NewAPIError) bool { if types.IsChannelError(err) { return true } - if types.IsLocalError(err) { + if types.IsSkipRetryError(err) { return false } if err.StatusCode == http.StatusUnauthorized { diff --git a/service/error.go b/service/error.go index 83979add..94d9c250 100644 --- a/service/error.go +++ b/service/error.go @@ -1,7 +1,6 @@ package service import ( - "encoding/json" "errors" "fmt" "io" @@ -112,7 +111,7 @@ func ResetStatusCode(newApiErr *types.NewAPIError, statusCodeMappingStr string) return } statusCodeMapping := make(map[string]string) - err := json.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping) + err := common.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping) if err != nil { return } diff --git a/setting/rate_limit.go b/setting/rate_limit.go index 53b53f88..d550b2c3 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -3,6 +3,7 @@ package setting import ( "encoding/json" "fmt" + "math" "one-api/common" "sync" ) @@ -58,6 +59,9 @@ func CheckModelRequestRateLimitGroup(jsonStr string) error { if limits[0] < 0 || limits[1] < 1 { return fmt.Errorf("group %s has negative rate limit values: [%d, %d]", group, limits[0], limits[1]) } + if limits[0] > math.MaxInt32 || limits[1] > math.MaxInt32 { + return fmt.Errorf("group %s [%d, %d] has max rate limits value 2147483647", group, limits[0], limits[1]) + } } return nil diff --git a/types/error.go b/types/error.go index c94bd001..74c3bae5 100644 --- a/types/error.go +++ b/types/error.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net/http" + "one-api/common" "strings" ) @@ -77,6 +78,7 @@ const ( type NewAPIError struct { Err error RelayError any + skipRetry bool errorType ErrorType errorCode ErrorCode StatusCode int @@ -107,19 +109,30 @@ func (e *NewAPIError) Error() string { return e.Err.Error() } +func (e *NewAPIError) MaskSensitiveError() string { + if e == nil { + return "" + } + if e.Err == nil { + return string(e.errorCode) + } + return common.MaskSensitiveInfo(e.Err.Error()) +} + func (e *NewAPIError) SetMessage(message string) { e.Err = errors.New(message) } func (e *NewAPIError) ToOpenAIError() OpenAIError { + var result OpenAIError switch e.errorType { case ErrorTypeOpenAIError: if openAIError, ok := e.RelayError.(OpenAIError); ok { - return openAIError + result = openAIError } case ErrorTypeClaudeError: if claudeError, ok := e.RelayError.(ClaudeError); ok { - return OpenAIError{ + result = OpenAIError{ Message: e.Error(), Type: claudeError.Type, Param: "", @@ -127,59 +140,70 @@ func (e *NewAPIError) ToOpenAIError() OpenAIError { } } } - return OpenAIError{ + result = OpenAIError{ Message: e.Error(), Type: string(e.errorType), Param: "", Code: e.errorCode, } + result.Message = common.MaskSensitiveInfo(result.Message) + return result } func (e *NewAPIError) ToClaudeError() ClaudeError { + var result ClaudeError switch e.errorType { case ErrorTypeOpenAIError: openAIError := e.RelayError.(OpenAIError) - return ClaudeError{ + result = ClaudeError{ Message: e.Error(), Type: fmt.Sprintf("%v", openAIError.Code), } case ErrorTypeClaudeError: - return e.RelayError.(ClaudeError) + result = e.RelayError.(ClaudeError) default: - return ClaudeError{ + result = ClaudeError{ Message: e.Error(), Type: string(e.errorType), } } + result.Message = common.MaskSensitiveInfo(result.Message) + return result } -func NewError(err error, errorCode ErrorCode) *NewAPIError { - return &NewAPIError{ +type NewAPIErrorOptions func(*NewAPIError) + +func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPIError { + e := &NewAPIError{ Err: err, RelayError: nil, errorType: ErrorTypeNewAPIError, StatusCode: http.StatusInternalServerError, errorCode: errorCode, } + for _, op := range ops { + op(e) + } + return e } -func NewOpenAIError(err error, errorCode ErrorCode, statusCode int) *NewAPIError { +func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { openaiError := OpenAIError{ Message: err.Error(), Type: string(errorCode), } - return WithOpenAIError(openaiError, statusCode) + return WithOpenAIError(openaiError, statusCode, ops...) } -func InitOpenAIError(errorCode ErrorCode, statusCode int) *NewAPIError { +func InitOpenAIError(errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { openaiError := OpenAIError{ Type: string(errorCode), } - return WithOpenAIError(openaiError, statusCode) + return WithOpenAIError(openaiError, statusCode, ops...) } -func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int) *NewAPIError { - return &NewAPIError{ +func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { + e := &NewAPIError{ Err: err, RelayError: OpenAIError{ Message: err.Error(), @@ -189,9 +213,14 @@ func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int) *New StatusCode: statusCode, errorCode: errorCode, } + for _, op := range ops { + op(e) + } + + return e } -func WithOpenAIError(openAIError OpenAIError, statusCode int) *NewAPIError { +func WithOpenAIError(openAIError OpenAIError, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { code, ok := openAIError.Code.(string) if !ok { code = fmt.Sprintf("%v", openAIError.Code) @@ -199,26 +228,34 @@ func WithOpenAIError(openAIError OpenAIError, statusCode int) *NewAPIError { if openAIError.Type == "" { openAIError.Type = "upstream_error" } - return &NewAPIError{ + e := &NewAPIError{ RelayError: openAIError, errorType: ErrorTypeOpenAIError, StatusCode: statusCode, Err: errors.New(openAIError.Message), errorCode: ErrorCode(code), } + for _, op := range ops { + op(e) + } + return e } -func WithClaudeError(claudeError ClaudeError, statusCode int) *NewAPIError { +func WithClaudeError(claudeError ClaudeError, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { if claudeError.Type == "" { claudeError.Type = "upstream_error" } - return &NewAPIError{ + e := &NewAPIError{ RelayError: claudeError, errorType: ErrorTypeClaudeError, StatusCode: statusCode, Err: errors.New(claudeError.Message), errorCode: ErrorCode(claudeError.Type), } + for _, op := range ops { + op(e) + } + return e } func IsChannelError(err *NewAPIError) bool { @@ -228,10 +265,16 @@ func IsChannelError(err *NewAPIError) bool { return strings.HasPrefix(string(err.errorCode), "channel:") } -func IsLocalError(err *NewAPIError) bool { +func IsSkipRetryError(err *NewAPIError) bool { if err == nil { return false } - return err.errorType == ErrorTypeNewAPIError + return err.skipRetry +} + +func ErrOptionWithSkipRetry() NewAPIErrorOptions { + return func(e *NewAPIError) { + e.skipRetry = true + } } diff --git a/web/src/components/table/channels/modals/EditChannelModal.jsx b/web/src/components/table/channels/modals/EditChannelModal.jsx index a62ec286..a3f09166 100644 --- a/web/src/components/table/channels/modals/EditChannelModal.jsx +++ b/web/src/components/table/channels/modals/EditChannelModal.jsx @@ -47,6 +47,7 @@ import { Highlight, } from '@douyinfe/semi-ui'; import { getChannelModels, copy, getChannelIcon, getModelCategories, selectFilter } from '../../../../helpers'; +import ModelSelectModal from './ModelSelectModal'; import { IconSave, IconClose, @@ -141,6 +142,8 @@ const EditChannelModal = (props) => { const [customModel, setCustomModel] = useState(''); const [modalImageUrl, setModalImageUrl] = useState(''); const [isModalOpenurl, setIsModalOpenurl] = useState(false); + const [modelModalVisible, setModelModalVisible] = useState(false); + const [fetchedModels, setFetchedModels] = useState([]); const formApiRef = useRef(null); const [vertexKeys, setVertexKeys] = useState([]); const [vertexFileList, setVertexFileList] = useState([]); @@ -378,7 +381,7 @@ const EditChannelModal = (props) => { // return; // } setLoading(true); - const models = inputs['models'] || []; + const models = []; let err = false; if (isEdit) { @@ -419,8 +422,9 @@ const EditChannelModal = (props) => { } if (!err) { - handleInputChange(name, Array.from(new Set(models))); - showSuccess(t('获取模型列表成功')); + const uniqueModels = Array.from(new Set(models)); + setFetchedModels(uniqueModels); + setModelModalVisible(true); } else { showError(t('获取模型列表失败')); } @@ -1650,6 +1654,17 @@ const EditChannelModal = (props) => { onVisibleChange={(visible) => setIsModalOpenurl(visible)} /> + { + handleInputChange('models', selectedModels); + showSuccess(t('模型列表已更新')); + setModelModalVisible(false); + }} + onCancel={() => setModelModalVisible(false)} + /> ); }; diff --git a/web/src/components/table/channels/modals/ModelSelectModal.jsx b/web/src/components/table/channels/modals/ModelSelectModal.jsx new file mode 100644 index 00000000..253d7254 --- /dev/null +++ b/web/src/components/table/channels/modals/ModelSelectModal.jsx @@ -0,0 +1,272 @@ +import React, { useState, useEffect } from 'react'; +import { useIsMobile } from '../../../../hooks/common/useIsMobile.js'; +import { Modal, Checkbox, Spin, Input, Typography, Empty, Tabs, Collapse } from '@douyinfe/semi-ui'; +import { + IllustrationNoResult, + IllustrationNoResultDark +} from '@douyinfe/semi-illustrations'; +import { IconSearch } from '@douyinfe/semi-icons'; +import { useTranslation } from 'react-i18next'; +import { getModelCategories } from '../../../../helpers/render'; + +const ModelSelectModal = ({ visible, models = [], selected = [], onConfirm, onCancel }) => { + const { t } = useTranslation(); + const [checkedList, setCheckedList] = useState(selected); + const [keyword, setKeyword] = useState(''); + const [activeTab, setActiveTab] = useState('new'); + + const isMobile = useIsMobile(); + + const filteredModels = models.filter((m) => m.toLowerCase().includes(keyword.toLowerCase())); + + // 分类模型:新获取的模型和已有模型 + const newModels = filteredModels.filter(model => !selected.includes(model)); + const existingModels = filteredModels.filter(model => selected.includes(model)); + + // 同步外部选中值 + useEffect(() => { + if (visible) { + setCheckedList(selected); + } + }, [visible, selected]); + + // 当模型列表变化时,设置默认tab + useEffect(() => { + if (visible) { + // 默认显示新获取模型tab,如果没有新模型则显示已有模型 + const hasNewModels = newModels.length > 0; + setActiveTab(hasNewModels ? 'new' : 'existing'); + } + }, [visible, newModels.length, selected]); + + const handleOk = () => { + onConfirm && onConfirm(checkedList); + }; + + // 按厂商分类模型 + const categorizeModels = (models) => { + const categories = getModelCategories(t); + const categorizedModels = {}; + const uncategorizedModels = []; + + models.forEach(model => { + let foundCategory = false; + for (const [key, category] of Object.entries(categories)) { + if (key !== 'all' && category.filter({ model_name: model })) { + if (!categorizedModels[key]) { + categorizedModels[key] = { + label: category.label, + icon: category.icon, + models: [] + }; + } + categorizedModels[key].models.push(model); + foundCategory = true; + break; + } + } + if (!foundCategory) { + uncategorizedModels.push(model); + } + }); + + // 如果有未分类模型,添加到"其他"分类 + if (uncategorizedModels.length > 0) { + categorizedModels['other'] = { + label: t('其他'), + icon: null, + models: uncategorizedModels + }; + } + + return categorizedModels; + }; + + const newModelsByCategory = categorizeModels(newModels); + const existingModelsByCategory = categorizeModels(existingModels); + + // Tab列表配置 + const tabList = [ + ...(newModels.length > 0 ? [{ + tab: `${t('新获取的模型')} (${newModels.length})`, + itemKey: 'new' + }] : []), + ...(existingModels.length > 0 ? [{ + tab: `${t('已有的模型')} (${existingModels.length})`, + itemKey: 'existing' + }] : []) + ]; + + // 处理分类全选/取消全选 + const handleCategorySelectAll = (categoryModels, isChecked) => { + let newCheckedList = [...checkedList]; + + if (isChecked) { + // 全选:添加该分类下所有未选中的模型 + categoryModels.forEach(model => { + if (!newCheckedList.includes(model)) { + newCheckedList.push(model); + } + }); + } else { + // 取消全选:移除该分类下所有已选中的模型 + newCheckedList = newCheckedList.filter(model => !categoryModels.includes(model)); + } + + setCheckedList(newCheckedList); + }; + + // 检查分类是否全选 + const isCategoryAllSelected = (categoryModels) => { + return categoryModels.length > 0 && categoryModels.every(model => checkedList.includes(model)); + }; + + // 检查分类是否部分选中 + const isCategoryIndeterminate = (categoryModels) => { + const selectedCount = categoryModels.filter(model => checkedList.includes(model)).length; + return selectedCount > 0 && selectedCount < categoryModels.length; + }; + + const renderModelsByCategory = (modelsByCategory, categoryKeyPrefix) => { + const categoryEntries = Object.entries(modelsByCategory); + if (categoryEntries.length === 0) return null; + + // 生成所有面板的key,确保都展开 + const allActiveKeys = categoryEntries.map((_, index) => `${categoryKeyPrefix}_${index}`); + + return ( + + {categoryEntries.map(([key, categoryData], index) => ( + { + e.stopPropagation(); // 防止触发面板折叠 + handleCategorySelectAll(categoryData.models, e.target.checked); + }} + onClick={(e) => e.stopPropagation()} // 防止点击checkbox时折叠面板 + /> + } + > +
+ {categoryData.icon} + + {t('已选择 {{selected}} / {{total}}', { + selected: categoryData.models.filter(model => checkedList.includes(model)).length, + total: categoryData.models.length + })} + +
+
+ {categoryData.models.map((model) => ( + + {model} + + ))} +
+
+ ))} +
+ ); + }; + + return ( + + + {t('选择模型')} + +
+ setActiveTab(key)} + /> +
+ + } + visible={visible} + onOk={handleOk} + onCancel={onCancel} + okText={t('确定')} + cancelText={t('取消')} + size={isMobile ? 'full-width' : 'large'} + closeOnEsc + maskClosable + centered + > + } + placeholder={t('搜索模型')} + value={keyword} + onChange={(v) => setKeyword(v)} + showClear + /> + + +
+ {filteredModels.length === 0 ? ( + } + darkModeImage={} + description={t('暂无匹配模型')} + style={{ padding: 30 }} + /> + ) : ( + setCheckedList(vals)}> + {activeTab === 'new' && newModels.length > 0 && ( +
+ {renderModelsByCategory(newModelsByCategory, 'new')} +
+ )} + {activeTab === 'existing' && existingModels.length > 0 && ( +
+ {renderModelsByCategory(existingModelsByCategory, 'existing')} +
+ )} +
+ )} +
+
+ + +
+ {(() => { + const currentModels = activeTab === 'new' ? newModels : existingModels; + const currentSelected = currentModels.filter(model => checkedList.includes(model)).length; + const isAllSelected = currentModels.length > 0 && currentSelected === currentModels.length; + const isIndeterminate = currentSelected > 0 && currentSelected < currentModels.length; + + return ( + <> + + {t('已选择 {{selected}} / {{total}}', { + selected: currentSelected, + total: currentModels.length + })} + + { + handleCategorySelectAll(currentModels, e.target.checked); + }} + /> + + ); + })()} +
+
+
+ ); +}; + +export default ModelSelectModal; \ No newline at end of file diff --git a/web/src/helpers/render.js b/web/src/helpers/render.js index bd0a8131..1178d5f9 100644 --- a/web/src/helpers/render.js +++ b/web/src/helpers/render.js @@ -883,12 +883,22 @@ export function renderQuotaWithAmount(amount) { } export function renderQuota(quota, digits = 2) { + let quotaPerUnit = localStorage.getItem('quota_per_unit'); let displayInCurrency = localStorage.getItem('display_in_currency'); quotaPerUnit = parseFloat(quotaPerUnit); displayInCurrency = displayInCurrency === 'true'; if (displayInCurrency) { - return '$' + (quota / quotaPerUnit).toFixed(digits); + const result = quota / quotaPerUnit; + const fixedResult = result.toFixed(digits); + + // 如果 toFixed 后结果为 0 但原始值不为 0,显示最小值 + if (parseFloat(fixedResult) === 0 && quota > 0 && result > 0) { + const minValue = Math.pow(10, -digits); + return '$' + minValue.toFixed(digits); + } + + return '$' + fixedResult; } return renderNumber(quota); } diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index 5d4cbef4..b965c192 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -1799,5 +1799,10 @@ "显示第": "Showing", "条 - 第": "to", "条,共": "of", - "条": "items" + "条": "items", + "选择模型": "Select model", + "已选择 {{selected}} / {{total}}": "Selected {{selected}} / {{total}}", + "新获取的模型": "New models", + "已有的模型": "Existing models", + "搜索模型": "Search models" } \ No newline at end of file diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index efb355df..bbdd623e 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -147,6 +147,7 @@ export default function RequestRateLimit(props) { label={t('用户每周期最多请求次数')} step={1} min={0} + max={100000000} suffix={t('次')} extraText={t('包括失败请求的次数,0代表不限制')} field={'ModelRequestRateLimitCount'} @@ -163,6 +164,7 @@ export default function RequestRateLimit(props) { label={t('用户每周期最多请求完成次数')} step={1} min={1} + max={100000000} suffix={t('次')} extraText={t('只包括请求成功的次数')} field={'ModelRequestRateLimitSuccessCount'} @@ -199,6 +201,7 @@ export default function RequestRateLimit(props) {
  • {t('使用 JSON 对象格式,格式为:{"组名": [最多请求次数, 最多请求完成次数]}')}
  • {t('示例:{"default": [200, 100], "vip": [0, 1000]}。')}
  • {t('[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1。')}
  • +
  • {t('[最多请求次数]和[最多请求完成次数]的最大值为2147483647。')}
  • {t('分组速率配置优先级高于全局速率限制。')}
  • {t('限制周期统一使用上方配置的“限制周期”值。')}