diff --git a/common/gin.go b/common/gin.go index d428184a..f876a92b 100644 --- a/common/gin.go +++ b/common/gin.go @@ -32,7 +32,7 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { } contentType := c.Request.Header.Get("Content-Type") if strings.HasPrefix(contentType, "application/json") { - err = UnmarshalJson(requestBody, &v) + err = Unmarshal(requestBody, &v) } else { // skip for now // TODO: someday non json request have variant model, we will need to implementation this diff --git a/common/json.go b/common/json.go index 512ad0c3..69aa952e 100644 --- a/common/json.go +++ b/common/json.go @@ -5,7 +5,7 @@ import ( "encoding/json" ) -func UnmarshalJson(data []byte, v any) error { +func Unmarshal(data []byte, v any) error { return json.Unmarshal(data, v) } @@ -17,6 +17,6 @@ func DecodeJson(reader *bytes.Reader, v any) error { return json.NewDecoder(reader).Decode(v) } -func EncodeJson(v any) ([]byte, error) { +func Marshal(v any) ([]byte, error) { return json.Marshal(v) } diff --git a/common/str.go b/common/str.go index 76dd801a..88b58c72 100644 --- a/common/str.go +++ b/common/str.go @@ -32,16 +32,30 @@ func MapToJsonStr(m map[string]interface{}) string { return string(bytes) } -func StrToMap(str string) map[string]interface{} { +func StrToMap(str string) (map[string]interface{}, error) { m := make(map[string]interface{}) - err := json.Unmarshal([]byte(str), &m) + err := Unmarshal([]byte(str), &m) if err != nil { - return nil + return nil, err } - return m + return m, nil } -func IsJsonStr(str string) bool { +func StrToJsonArray(str string) ([]interface{}, error) { + var js []interface{} + err := json.Unmarshal([]byte(str), &js) + if err != nil { + return nil, err + } + return js, nil +} + +func IsJsonArray(str string) bool { + var js []interface{} + return json.Unmarshal([]byte(str), &js) == nil +} + +func IsJsonObject(str string) bool { var js map[string]interface{} return json.Unmarshal([]byte(str), &js) == nil } diff --git a/constant/context_key.go b/constant/context_key.go index 71e02f01..4eaf3d00 100644 --- a/constant/context_key.go +++ b/constant/context_key.go @@ -17,11 +17,20 @@ const ( ContextKeyTokenModelLimit ContextKey = "token_model_limit" /* channel related keys */ - ContextKeyBaseUrl ContextKey = "base_url" - ContextKeyChannelType ContextKey = "channel_type" - ContextKeyChannelId ContextKey = "channel_id" - ContextKeyChannelSetting ContextKey = "channel_setting" - ContextKeyParamOverride ContextKey = "param_override" + ContextKeyChannelId ContextKey = "channel_id" + ContextKeyChannelName ContextKey = "channel_name" + ContextKeyChannelCreateTime ContextKey = "channel_create_time" + ContextKeyChannelBaseUrl ContextKey = "base_url" + ContextKeyChannelType ContextKey = "channel_type" + ContextKeyChannelSetting ContextKey = "channel_setting" + ContextKeyChannelParamOverride ContextKey = "param_override" + ContextKeyChannelOrganization ContextKey = "channel_organization" + ContextKeyChannelAutoBan ContextKey = "auto_ban" + ContextKeyChannelModelMapping ContextKey = "model_mapping" + ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping" + ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key" + ContextKeyChannelMultiKeyIndex ContextKey = "channel_multi_key_index" + ContextKeyChannelKey ContextKey = "channel_key" /* user related keys */ ContextKeyUserId ContextKey = "id" diff --git a/constant/multi_key_mode.go b/constant/multi_key_mode.go new file mode 100644 index 00000000..cd0cdbff --- /dev/null +++ b/constant/multi_key_mode.go @@ -0,0 +1,8 @@ +package constant + +type MultiKeyMode string + +const ( + MultiKeyModeRandom MultiKeyMode = "random" // 随机 + MultiKeyModePolling MultiKeyMode = "polling" // 轮询 +) diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 3c92c78b..2c2c25b9 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -12,6 +12,7 @@ import ( "one-api/model" "one-api/service" "one-api/setting" + "one-api/types" "strconv" "time" @@ -415,7 +416,7 @@ func UpdateChannelBalance(c *gin.Context) { }) return } - channel, err := model.GetChannelById(id, true) + channel, err := model.CacheGetChannel(id) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -423,6 +424,13 @@ func UpdateChannelBalance(c *gin.Context) { }) return } + if channel.ChannelInfo.IsMultiKey { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "多密钥渠道不支持余额查询", + }) + return + } balance, err := updateChannelBalance(channel) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -436,7 +444,6 @@ func UpdateChannelBalance(c *gin.Context) { "message": "", "balance": balance, }) - return } func updateAllChannelsBalance() error { @@ -448,6 +455,9 @@ func updateAllChannelsBalance() error { if channel.Status != common.ChannelStatusEnabled { continue } + if channel.ChannelInfo.IsMultiKey { + continue // skip multi-key channels + } // TODO: support Azure //if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom { // continue @@ -458,7 +468,7 @@ func updateAllChannelsBalance() error { } else { // err is nil & balance <= 0 means quota is used up if balance <= 0 { - service.DisableChannel(channel.Id, channel.Name, "余额不足") + service.DisableChannel(*types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, "", channel.GetAutoBan()), "余额不足") } } time.Sleep(common.RequestInterval) diff --git a/controller/channel-test.go b/controller/channel-test.go index efc8cd96..203c91a2 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -19,6 +19,7 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" + "one-api/types" "strconv" "strings" "sync" @@ -29,22 +30,43 @@ import ( "github.com/gin-gonic/gin" ) -func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) { +type testResult struct { + context *gin.Context + localErr error + newAPIError *types.NewAPIError +} + +func testChannel(channel *model.Channel, testModel string) testResult { tik := time.Now() if channel.Type == constant.ChannelTypeMidjourney { - return errors.New("midjourney channel test is not supported"), nil + return testResult{ + localErr: errors.New("midjourney channel test is not supported"), + newAPIError: nil, + } } if channel.Type == constant.ChannelTypeMidjourneyPlus { - return errors.New("midjourney plus channel test is not supported"), nil + return testResult{ + localErr: errors.New("midjourney plus channel test is not supported"), + newAPIError: nil, + } } if channel.Type == constant.ChannelTypeSunoAPI { - return errors.New("suno channel test is not supported"), nil + return testResult{ + localErr: errors.New("suno channel test is not supported"), + newAPIError: nil, + } } if channel.Type == constant.ChannelTypeKling { - return errors.New("kling channel test is not supported"), nil + return testResult{ + localErr: errors.New("kling channel test is not supported"), + newAPIError: nil, + } } if channel.Type == constant.ChannelTypeJimeng { - return errors.New("jimeng channel test is not supported"), nil + return testResult{ + localErr: errors.New("jimeng channel test is not supported"), + newAPIError: nil, + } } w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -81,31 +103,49 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr cache, err := model.GetUserCache(1) if err != nil { - return err, nil + return testResult{ + localErr: err, + newAPIError: nil, + } } cache.WriteContext(c) - c.Request.Header.Set("Authorization", "Bearer "+channel.Key) + //c.Request.Header.Set("Authorization", "Bearer "+channel.Key) c.Request.Header.Set("Content-Type", "application/json") c.Set("channel", channel.Type) c.Set("base_url", channel.GetBaseURL()) group, _ := model.GetUserGroup(1, false) c.Set("group", group) - middleware.SetupContextForSelectedChannel(c, channel, testModel) + newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel) + if newAPIError != nil { + return testResult{ + context: c, + localErr: newAPIError, + newAPIError: newAPIError, + } + } info := relaycommon.GenRelayInfo(c) err = helper.ModelMappedHelper(c, info, nil) if err != nil { - return err, nil + return testResult{ + context: c, + localErr: err, + newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError), + } } testModel = info.UpstreamModelName apiType, _ := common.ChannelType2APIType(channel.Type) adaptor := relay.GetAdaptor(apiType) if adaptor == nil { - return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil + return testResult{ + context: c, + localErr: fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), + newAPIError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType), + } } request := buildTestRequest(testModel) @@ -116,45 +156,77 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens)) if err != nil { - return err, nil + return testResult{ + context: c, + localErr: err, + newAPIError: types.NewError(err, types.ErrorCodeModelPriceError), + } } adaptor.Init(info) convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request) if err != nil { - return err, nil + return testResult{ + context: c, + localErr: err, + newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed), + } } jsonData, err := json.Marshal(convertedRequest) if err != nil { - return err, nil + return testResult{ + context: c, + localErr: err, + newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed), + } } requestBody := bytes.NewBuffer(jsonData) c.Request.Body = io.NopCloser(requestBody) resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { - return err, nil + return testResult{ + context: c, + localErr: err, + newAPIError: types.NewError(err, types.ErrorCodeDoRequestFailed), + } } var httpResp *http.Response if resp != nil { httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { err := service.RelayErrorHandler(httpResp, true) - return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err + return testResult{ + context: c, + localErr: err, + newAPIError: types.NewError(err, types.ErrorCodeBadResponse), + } } } usageA, respErr := adaptor.DoResponse(c, httpResp, info) if respErr != nil { - return fmt.Errorf("%s", respErr.Error.Message), respErr + return testResult{ + context: c, + localErr: respErr, + newAPIError: respErr, + } } if usageA == nil { - return errors.New("usage is nil"), nil + return testResult{ + context: c, + localErr: errors.New("usage is nil"), + newAPIError: types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody), + } } usage := usageA.(*dto.Usage) result := w.Result() respBody, err := io.ReadAll(result.Body) if err != nil { - return err, nil + return testResult{ + context: c, + localErr: err, + newAPIError: types.NewError(err, types.ErrorCodeReadResponseBodyFailed), + } } info.PromptTokens = usage.PromptTokens @@ -187,7 +259,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr Other: other, }) common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) - return nil, nil + return testResult{ + context: c, + localErr: nil, + newAPIError: nil, + } } func buildTestRequest(model string) *dto.GeneralOpenAIRequest { @@ -236,7 +312,7 @@ func TestChannel(c *gin.Context) { }) return } - channel, err := model.GetChannelById(channelId, true) + channel, err := model.CacheGetChannel(channelId) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -244,17 +320,30 @@ func TestChannel(c *gin.Context) { }) return } + //defer func() { + // if channel.ChannelInfo.IsMultiKey { + // go func() { _ = channel.SaveChannelInfo() }() + // } + //}() testModel := c.Query("model") tik := time.Now() - err, _ = testChannel(channel, testModel) + result := testChannel(channel, testModel) + if result.localErr != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": result.localErr.Error(), + "time": 0.0, + }) + return + } tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() go channel.UpdateResponseTime(milliseconds) consumedTime := float64(milliseconds) / 1000.0 - if err != nil { + if result.newAPIError != nil { c.JSON(http.StatusOK, gin.H{ "success": false, - "message": err.Error(), + "message": result.newAPIError.Error(), "time": consumedTime, }) return @@ -279,9 +368,9 @@ func testAllChannels(notify bool) error { } testAllChannelsRunning = true testAllChannelsLock.Unlock() - channels, err := model.GetAllChannels(0, 0, true, false) - if err != nil { - return err + channels, getChannelErr := model.GetAllChannels(0, 0, true, false) + if getChannelErr != nil { + return getChannelErr } var disableThreshold = int64(common.ChannelDisableThreshold * 1000) if disableThreshold == 0 { @@ -298,32 +387,34 @@ func testAllChannels(notify bool) error { for _, channel := range channels { isChannelEnabled := channel.Status == common.ChannelStatusEnabled tik := time.Now() - err, openaiWithStatusErr := testChannel(channel, "") + result := testChannel(channel, "") tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() shouldBanChannel := false - + newAPIError := result.newAPIError // request error disables the channel - if openaiWithStatusErr != nil { - oaiErr := openaiWithStatusErr.Error - err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message)) - shouldBanChannel = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr) + if newAPIError != nil { + shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError) } - if milliseconds > disableThreshold { - err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) - shouldBanChannel = true + // 当错误检查通过,才检查响应时间 + if common.AutomaticDisableChannelEnabled && !shouldBanChannel { + if milliseconds > disableThreshold { + err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) + newAPIError = types.NewError(err, types.ErrorCodeChannelResponseTimeExceeded) + shouldBanChannel = true + } } // disable channel if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() { - service.DisableChannel(channel.Id, channel.Name, err.Error()) + go processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) } // enable channel - if !isChannelEnabled && service.ShouldEnableChannel(err, openaiWithStatusErr, channel.Status) { - service.EnableChannel(channel.Id, channel.Name) + if !isChannelEnabled && service.ShouldEnableChannel(newAPIError, channel.Status) { + service.EnableChannel(channel.Id, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.Name) } channel.UpdateResponseTime(milliseconds) diff --git a/controller/channel.go b/controller/channel.go index 85b14b43..ee6ddeba 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -380,9 +380,47 @@ func GetChannel(c *gin.Context) { return } +type AddChannelRequest struct { + Mode string `json:"mode"` + MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"` + Channel *model.Channel `json:"channel"` +} + +func getVertexArrayKeys(keys string) ([]string, error) { + if keys == "" { + return nil, nil + } + var keyArray []interface{} + err := common.Unmarshal([]byte(keys), &keyArray) + if err != nil { + return nil, fmt.Errorf("批量添加 Vertex AI 必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入: %w", err) + } + cleanKeys := make([]string, 0, len(keyArray)) + for _, key := range keyArray { + var keyStr string + switch v := key.(type) { + case string: + keyStr = strings.TrimSpace(v) + default: + bytes, err := json.Marshal(v) + if err != nil { + return nil, fmt.Errorf("Vertex AI key JSON 编码失败: %w", err) + } + keyStr = string(bytes) + } + if keyStr != "" { + cleanKeys = append(cleanKeys, keyStr) + } + } + if len(cleanKeys) == 0 { + return nil, fmt.Errorf("批量添加 Vertex AI 的 keys 不能为空") + } + return cleanKeys, nil +} + func AddChannel(c *gin.Context) { - channel := model.Channel{} - err := c.ShouldBindJSON(&channel) + addChannelRequest := AddChannelRequest{} + err := c.ShouldBindJSON(&addChannelRequest) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -390,7 +428,8 @@ func AddChannel(c *gin.Context) { }) return } - err = channel.ValidateSettings() + + err = addChannelRequest.Channel.ValidateSettings() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -398,49 +437,113 @@ func AddChannel(c *gin.Context) { }) return } - channel.CreatedTime = common.GetTimestamp() - keys := strings.Split(channel.Key, "\n") - if channel.Type == constant.ChannelTypeVertexAi { - if channel.Other == "" { + + if addChannelRequest.Channel == nil || addChannelRequest.Channel.Key == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "channel cannot be empty", + }) + return + } + + // Validate the length of the model name + for _, m := range addChannelRequest.Channel.GetModels() { + if len(m) > 255 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": fmt.Sprintf("模型名称过长: %s", m), + }) + return + } + } + if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi { + if addChannelRequest.Channel.Other == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "部署地区不能为空", }) return } else { - if common.IsJsonStr(channel.Other) { - // must have default - regionMap := common.StrToMap(channel.Other) - if regionMap["default"] == nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "部署地区必须包含default字段", - }) - return - } + regionMap, err := common.StrToMap(addChannelRequest.Channel.Other) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}", + }) + return + } + if regionMap["default"] == nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "部署地区必须包含default字段", + }) + return } } - keys = []string{channel.Key} } + + addChannelRequest.Channel.CreatedTime = common.GetTimestamp() + keys := make([]string, 0) + switch addChannelRequest.Mode { + case "multi_to_single": + addChannelRequest.Channel.ChannelInfo.IsMultiKey = true + addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode + if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi { + array, err := getVertexArrayKeys(addChannelRequest.Channel.Key) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(array) + addChannelRequest.Channel.Key = strings.Join(array, "\n") + } else { + cleanKeys := make([]string, 0) + for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") { + if key == "" { + continue + } + key = strings.TrimSpace(key) + cleanKeys = append(cleanKeys, key) + } + addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys) + addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n") + } + keys = []string{addChannelRequest.Channel.Key} + case "batch": + if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi { + // multi json + keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + keys = strings.Split(addChannelRequest.Channel.Key, "\n") + } + case "single": + keys = []string{addChannelRequest.Channel.Key} + default: + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "不支持的添加模式", + }) + return + } + channels := make([]model.Channel, 0, len(keys)) for _, key := range keys { if key == "" { continue } - localChannel := channel + localChannel := addChannelRequest.Channel localChannel.Key = key - // Validate the length of the model name - models := strings.Split(localChannel.Models, ",") - for _, model := range models { - if len(model) > 255 { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": fmt.Sprintf("模型名称过长: %s", model), - }) - return - } - } - channels = append(channels, localChannel) + channels = append(channels, *localChannel) } err = model.BatchInsertChannels(channels) if err != nil { @@ -615,8 +718,13 @@ func DeleteChannelBatch(c *gin.Context) { return } +type PatchChannel struct { + model.Channel + MultiKeyMode *string `json:"multi_key_mode"` +} + func UpdateChannel(c *gin.Context) { - channel := model.Channel{} + channel := PatchChannel{} err := c.ShouldBindJSON(&channel) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -641,17 +749,34 @@ func UpdateChannel(c *gin.Context) { }) return } else { - if common.IsJsonStr(channel.Other) { - // must have default - regionMap := common.StrToMap(channel.Other) - if regionMap["default"] == nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "部署地区必须包含default字段", - }) - return - } + regionMap, err := common.StrToMap(channel.Other) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}", + }) + return } + if regionMap["default"] == nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "部署地区必须包含default字段", + }) + return + } + } + } + if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" { + originChannel, err := model.GetChannelById(channel.Id, false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + } + if originChannel.ChannelInfo.IsMultiKey { + channel.ChannelInfo = originChannel.ChannelInfo + channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode) } } err = channel.Update() diff --git a/controller/playground.go b/controller/playground.go index 33471455..e071d12e 100644 --- a/controller/playground.go +++ b/controller/playground.go @@ -3,45 +3,44 @@ package controller import ( "errors" "fmt" - "net/http" "one-api/common" "one-api/constant" "one-api/dto" "one-api/middleware" "one-api/model" - "one-api/service" "one-api/setting" + "one-api/types" "time" "github.com/gin-gonic/gin" ) func Playground(c *gin.Context) { - var openaiErr *dto.OpenAIErrorWithStatusCode + var newAPIError *types.NewAPIError defer func() { - if openaiErr != nil { - c.JSON(openaiErr.StatusCode, gin.H{ - "error": openaiErr.Error, + if newAPIError != nil { + c.JSON(newAPIError.StatusCode, gin.H{ + "error": newAPIError.ToOpenAIError(), }) } }() useAccessToken := c.GetBool("use_access_token") if useAccessToken { - openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest) + newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied) return } playgroundRequest := &dto.PlayGroundRequest{} err := common.UnmarshalBodyReusable(c, playgroundRequest) if err != nil { - openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest) + newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest) return } if playgroundRequest.Model == "" { - openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest) + newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest) return } c.Set("original_model", playgroundRequest.Model) @@ -52,26 +51,32 @@ func Playground(c *gin.Context) { group = userGroup } else { if !setting.GroupInUserUsableGroups(group) && group != userGroup { - openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden) + newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied) return } c.Set("group", group) } - c.Set("token_name", "playground-"+group) - channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0) + + userId := c.GetInt("id") + //c.Set("token_name", "playground-"+group) + tempToken := &model.Token{ + UserId: userId, + Name: fmt.Sprintf("playground-%s", group), + Group: group, + } + _ = middleware.SetupContextForToken(c, tempToken) + _, err = getChannel(c, group, playgroundRequest.Model, 0) if err != nil { - message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model) - openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError) + newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed) return } - middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model) + //middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model) common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now()) // Write user context to ensure acceptUnsetRatio is available - userId := c.GetInt("id") userCache, err := model.GetUserCache(userId) if err != nil { - openaiErr = service.OpenAIErrorWrapperLocal(err, "get_user_cache_failed", http.StatusInternalServerError) + newAPIError = types.NewError(err, types.ErrorCodeQueryDataError) return } userCache.WriteContext(c) diff --git a/controller/relay.go b/controller/relay.go index e375120b..b224b42c 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -17,14 +17,15 @@ import ( relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" + "one-api/types" "strings" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" ) -func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { - var err *dto.OpenAIErrorWithStatusCode +func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError { + var err *types.NewAPIError switch relayMode { case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits: err = relay.ImageHelper(c) @@ -55,14 +56,14 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode userGroup := c.GetString("group") channelId := c.GetInt("channel_id") other := make(map[string]interface{}) - other["error_type"] = err.Error.Type - other["error_code"] = err.Error.Code + other["error_type"] = err.ErrorType + other["error_code"] = err.GetErrorCode() other["status_code"] = err.StatusCode other["channel_id"] = channelId other["channel_name"] = c.GetString("channel_name") other["channel_type"] = c.GetInt("channel_type") - model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error.Message, tokenId, 0, false, userGroup, other) + model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error(), tokenId, 0, false, userGroup, other) } return err @@ -73,25 +74,25 @@ func Relay(c *gin.Context) { requestId := c.GetString(common.RequestIdKey) group := c.GetString("group") originalModel := c.GetString("original_model") - var openaiErr *dto.OpenAIErrorWithStatusCode + var newAPIError *types.NewAPIError for i := 0; i <= common.RetryTimes; i++ { channel, err := getChannel(c, group, originalModel, i) if err != nil { common.LogError(c, err.Error()) - openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError) + newAPIError = err break } - openaiErr = relayRequest(c, relayMode, channel) + newAPIError = relayRequest(c, relayMode, channel) - if openaiErr == nil { + if newAPIError == nil { return // 成功处理请求,直接返回 } - go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr) + go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) - if !shouldRetry(c, openaiErr, common.RetryTimes-i) { + if !shouldRetry(c, newAPIError, common.RetryTimes-i) { break } } @@ -101,14 +102,14 @@ func Relay(c *gin.Context) { common.LogInfo(c, retryLogStr) } - if openaiErr != nil { - if openaiErr.StatusCode == http.StatusTooManyRequests { - common.LogError(c, fmt.Sprintf("origin 429 error: %s", openaiErr.Error.Message)) - openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" - } - openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId) - c.JSON(openaiErr.StatusCode, gin.H{ - "error": openaiErr.Error, + if newAPIError != nil { + //if newAPIError.StatusCode == http.StatusTooManyRequests { + // common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error())) + // newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试") + //} + newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId)) + c.JSON(newAPIError.StatusCode, gin.H{ + "error": newAPIError.ToOpenAIError(), }) } } @@ -127,8 +128,7 @@ func WssRelay(c *gin.Context) { defer ws.Close() if err != nil { - openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError) - helper.WssError(c, ws, openaiErr.Error) + helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed).ToOpenAIError()) return } @@ -137,25 +137,25 @@ func WssRelay(c *gin.Context) { group := c.GetString("group") //wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01 originalModel := c.GetString("original_model") - var openaiErr *dto.OpenAIErrorWithStatusCode + var newAPIError *types.NewAPIError for i := 0; i <= common.RetryTimes; i++ { channel, err := getChannel(c, group, originalModel, i) if err != nil { common.LogError(c, err.Error()) - openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError) + newAPIError = err break } - openaiErr = wssRequest(c, ws, relayMode, channel) + newAPIError = wssRequest(c, ws, relayMode, channel) - if openaiErr == nil { + if newAPIError == nil { return // 成功处理请求,直接返回 } - go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr) + go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) - if !shouldRetry(c, openaiErr, common.RetryTimes-i) { + if !shouldRetry(c, newAPIError, common.RetryTimes-i) { break } } @@ -165,12 +165,12 @@ func WssRelay(c *gin.Context) { common.LogInfo(c, retryLogStr) } - if openaiErr != nil { - if openaiErr.StatusCode == http.StatusTooManyRequests { - openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" - } - openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId) - helper.WssError(c, ws, openaiErr.Error) + if newAPIError != nil { + //if newAPIError.StatusCode == http.StatusTooManyRequests { + // newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试") + //} + newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId)) + helper.WssError(c, ws, newAPIError.ToOpenAIError()) } } @@ -179,27 +179,25 @@ func RelayClaude(c *gin.Context) { requestId := c.GetString(common.RequestIdKey) group := c.GetString("group") originalModel := c.GetString("original_model") - var claudeErr *dto.ClaudeErrorWithStatusCode + var newAPIError *types.NewAPIError for i := 0; i <= common.RetryTimes; i++ { channel, err := getChannel(c, group, originalModel, i) if err != nil { common.LogError(c, err.Error()) - claudeErr = service.ClaudeErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError) + newAPIError = err break } - claudeErr = claudeRequest(c, channel) + newAPIError = claudeRequest(c, channel) - if claudeErr == nil { + if newAPIError == nil { return // 成功处理请求,直接返回 } - openaiErr := service.ClaudeErrorToOpenAIError(claudeErr) + go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) - go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr) - - if !shouldRetry(c, openaiErr, common.RetryTimes-i) { + if !shouldRetry(c, newAPIError, common.RetryTimes-i) { break } } @@ -209,30 +207,30 @@ func RelayClaude(c *gin.Context) { common.LogInfo(c, retryLogStr) } - if claudeErr != nil { - claudeErr.Error.Message = common.MessageWithRequestId(claudeErr.Error.Message, requestId) - c.JSON(claudeErr.StatusCode, gin.H{ + if newAPIError != nil { + newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId)) + c.JSON(newAPIError.StatusCode, gin.H{ "type": "error", - "error": claudeErr.Error, + "error": newAPIError.ToClaudeError(), }) } } -func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode { +func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *types.NewAPIError { addUsedChannel(c, channel.Id) requestBody, _ := common.GetRequestBody(c) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) return relayHandler(c, relayMode) } -func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode { +func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *types.NewAPIError { addUsedChannel(c, channel.Id) requestBody, _ := common.GetRequestBody(c) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) return relay.WssHelper(c, ws) } -func claudeRequest(c *gin.Context, channel *model.Channel) *dto.ClaudeErrorWithStatusCode { +func claudeRequest(c *gin.Context, channel *model.Channel) *types.NewAPIError { addUsedChannel(c, channel.Id) requestBody, _ := common.GetRequestBody(c) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) @@ -245,7 +243,7 @@ func addUsedChannel(c *gin.Context, channelId int) { c.Set("use_channel", useChannel) } -func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) { +func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, *types.NewAPIError) { if retryCount == 0 { autoBan := c.GetBool("auto_ban") autoBanInt := 1 @@ -259,19 +257,28 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m AutoBan: &autoBanInt, }, nil } - channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount) + channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount) if err != nil { - return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error())) + if group == "auto" { + return nil, types.NewError(errors.New(fmt.Sprintf("获取自动分组下模型 %s 的可用渠道失败: %s", originalModel, err.Error())), types.ErrorCodeGetChannelFailed) + } + return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败: %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed) + } + newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel) + if newAPIError != nil { + return nil, newAPIError } - middleware.SetupContextForSelectedChannel(c, channel, originalModel) return channel, nil } -func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool { +func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) bool { if openaiErr == nil { return false } - if openaiErr.LocalError { + if types.IsChannelError(openaiErr) { + return true + } + if types.IsLocalError(openaiErr) { return false } if retryTimes <= 0 { @@ -310,12 +317,12 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry return true } -func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) { +func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) { // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况 // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously - common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message)) - if service.ShouldDisableChannel(channelType, err) && autoBan { - service.DisableChannel(channelId, channelName, err.Error.Message) + common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error())) + if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan { + service.DisableChannel(channelError, err.Error()) } } @@ -388,9 +395,10 @@ func RelayTask(c *gin.Context) { retryTimes = 0 } for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ { - channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, i) - if err != nil { - common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) + channel, newAPIError := getChannel(c, group, originalModel, i) + if newAPIError != nil { + common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error())) + taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError) break } channelId = channel.Id @@ -398,9 +406,9 @@ func RelayTask(c *gin.Context) { useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) c.Set("use_channel", useChannel) common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) - middleware.SetupContextForSelectedChannel(c, channel, originalModel) + //middleware.SetupContextForSelectedChannel(c, channel, originalModel) - requestBody, err := common.GetRequestBody(c) + requestBody, _ := common.GetRequestBody(c) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) taskErr = taskRelayHandler(c, relayMode) } diff --git a/dto/claude.go b/dto/claude.go index 98e09c78..b5d43f23 100644 --- a/dto/claude.go +++ b/dto/claude.go @@ -3,6 +3,7 @@ package dto import ( "encoding/json" "one-api/common" + "one-api/types" ) type ClaudeMetadata struct { @@ -228,7 +229,7 @@ type ClaudeResponse struct { Completion string `json:"completion,omitempty"` StopReason string `json:"stop_reason,omitempty"` Model string `json:"model,omitempty"` - Error *ClaudeError `json:"error,omitempty"` + Error *types.ClaudeError `json:"error,omitempty"` Usage *ClaudeUsage `json:"usage,omitempty"` Index *int `json:"index,omitempty"` ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"` diff --git a/dto/error.go b/dto/error.go index b347f6a1..d7f6824d 100644 --- a/dto/error.go +++ b/dto/error.go @@ -1,5 +1,7 @@ package dto +import "one-api/types" + type OpenAIError struct { Message string `json:"message"` Type string `json:"type"` @@ -14,11 +16,11 @@ type OpenAIErrorWithStatusCode struct { } type GeneralErrorResponse struct { - Error OpenAIError `json:"error"` - Message string `json:"message"` - Msg string `json:"msg"` - Err string `json:"err"` - ErrorMsg string `json:"error_msg"` + Error types.OpenAIError `json:"error"` + Message string `json:"message"` + Msg string `json:"msg"` + Err string `json:"err"` + ErrorMsg string `json:"error_msg"` Header struct { Message string `json:"message"` } `json:"header"` diff --git a/dto/openai_request.go b/dto/openai_request.go index aa4f1962..6cb554c7 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -66,8 +66,8 @@ type GeneralOpenAIRequest struct { func (r *GeneralOpenAIRequest) ToMap() map[string]any { result := make(map[string]any) - data, _ := common.EncodeJson(r) - _ = common.UnmarshalJson(data, &result) + data, _ := common.Marshal(r) + _ = common.Unmarshal(data, &result) return result } diff --git a/dto/openai_response.go b/dto/openai_response.go index d95acd9e..64601427 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -1,6 +1,9 @@ package dto -import "encoding/json" +import ( + "encoding/json" + "one-api/types" +) type SimpleResponse struct { Usage `json:"usage"` @@ -28,7 +31,7 @@ type OpenAITextResponse struct { Object string `json:"object"` Created any `json:"created"` Choices []OpenAITextResponseChoice `json:"choices"` - Error *OpenAIError `json:"error,omitempty"` + Error *types.OpenAIError `json:"error,omitempty"` Usage `json:"usage"` } @@ -201,7 +204,7 @@ type OpenAIResponsesResponse struct { Object string `json:"object"` CreatedAt int `json:"created_at"` Status string `json:"status"` - Error *OpenAIError `json:"error,omitempty"` + Error *types.OpenAIError `json:"error,omitempty"` IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` Instructions string `json:"instructions"` MaxOutputTokens int `json:"max_output_tokens"` diff --git a/dto/realtime.go b/dto/realtime.go index 86ae352d..32a69056 100644 --- a/dto/realtime.go +++ b/dto/realtime.go @@ -1,5 +1,7 @@ package dto +import "one-api/types" + const ( RealtimeEventTypeError = "error" RealtimeEventTypeSessionUpdate = "session.update" @@ -23,12 +25,12 @@ type RealtimeEvent struct { EventId string `json:"event_id"` Type string `json:"type"` //PreviousItemId string `json:"previous_item_id"` - Session *RealtimeSession `json:"session,omitempty"` - Item *RealtimeItem `json:"item,omitempty"` - Error *OpenAIError `json:"error,omitempty"` - Response *RealtimeResponse `json:"response,omitempty"` - Delta string `json:"delta,omitempty"` - Audio string `json:"audio,omitempty"` + Session *RealtimeSession `json:"session,omitempty"` + Item *RealtimeItem `json:"item,omitempty"` + Error *types.OpenAIError `json:"error,omitempty"` + Response *RealtimeResponse `json:"response,omitempty"` + Delta string `json:"delta,omitempty"` + Audio string `json:"audio,omitempty"` } type RealtimeResponse struct { diff --git a/middleware/auth.go b/middleware/auth.go index ecf4844b..47d033a9 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -1,6 +1,7 @@ package middleware import ( + "fmt" "net/http" "one-api/common" "one-api/model" @@ -233,30 +234,41 @@ func TokenAuth() func(c *gin.Context) { userCache.WriteContext(c) - c.Set("id", token.UserId) - c.Set("token_id", token.Id) - c.Set("token_key", token.Key) - c.Set("token_name", token.Name) - c.Set("token_unlimited_quota", token.UnlimitedQuota) - if !token.UnlimitedQuota { - c.Set("token_quota", token.RemainQuota) - } - if token.ModelLimitsEnabled { - c.Set("token_model_limit_enabled", true) - c.Set("token_model_limit", token.GetModelLimitsMap()) - } else { - c.Set("token_model_limit_enabled", false) - } - c.Set("allow_ips", token.GetIpLimitsMap()) - c.Set("token_group", token.Group) - if len(parts) > 1 { - if model.IsAdmin(token.UserId) { - c.Set("specific_channel_id", parts[1]) - } else { - abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") - return - } + err = SetupContextForToken(c, token, parts...) + if err != nil { + return } c.Next() } } + +func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) error { + if token == nil { + return fmt.Errorf("token is nil") + } + c.Set("id", token.UserId) + c.Set("token_id", token.Id) + c.Set("token_key", token.Key) + c.Set("token_name", token.Name) + c.Set("token_unlimited_quota", token.UnlimitedQuota) + if !token.UnlimitedQuota { + c.Set("token_quota", token.RemainQuota) + } + if token.ModelLimitsEnabled { + c.Set("token_model_limit_enabled", true) + c.Set("token_model_limit", token.GetModelLimitsMap()) + } else { + c.Set("token_model_limit_enabled", false) + } + c.Set("allow_ips", token.GetIpLimitsMap()) + c.Set("token_group", token.Group) + if len(parts) > 1 { + if model.IsAdmin(token.UserId) { + c.Set("specific_channel_id", parts[1]) + } else { + abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") + return fmt.Errorf("普通用户不支持指定渠道") + } + } + return nil +} diff --git a/middleware/distributor.go b/middleware/distributor.go index 642b5253..a6889e39 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -12,6 +12,7 @@ import ( "one-api/service" "one-api/setting" "one-api/setting/ratio_setting" + "one-api/types" "strconv" "strings" "time" @@ -21,6 +22,7 @@ import ( type ModelRequest struct { Model string `json:"model"` + Group string `json:"group,omitempty"` } func Distribute() func(c *gin.Context) { @@ -237,28 +239,47 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { } c.Set("relay_mode", relayMode) } + if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") { + // playground chat completions + err = common.UnmarshalBodyReusable(c, &modelRequest) + if err != nil { + return nil, false, errors.New("无效的请求, " + err.Error()) + } + common.SetContextKey(c, constant.ContextKeyTokenGroup, modelRequest.Group) + } return &modelRequest, shouldSelectChannel, nil } -func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) { +func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError { c.Set("original_model", modelName) // for retry if channel == nil { - return + return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed) } - c.Set("channel_id", channel.Id) - c.Set("channel_name", channel.Name) + common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id) + common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name) common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type) - c.Set("channel_create_time", channel.CreatedTime) + common.SetContextKey(c, constant.ContextKeyChannelCreateTime, channel.CreatedTime) common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting()) - c.Set("param_override", channel.GetParamOverride()) - if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization { - c.Set("channel_organization", *channel.OpenAIOrganization) + common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride()) + if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" { + common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization) } - c.Set("auto_ban", channel.GetAutoBan()) - c.Set("model_mapping", channel.GetModelMapping()) - c.Set("status_code_mapping", channel.GetStatusCodeMapping()) - c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) - common.SetContextKey(c, constant.ContextKeyBaseUrl, channel.GetBaseURL()) + common.SetContextKey(c, constant.ContextKeyChannelAutoBan, channel.GetAutoBan()) + common.SetContextKey(c, constant.ContextKeyChannelModelMapping, channel.GetModelMapping()) + common.SetContextKey(c, constant.ContextKeyChannelStatusCodeMapping, channel.GetStatusCodeMapping()) + + key, index, newAPIError := channel.GetNextEnabledKey() + if newAPIError != nil { + return newAPIError + } + if channel.ChannelInfo.IsMultiKey { + common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true) + common.SetContextKey(c, constant.ContextKeyChannelMultiKeyIndex, index) + } + // c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key)) + common.SetContextKey(c, constant.ContextKeyChannelKey, key) + common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL()) + // TODO: api_version统一 switch channel.Type { case constant.ChannelTypeAzure: @@ -278,6 +299,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode case constant.ChannelTypeCoze: c.Set("bot_id", channel.Other) } + return nil } // extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名 diff --git a/model/channel.go b/model/channel.go index a5f307ef..b792c87e 100644 --- a/model/channel.go +++ b/model/channel.go @@ -1,9 +1,15 @@ package model import ( + "database/sql/driver" "encoding/json" + "errors" + "fmt" + "math/rand" "one-api/common" + "one-api/constant" "one-api/dto" + "one-api/types" "strings" "sync" @@ -36,8 +42,126 @@ type Channel struct { AutoBan *int `json:"auto_ban" gorm:"default:1"` OtherInfo string `json:"other_info"` Tag *string `json:"tag" gorm:"index"` - Setting *string `json:"setting" gorm:"type:text"` + Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置 ParamOverride *string `json:"param_override" gorm:"type:text"` + // add after v0.8.5 + ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"` +} + +type ChannelInfo struct { + IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式 + MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量 + MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status + MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引 + MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"` +} + +// Value implements driver.Valuer interface +func (c ChannelInfo) Value() (driver.Value, error) { + return common.Marshal(&c) +} + +// Scan implements sql.Scanner interface +func (c *ChannelInfo) Scan(value interface{}) error { + bytesValue, _ := value.([]byte) + return common.Unmarshal(bytesValue, c) +} + +func (channel *Channel) getKeys() []string { + if channel.Key == "" { + return []string{} + } + // use \n to split keys + keys := strings.Split(strings.Trim(channel.Key, "\n"), "\n") + return keys +} + +func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) { + // If not in multi-key mode, return the original key string directly. + if !channel.ChannelInfo.IsMultiKey { + return channel.Key, 0, nil + } + + // Obtain all keys (split by \n) + keys := channel.getKeys() + if len(keys) == 0 { + // No keys available, return error, should disable the channel + return "", 0, types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey) + } + + statusList := channel.ChannelInfo.MultiKeyStatusList + // helper to get key status, default to enabled when missing + getStatus := func(idx int) int { + if statusList == nil { + return common.ChannelStatusEnabled + } + if status, ok := statusList[idx]; ok { + return status + } + return common.ChannelStatusEnabled + } + + // Collect indexes of enabled keys + enabledIdx := make([]int, 0, len(keys)) + for i := range keys { + if getStatus(i) == common.ChannelStatusEnabled { + enabledIdx = append(enabledIdx, i) + } + } + // If no specific status list or none enabled, fall back to first key + if len(enabledIdx) == 0 { + return keys[0], 0, nil + } + + switch channel.ChannelInfo.MultiKeyMode { + case constant.MultiKeyModeRandom: + // Randomly pick one enabled key + selectedIdx := enabledIdx[rand.Intn(len(enabledIdx))] + return keys[selectedIdx], selectedIdx, nil + case constant.MultiKeyModePolling: + // Use channel-specific lock to ensure thread-safe polling + lock := getChannelPollingLock(channel.Id) + lock.Lock() + defer lock.Unlock() + + channelInfo, err := CacheGetChannelInfo(channel.Id) + if err != nil { + return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed) + } + //println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex) + defer func() { + if common.DebugEnabled { + println(fmt.Sprintf("channel %d polling index: %d", channel.Id, channel.ChannelInfo.MultiKeyPollingIndex)) + } + if !common.MemoryCacheEnabled { + _ = channel.SaveChannelInfo() + } else { + // CacheUpdateChannel(channel) + } + }() + // Start from the saved polling index and look for the next enabled key + start := channelInfo.MultiKeyPollingIndex + if start < 0 || start >= len(keys) { + start = 0 + } + for i := 0; i < len(keys); i++ { + idx := (start + i) % len(keys) + if getStatus(idx) == common.ChannelStatusEnabled { + // update polling index for next call (point to the next position) + channel.ChannelInfo.MultiKeyPollingIndex = (idx + 1) % len(keys) + return keys[idx], idx, nil + } + } + // Fallback – should not happen, but return first enabled key + return keys[enabledIdx[0]], enabledIdx[0], nil + default: + // Unknown mode, default to first enabled key (or original key string) + return keys[enabledIdx[0]], enabledIdx[0], nil + } +} + +func (channel *Channel) SaveChannelInfo() error { + return DB.Model(channel).Update("channel_info", channel.ChannelInfo).Error } func (channel *Channel) GetModels() []string { @@ -175,14 +299,20 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([] } func GetChannelById(id int, selectAll bool) (*Channel, error) { - channel := Channel{Id: id} + channel := &Channel{Id: id} var err error = nil if selectAll { - err = DB.First(&channel, "id = ?", id).Error + err = DB.First(channel, "id = ?", id).Error } else { - err = DB.Omit("key").First(&channel, "id = ?", id).Error + err = DB.Omit("key").First(channel, "id = ?", id).Error } - return &channel, err + if err != nil { + return nil, err + } + if channel == nil { + return nil, errors.New("channel not found") + } + return channel, nil } func BatchInsertChannels(channels []Channel) error { @@ -308,48 +438,128 @@ func (channel *Channel) Delete() error { var channelStatusLock sync.Mutex -func UpdateChannelStatusById(id int, status int, reason string) bool { +// channelPollingLocks stores locks for each channel.id to ensure thread-safe polling +var channelPollingLocks sync.Map + +// getChannelPollingLock returns or creates a mutex for the given channel ID +func getChannelPollingLock(channelId int) *sync.Mutex { + if lock, exists := channelPollingLocks.Load(channelId); exists { + return lock.(*sync.Mutex) + } + // Create new lock for this channel + newLock := &sync.Mutex{} + actual, _ := channelPollingLocks.LoadOrStore(channelId, newLock) + return actual.(*sync.Mutex) +} + +// CleanupChannelPollingLocks removes locks for channels that no longer exist +// This is optional and can be called periodically to prevent memory leaks +func CleanupChannelPollingLocks() { + var activeChannelIds []int + DB.Model(&Channel{}).Pluck("id", &activeChannelIds) + + activeChannelSet := make(map[int]bool) + for _, id := range activeChannelIds { + activeChannelSet[id] = true + } + + channelPollingLocks.Range(func(key, value interface{}) bool { + channelId := key.(int) + if !activeChannelSet[channelId] { + channelPollingLocks.Delete(channelId) + } + return true + }) +} + +func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) { + keys := channel.getKeys() + if len(keys) == 0 { + channel.Status = status + } else { + var keyIndex int + for i, key := range keys { + if key == usingKey { + keyIndex = i + break + } + } + if channel.ChannelInfo.MultiKeyStatusList == nil { + channel.ChannelInfo.MultiKeyStatusList = make(map[int]int) + } + if status == common.ChannelStatusEnabled { + delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex) + } else { + channel.ChannelInfo.MultiKeyStatusList[keyIndex] = status + } + if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize { + channel.Status = common.ChannelStatusAutoDisabled + info := channel.GetOtherInfo() + info["status_reason"] = "All keys are disabled" + info["status_time"] = common.GetTimestamp() + channel.SetOtherInfo(info) + } + } +} + +func UpdateChannelStatus(channelId int, usingKey string, status int, reason string) bool { if common.MemoryCacheEnabled { channelStatusLock.Lock() defer channelStatusLock.Unlock() - channelCache, _ := CacheGetChannel(id) - // 如果缓存渠道存在,且状态已是目标状态,直接返回 - if channelCache != nil && channelCache.Status == status { + channelCache, _ := CacheGetChannel(channelId) + if channelCache == nil { return false } - // 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回 - if channelCache == nil && status != common.ChannelStatusEnabled { - return false + if channelCache.ChannelInfo.IsMultiKey { + // 如果是多Key模式,更新缓存中的状态 + handlerMultiKeyUpdate(channelCache, usingKey, status) + //CacheUpdateChannel(channelCache) + //return true + } else { + // 如果缓存渠道存在,且状态已是目标状态,直接返回 + if channelCache.Status == status { + return false + } + // 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回 + if status != common.ChannelStatusEnabled { + return false + } + CacheUpdateChannelStatus(channelId, status) } - CacheUpdateChannelStatus(id, status) } - err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) + + shouldUpdateAbilities := false + defer func() { + if shouldUpdateAbilities { + err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled) + if err != nil { + common.SysError("failed to update ability status: " + err.Error()) + } + } + }() + channel, err := GetChannelById(channelId, true) if err != nil { - common.SysError("failed to update ability status: " + err.Error()) return false - } - channel, err := GetChannelById(id, true) - if err != nil { - // find channel by id error, directly update status - result := DB.Model(&Channel{}).Where("id = ?", id).Update("status", status) - if result.Error != nil { - common.SysError("failed to update channel status: " + result.Error.Error()) - return false - } - if result.RowsAffected == 0 { - return false - } } else { if channel.Status == status { return false } - // find channel by id success, update status and other info - info := channel.GetOtherInfo() - info["status_reason"] = reason - info["status_time"] = common.GetTimestamp() - channel.SetOtherInfo(info) - channel.Status = status + + if channel.ChannelInfo.IsMultiKey { + beforeStatus := channel.Status + handlerMultiKeyUpdate(channel, usingKey, status) + if beforeStatus != channel.Status { + shouldUpdateAbilities = true + } + } else { + info := channel.GetOtherInfo() + info["status_reason"] = reason + info["status_time"] = common.GetTimestamp() + channel.SetOtherInfo(info) + channel.Status = status + shouldUpdateAbilities = true + } err = channel.Save() if err != nil { common.SysError("failed to update channel status: " + err.Error()) @@ -532,6 +742,8 @@ func (channel *Channel) GetSetting() dto.ChannelSettings { err := json.Unmarshal([]byte(*channel.Setting), &setting) if err != nil { common.SysError("failed to unmarshal setting: " + err.Error()) + channel.Setting = nil // 清空设置以避免后续错误 + _ = channel.Save() // 保存修改 } } return setting diff --git a/model/cache.go b/model/channel_cache.go similarity index 63% rename from model/cache.go rename to model/channel_cache.go index 3e5eb4c4..b2451248 100644 --- a/model/cache.go +++ b/model/channel_cache.go @@ -14,8 +14,8 @@ import ( "github.com/gin-gonic/gin" ) -var group2model2channels map[string]map[string][]*Channel -var channelsIDM map[int]*Channel +var group2model2channels map[string]map[string][]int // enabled channel +var channelsIDM map[int]*Channel // all channels include disabled var channelSyncLock sync.RWMutex func InitChannelCache() { @@ -24,7 +24,7 @@ func InitChannelCache() { } newChannelId2channel := make(map[int]*Channel) var channels []*Channel - DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels) + DB.Find(&channels) for _, channel := range channels { newChannelId2channel[channel.Id] = channel } @@ -34,21 +34,22 @@ func InitChannelCache() { for _, ability := range abilities { groups[ability.Group] = true } - newGroup2model2channels := make(map[string]map[string][]*Channel) - newChannelsIDM := make(map[int]*Channel) + newGroup2model2channels := make(map[string]map[string][]int) for group := range groups { - newGroup2model2channels[group] = make(map[string][]*Channel) + newGroup2model2channels[group] = make(map[string][]int) } for _, channel := range channels { - newChannelsIDM[channel.Id] = channel + if channel.Status != common.ChannelStatusEnabled { + continue // skip disabled channels + } groups := strings.Split(channel.Group, ",") for _, group := range groups { models := strings.Split(channel.Models, ",") for _, model := range models { if _, ok := newGroup2model2channels[group][model]; !ok { - newGroup2model2channels[group][model] = make([]*Channel, 0) + newGroup2model2channels[group][model] = make([]int, 0) } - newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel) + newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel.Id) } } } @@ -57,7 +58,7 @@ func InitChannelCache() { for group, model2channels := range newGroup2model2channels { for model, channels := range model2channels { sort.Slice(channels, func(i, j int) bool { - return channels[i].GetPriority() > channels[j].GetPriority() + return newChannelId2channel[channels[i]].GetPriority() > newChannelId2channel[channels[j]].GetPriority() }) newGroup2model2channels[group][model] = channels } @@ -65,7 +66,7 @@ func InitChannelCache() { channelSyncLock.Lock() group2model2channels = newGroup2model2channels - channelsIDM = newChannelsIDM + channelsIDM = newChannelId2channel channelSyncLock.Unlock() common.SysLog("channels synced from database") } @@ -128,16 +129,27 @@ func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, } channelSyncLock.RLock() + defer channelSyncLock.RUnlock() channels := group2model2channels[group][model] - channelSyncLock.RUnlock() if len(channels) == 0 { return nil, errors.New("channel not found") } + if len(channels) == 1 { + if channel, ok := channelsIDM[channels[0]]; ok { + return channel, nil + } + return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channels[0]) + } + uniquePriorities := make(map[int]bool) - for _, channel := range channels { - uniquePriorities[int(channel.GetPriority())] = true + for _, channelId := range channels { + if channel, ok := channelsIDM[channelId]; ok { + uniquePriorities[int(channel.GetPriority())] = true + } else { + return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId) + } } var sortedUniquePriorities []int for priority := range uniquePriorities { @@ -152,9 +164,13 @@ func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, // get the priority for the given retry number var targetChannels []*Channel - for _, channel := range channels { - if channel.GetPriority() == targetPriority { - targetChannels = append(targetChannels, channel) + for _, channelId := range channels { + if channel, ok := channelsIDM[channelId]; ok { + if channel.GetPriority() == targetPriority { + targetChannels = append(targetChannels, channel) + } + } else { + return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId) } } @@ -188,11 +204,35 @@ func CacheGetChannel(id int) (*Channel, error) { c, ok := channelsIDM[id] if !ok { - return nil, errors.New(fmt.Sprintf("当前渠道# %d,已不存在", id)) + return nil, fmt.Errorf("渠道# %d,已不存在", id) + } + if c.Status != common.ChannelStatusEnabled { + return nil, fmt.Errorf("渠道# %d,已被禁用", id) } return c, nil } +func CacheGetChannelInfo(id int) (*ChannelInfo, error) { + if !common.MemoryCacheEnabled { + channel, err := GetChannelById(id, true) + if err != nil { + return nil, err + } + return &channel.ChannelInfo, nil + } + channelSyncLock.RLock() + defer channelSyncLock.RUnlock() + + c, ok := channelsIDM[id] + if !ok { + return nil, fmt.Errorf("渠道# %d,已不存在", id) + } + if c.Status != common.ChannelStatusEnabled { + return nil, fmt.Errorf("渠道# %d,已被禁用", id) + } + return &c.ChannelInfo, nil +} + func CacheUpdateChannelStatus(id int, status int) { if !common.MemoryCacheEnabled { return @@ -203,3 +243,20 @@ func CacheUpdateChannelStatus(id int, status int) { channel.Status = status } } + +func CacheUpdateChannel(channel *Channel) { + if !common.MemoryCacheEnabled { + return + } + channelSyncLock.Lock() + defer channelSyncLock.Unlock() + if channel == nil { + return + } + + println("CacheUpdateChannel:", channel.Id, channel.Name, channel.Status, channel.ChannelInfo.MultiKeyPollingIndex) + + println("before:", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex) + channelsIDM[channel.Id] = channel + println("after :", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex) +} diff --git a/model/log.go b/model/log.go index e2d1ee5a..45923075 100644 --- a/model/log.go +++ b/model/log.go @@ -49,7 +49,7 @@ func formatUserLogs(logs []*Log) { for i := range logs { logs[i].ChannelName = "" var otherMap map[string]interface{} - otherMap = common.StrToMap(logs[i].Other) + otherMap, _ = common.StrToMap(logs[i].Other) if otherMap != nil { // delete admin delete(otherMap, "admin_info") diff --git a/model/main.go b/model/main.go index d46a21cf..e2f9aecb 100644 --- a/model/main.go +++ b/model/main.go @@ -57,7 +57,7 @@ func initCol() { } } // log sql type and database type - common.SysLog("Using Log SQL Type: " + common.LogSqlType) + //common.SysLog("Using Log SQL Type: " + common.LogSqlType) } var DB *gorm.DB @@ -225,12 +225,6 @@ func InitLogDB() (err error) { if !common.IsMasterNode { return nil } - //if common.UsingMySQL { - // _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded - // _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded - // _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded - // _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded - //} common.SysLog("database migration started") err = migrateLOGDB() return err diff --git a/model/user_cache.go b/model/user_cache.go index a62d9773..a631457c 100644 --- a/model/user_cache.go +++ b/model/user_cache.go @@ -1,7 +1,6 @@ package model import ( - "encoding/json" "fmt" "one-api/common" "one-api/constant" @@ -36,7 +35,7 @@ func (user *UserBase) WriteContext(c *gin.Context) { func (user *UserBase) GetSetting() dto.UserSetting { setting := dto.UserSetting{} if user.Setting != "" { - err := json.Unmarshal([]byte(user.Setting), &setting) + err := common.Unmarshal([]byte(user.Setting), &setting) if err != nil { common.SysError("failed to unmarshal setting: " + err.Error()) } diff --git a/relay/audio_handler.go b/relay/audio_handler.go index c1ce1a02..f39dbd82 100644 --- a/relay/audio_handler.go +++ b/relay/audio_handler.go @@ -3,7 +3,6 @@ package relay import ( "errors" "fmt" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/dto" @@ -12,7 +11,10 @@ import ( "one-api/relay/helper" "one-api/service" "one-api/setting" + "one-api/types" "strings" + + "github.com/gin-gonic/gin" ) func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) { @@ -54,13 +56,13 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. return audioRequest, nil } -func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { +func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) { relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c) audioRequest, err := getAndValidAudioRequest(c, relayInfo) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error())) - return service.OpenAIErrorWrapper(err, "invalid_audio_request", http.StatusBadRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest) } promptTokens := 0 @@ -73,7 +75,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeModelPriceError) } preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) @@ -88,23 +90,23 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { err = helper.ModelMappedHelper(c, relayInfo, audioRequest) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError) } adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) } adaptor.Init(relayInfo) ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeConvertRequestFailed) } resp, err := adaptor.DoRequest(c, relayInfo, ioReader) if err != nil { - return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeDoRequestFailed) } statusCodeMappingStr := c.GetString("status_code_mapping") @@ -112,18 +114,18 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { if resp != nil { httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { - openaiErr = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(httpResp, false) // reset status code 重置状态码 - service.ResetStatusCode(openaiErr, statusCodeMappingStr) - return openaiErr + service.ResetStatusCode(newAPIError, statusCodeMappingStr) + return newAPIError } } - usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo) - if openaiErr != nil { + usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + if newAPIError != nil { // reset status code 重置状态码 - service.ResetStatusCode(openaiErr, statusCodeMappingStr) - return openaiErr + service.ResetStatusCode(newAPIError, statusCodeMappingStr) + return newAPIError } postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index 2ff34e01..ab8836ba 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -5,6 +5,7 @@ import ( "net/http" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/types" "github.com/gin-gonic/gin" ) @@ -21,7 +22,7 @@ type Adaptor interface { ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) - DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) + DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) GetModelList() []string GetChannelName() string ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 63525cc4..d941a1bc 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -10,6 +10,7 @@ import ( "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/constant" + "one-api/types" "github.com/gin-gonic/gin" ) @@ -99,7 +100,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { switch info.RelayMode { case constant.RelayModeImagesGenerations: err, usage = aliImageHandler(c, resp, info) @@ -109,9 +110,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom err, usage = RerankHandler(c, resp, info) default: if info.IsStream { - err, usage = openai.OaiStreamHandler(c, resp, info) + usage, err = openai.OaiStreamHandler(c, info, resp) } else { - err, usage = openai.OpenaiHandler(c, resp, info) + usage, err = openai.OpenaiHandler(c, info, resp) } } return diff --git a/relay/channel/ali/image.go b/relay/channel/ali/image.go index c84c7885..0d430c62 100644 --- a/relay/channel/ali/image.go +++ b/relay/channel/ali/image.go @@ -4,15 +4,17 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" "one-api/service" + "one-api/types" "strings" "time" + + "github.com/gin-gonic/gin" ) func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest { @@ -124,49 +126,46 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc return &imageResponse } -func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) { responseFormat := c.GetString("response_format") var aliTaskResponse AliResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil } common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &aliTaskResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeBadResponseBody), nil } if aliTaskResponse.Message != "" { common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message) - return service.OpenAIErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil + return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil } aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId) if err != nil { - return service.OpenAIErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeBadResponse), nil } if aliResponse.Output.TaskStatus != "SUCCEEDED" { - return &dto.OpenAIErrorWithStatusCode{ - Error: dto.OpenAIError{ - Message: aliResponse.Output.Message, - Type: "ali_error", - Param: "", - Code: aliResponse.Output.Code, - }, - StatusCode: resp.StatusCode, - }, nil + return types.WithOpenAIError(types.OpenAIError{ + Message: aliResponse.Output.Message, + Type: "ali_error", + Param: "", + Code: aliResponse.Output.Code, + }, resp.StatusCode), nil } fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat) jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeBadResponseBody), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, nil + c.Writer.Write(jsonResponse) + return nil, &dto.Usage{} } diff --git a/relay/channel/ali/rerank.go b/relay/channel/ali/rerank.go index ebfe26de..59cb0a11 100644 --- a/relay/channel/ali/rerank.go +++ b/relay/channel/ali/rerank.go @@ -7,7 +7,7 @@ import ( "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" - "one-api/service" + "one-api/types" "github.com/gin-gonic/gin" ) @@ -31,29 +31,26 @@ func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest { } } -func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil } common.CloseResponseBodyGracefully(resp) var aliResponse AliRerankResponse err = json.Unmarshal(responseBody, &aliResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeBadResponseBody), nil } if aliResponse.Code != "" { - return &dto.OpenAIErrorWithStatusCode{ - Error: dto.OpenAIError{ - Message: aliResponse.Message, - Type: aliResponse.Code, - Param: aliResponse.RequestId, - Code: aliResponse.Code, - }, - StatusCode: resp.StatusCode, - }, nil + return types.WithOpenAIError(types.OpenAIError{ + Message: aliResponse.Message, + Type: aliResponse.Code, + Param: aliResponse.RequestId, + Code: aliResponse.Code, + }, resp.StatusCode), nil } usage := dto.Usage{ @@ -68,14 +65,10 @@ func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI jsonResponse, err := json.Marshal(rerankResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeBadResponseBody), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - if err != nil { - return service.OpenAIErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil - } - + c.Writer.Write(jsonResponse) return nil, &usage } diff --git a/relay/channel/ali/text.go b/relay/channel/ali/text.go index 149c9b4b..bc49501c 100644 --- a/relay/channel/ali/text.go +++ b/relay/channel/ali/text.go @@ -8,9 +8,10 @@ import ( "one-api/common" "one-api/dto" "one-api/relay/helper" - "one-api/service" "strings" + "one-api/types" + "github.com/gin-gonic/gin" ) @@ -38,11 +39,11 @@ func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingReque } } -func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) { var fullTextResponse dto.OpenAIEmbeddingResponse err := json.NewDecoder(resp.Body).Decode(&fullTextResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeBadResponseBody), nil } common.CloseResponseBodyGracefully(resp) @@ -53,11 +54,11 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW } jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeBadResponseBody), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) + c.Writer.Write(jsonResponse) return nil, &fullTextResponse.Usage } @@ -119,7 +120,7 @@ func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStre return &response } -func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) { var usage dto.Usage scanner := bufio.NewScanner(resp.Body) scanner.Split(bufio.ScanLines) @@ -174,32 +175,29 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith return nil, &usage } -func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) { var aliResponse AliResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil } common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &aliResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeBadResponseBody), nil } if aliResponse.Code != "" { - return &dto.OpenAIErrorWithStatusCode{ - Error: dto.OpenAIError{ - Message: aliResponse.Message, - Type: aliResponse.Code, - Param: aliResponse.RequestId, - Code: aliResponse.Code, - }, - StatusCode: resp.StatusCode, - }, nil + return types.WithOpenAIError(types.OpenAIError{ + Message: aliResponse.Message, + Type: "ali_error", + Param: aliResponse.RequestId, + Code: aliResponse.Code, + }, resp.StatusCode), nil } fullTextResponse := responseAli2OpenAI(&aliResponse) - jsonResponse, err := json.Marshal(fullTextResponse) + jsonResponse, err := common.Marshal(fullTextResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeBadResponseBody), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go index 9c879399..d3354f00 100644 --- a/relay/channel/aws/adaptor.go +++ b/relay/channel/aws/adaptor.go @@ -8,6 +8,7 @@ import ( "one-api/relay/channel/claude" relaycommon "one-api/relay/common" "one-api/setting/model_setting" + "one-api/types" "github.com/gin-gonic/gin" ) @@ -84,7 +85,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return nil, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.IsStream { err, usage = awsStreamHandler(c, resp, info, a.RequestMode) } else { diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index 3c9542c6..0df19e07 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -3,19 +3,22 @@ package aws import ( "encoding/json" "fmt" - "github.com/gin-gonic/gin" - "github.com/pkg/errors" "net/http" "one-api/common" "one-api/dto" "one-api/relay/channel/claude" relaycommon "one-api/relay/common" + "one-api/relay/helper" + "one-api/types" "strings" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" - "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" + bedrockruntimeTypes "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" ) func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) { @@ -65,24 +68,21 @@ func awsModelCrossRegion(awsModelId, awsRegionPrefix string) string { return modelPrefix + "." + awsModelId } -func awsModelID(requestModel string) (string, error) { +func awsModelID(requestModel string) string { if awsModelID, ok := awsModelIDMap[requestModel]; ok { - return awsModelID, nil + return awsModelID } - return requestModel, nil + return requestModel } -func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) { awsCli, err := newAwsClient(c, info) if err != nil { - return wrapErr(errors.Wrap(err, "newAwsClient")), nil + return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil } - awsModelId, err := awsModelID(c.GetString("request_model")) - if err != nil { - return wrapErr(errors.Wrap(err, "awsModelID")), nil - } + awsModelId := awsModelID(c.GetString("request_model")) awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region) canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix) @@ -98,42 +98,42 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (* claudeReq_, ok := c.Get("converted_request") if !ok { - return wrapErr(errors.New("request not found")), nil + return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil } claudeReq := claudeReq_.(*dto.ClaudeRequest) awsClaudeReq := copyRequest(claudeReq) awsReq.Body, err = json.Marshal(awsClaudeReq) if err != nil { - return wrapErr(errors.Wrap(err, "marshal request")), nil + return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil } awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq) if err != nil { - return wrapErr(errors.Wrap(err, "InvokeModel")), nil + return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil } claudeInfo := &claude.ClaudeResponseInfo{ - ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + ResponseId: helper.GetResponseID(c), Created: common.GetTimestamp(), Model: info.UpstreamModelName, ResponseText: strings.Builder{}, Usage: &dto.Usage{}, } - claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage) + handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage) + if handlerErr != nil { + return handlerErr, nil + } return nil, claudeInfo.Usage } -func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) { awsCli, err := newAwsClient(c, info) if err != nil { - return wrapErr(errors.Wrap(err, "newAwsClient")), nil + return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil } - awsModelId, err := awsModelID(c.GetString("request_model")) - if err != nil { - return wrapErr(errors.Wrap(err, "awsModelID")), nil - } + awsModelId := awsModelID(c.GetString("request_model")) awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region) canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix) @@ -149,25 +149,25 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel claudeReq_, ok := c.Get("converted_request") if !ok { - return wrapErr(errors.New("request not found")), nil + return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil } claudeReq := claudeReq_.(*dto.ClaudeRequest) awsClaudeReq := copyRequest(claudeReq) awsReq.Body, err = json.Marshal(awsClaudeReq) if err != nil { - return wrapErr(errors.Wrap(err, "marshal request")), nil + return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil } awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq) if err != nil { - return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil + return types.NewError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeChannelAwsClientError), nil } stream := awsResp.GetStream() defer stream.Close() claudeInfo := &claude.ClaudeResponseInfo{ - ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + ResponseId: helper.GetResponseID(c), Created: common.GetTimestamp(), Model: info.UpstreamModelName, ResponseText: strings.Builder{}, @@ -176,18 +176,18 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel for event := range stream.Events() { switch v := event.(type) { - case *types.ResponseStreamMemberChunk: + case *bedrockruntimeTypes.ResponseStreamMemberChunk: info.SetFirstResponseTime() respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage) if respErr != nil { return respErr, nil } - case *types.UnknownUnionMember: + case *bedrockruntimeTypes.UnknownUnionMember: fmt.Println("unknown tag:", v.Tag) - return wrapErr(errors.New("unknown response type")), nil + return types.NewError(errors.New("unknown response type"), types.ErrorCodeInvalidRequest), nil default: fmt.Println("union is nil or unknown type") - return wrapErr(errors.New("nil or unknown response type")), nil + return types.NewError(errors.New("nil or unknown response type"), types.ErrorCodeInvalidRequest), nil } } diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index 396c31ab..22443354 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -9,6 +9,7 @@ import ( "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/relay/constant" + "one-api/types" "strings" "github.com/gin-gonic/gin" @@ -140,15 +141,15 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.IsStream { - err, usage = baiduStreamHandler(c, resp) + err, usage = baiduStreamHandler(c, info, resp) } else { switch info.RelayMode { case constant.RelayModeEmbeddings: - err, usage = baiduEmbeddingHandler(c, resp) + err, usage = baiduEmbeddingHandler(c, info, resp) default: - err, usage = baiduHandler(c, resp) + err, usage = baiduHandler(c, info, resp) } } return diff --git a/relay/channel/baidu/relay-baidu.go b/relay/channel/baidu/relay-baidu.go index 11492fe3..06b48c20 100644 --- a/relay/channel/baidu/relay-baidu.go +++ b/relay/channel/baidu/relay-baidu.go @@ -1,21 +1,23 @@ package baidu import ( - "bufio" "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "one-api/constant" "one-api/dto" + relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" + "one-api/types" "strings" "sync" "time" + + "github.com/gin-gonic/gin" ) // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 @@ -110,92 +112,49 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAI return &openAIEmbeddingResponse } -func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - var usage dto.Usage - scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 6 { // ignore blank line or wrong format - continue - } - data = data[6:] - dataChan <- data - } - stopChan <- true - }() - helper.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var baiduResponse BaiduChatStreamResponse - err := json.Unmarshal([]byte(data), &baiduResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - if baiduResponse.Usage.TotalTokens != 0 { - usage.TotalTokens = baiduResponse.Usage.TotalTokens - usage.PromptTokens = baiduResponse.Usage.PromptTokens - usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens - } - response := streamResponseBaidu2OpenAI(&baiduResponse) - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) +func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) { + usage := &dto.Usage{} + helper.StreamScannerHandler(c, resp, info, func(data string) bool { + var baiduResponse BaiduChatStreamResponse + err := common.Unmarshal([]byte(data), &baiduResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false } + if baiduResponse.Usage.TotalTokens != 0 { + usage.TotalTokens = baiduResponse.Usage.TotalTokens + usage.PromptTokens = baiduResponse.Usage.PromptTokens + usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens + } + response := streamResponseBaidu2OpenAI(&baiduResponse) + err = helper.ObjectData(c, response) + if err != nil { + common.SysError("error sending stream response: " + err.Error()) + } + return true }) common.CloseResponseBodyGracefully(resp) - return nil, &usage + return nil, usage } -func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) { var baiduResponse BaiduChatResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeBadResponseBody), nil } common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeBadResponseBody), nil } if baiduResponse.ErrorMsg != "" { - return &dto.OpenAIErrorWithStatusCode{ - Error: dto.OpenAIError{ - Message: baiduResponse.ErrorMsg, - Type: "baidu_error", - Param: "", - Code: baiduResponse.ErrorCode, - }, - StatusCode: resp.StatusCode, - }, nil + return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil } fullTextResponse := responseBaidu2OpenAI(&baiduResponse) jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeBadResponseBody), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) @@ -203,32 +162,24 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat return nil, &fullTextResponse.Usage } -func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) { var baiduResponse BaiduEmbeddingResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeBadResponseBody), nil } common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeBadResponseBody), nil } if baiduResponse.ErrorMsg != "" { - return &dto.OpenAIErrorWithStatusCode{ - Error: dto.OpenAIError{ - Message: baiduResponse.ErrorMsg, - Type: "baidu_error", - Param: "", - Code: baiduResponse.ErrorCode, - }, - StatusCode: resp.StatusCode, - }, nil + return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil } fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse) jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeBadResponseBody), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) diff --git a/relay/channel/baidu_v2/adaptor.go b/relay/channel/baidu_v2/adaptor.go index 470f2a0c..375fd531 100644 --- a/relay/channel/baidu_v2/adaptor.go +++ b/relay/channel/baidu_v2/adaptor.go @@ -9,6 +9,7 @@ import ( "one-api/relay/channel" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + "one-api/types" "strings" "github.com/gin-gonic/gin" @@ -92,11 +93,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.IsStream { - err, usage = openai.OaiStreamHandler(c, resp, info) + usage, err = openai.OaiStreamHandler(c, info, resp) } else { - err, usage = openai.OpenaiHandler(c, resp, info) + usage, err = openai.OpenaiHandler(c, info, resp) } return } diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 8389b9f1..540742d6 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -9,6 +9,7 @@ import ( "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/setting/model_setting" + "one-api/types" "strings" "github.com/gin-gonic/gin" @@ -94,7 +95,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.IsStream { err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode) } else { diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index a8607d86..d03c61c2 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -12,6 +12,7 @@ import ( "one-api/relay/helper" "one-api/service" "one-api/setting/model_setting" + "one-api/types" "strings" "github.com/gin-gonic/gin" @@ -125,7 +126,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla if textRequest.Reasoning != nil { var reasoning openrouter.RequestReasoning - if err := common.UnmarshalJson(textRequest.Reasoning, &reasoning); err != nil { + if err := common.Unmarshal(textRequest.Reasoning, &reasoning); err != nil { return nil, err } @@ -517,22 +518,15 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons return true } -func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode { +func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *types.NewAPIError { var claudeResponse dto.ClaudeResponse err := common.UnmarshalJsonStr(data, &claudeResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) - return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeBadResponseBody) } if claudeResponse.Error != nil && claudeResponse.Error.Type != "" { - return &dto.OpenAIErrorWithStatusCode{ - Error: dto.OpenAIError{ - Code: "stream_response_error", - Type: claudeResponse.Error.Type, - Message: claudeResponse.Error.Message, - }, - StatusCode: http.StatusInternalServerError, - } + return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError) } if info.RelayFormat == relaycommon.RelayFormatClaude { FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo) @@ -593,15 +587,15 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau } } -func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) { claudeInfo := &ClaudeResponseInfo{ - ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + ResponseId: helper.GetResponseID(c), Created: common.GetTimestamp(), Model: info.UpstreamModelName, ResponseText: strings.Builder{}, Usage: &dto.Usage{}, } - var err *dto.OpenAIErrorWithStatusCode + var err *types.NewAPIError helper.StreamScannerHandler(c, resp, info, func(data string) bool { err = HandleStreamResponseData(c, info, claudeInfo, data, requestMode) if err != nil { @@ -617,21 +611,14 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. return nil, claudeInfo.Usage } -func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode { +func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *types.NewAPIError { var claudeResponse dto.ClaudeResponse - err := common.UnmarshalJson(data, &claudeResponse) + err := common.Unmarshal(data, &claudeResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeBadResponseBody) } if claudeResponse.Error != nil && claudeResponse.Error.Type != "" { - return &dto.OpenAIErrorWithStatusCode{ - Error: dto.OpenAIError{ - Message: claudeResponse.Error.Message, - Type: claudeResponse.Error.Type, - Code: claudeResponse.Error.Type, - }, - StatusCode: http.StatusInternalServerError, - } + return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError) } if requestMode == RequestModeCompletion { completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName) @@ -652,7 +639,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud openaiResponse.Usage = *claudeInfo.Usage responseData, err = json.Marshal(openaiResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeBadResponseBody) } case relaycommon.RelayFormatClaude: responseData = data @@ -662,11 +649,11 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud return nil } -func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) { defer common.CloseResponseBodyGracefully(resp) claudeInfo := &ClaudeResponseInfo{ - ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + ResponseId: helper.GetResponseID(c), Created: common.GetTimestamp(), Model: info.UpstreamModelName, ResponseText: strings.Builder{}, @@ -674,7 +661,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r } responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeBadResponseBody), nil } if common.DebugEnabled { println("responseBody: ", string(responseBody)) diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go index 06f4ca34..6e59ad71 100644 --- a/relay/channel/cloudflare/adaptor.go +++ b/relay/channel/cloudflare/adaptor.go @@ -10,6 +10,7 @@ import ( "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/relay/constant" + "one-api/types" "github.com/gin-gonic/gin" ) @@ -94,20 +95,20 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf return nil, errors.New("not implemented") } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { switch info.RelayMode { case constant.RelayModeEmbeddings: fallthrough case constant.RelayModeChatCompletions: if info.IsStream { - err, usage = cfStreamHandler(c, resp, info) + err, usage = cfStreamHandler(c, info, resp) } else { - err, usage = cfHandler(c, resp, info) + err, usage = cfHandler(c, info, resp) } case constant.RelayModeAudioTranslation: fallthrough case constant.RelayModeAudioTranscription: - err, usage = cfSTTHandler(c, resp, info) + err, usage = cfSTTHandler(c, info, resp) } return } diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go index 1c3a26f7..5e8fe7f9 100644 --- a/relay/channel/cloudflare/relay_cloudflare.go +++ b/relay/channel/cloudflare/relay_cloudflare.go @@ -3,7 +3,6 @@ package cloudflare import ( "bufio" "encoding/json" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" @@ -11,8 +10,11 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" + "one-api/types" "strings" "time" + + "github.com/gin-gonic/gin" ) func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest { @@ -25,7 +27,7 @@ func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfReque } } -func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) { scanner := bufio.NewScanner(resp.Body) scanner.Split(bufio.ScanLines) @@ -86,16 +88,16 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela return nil, usage } -func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeBadResponseBody), nil } common.CloseResponseBodyGracefully(resp) var response dto.TextResponse err = json.Unmarshal(responseBody, &response) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeBadResponseBody), nil } response.Model = info.UpstreamModelName var responseText string @@ -107,7 +109,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) response.Id = helper.GetResponseID(c) jsonResponse, err := json.Marshal(response) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeBadResponseBody), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) @@ -115,16 +117,16 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) return nil, usage } -func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func cfSTTHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) { var cfResp CfAudioResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeBadResponseBody), nil } common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &cfResp) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeBadResponseBody), nil } audioResp := &dto.AudioResponse{ @@ -133,7 +135,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn jsonResponse, err := json.Marshal(audioResp) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeBadResponseBody), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go index a93b10f6..4f3a96c3 100644 --- a/relay/channel/cohere/adaptor.go +++ b/relay/channel/cohere/adaptor.go @@ -9,6 +9,7 @@ import ( "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/relay/constant" + "one-api/types" "github.com/gin-gonic/gin" ) @@ -71,14 +72,14 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.RelayMode == constant.RelayModeRerank { - err, usage = cohereRerankHandler(c, resp, info) + usage, err = cohereRerankHandler(c, resp, info) } else { if info.IsStream { - err, usage = cohereStreamHandler(c, resp, info) + usage, err = cohereStreamHandler(c, info, resp) // TODO: fix this } else { - err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens) + usage, err = cohereHandler(c, info, resp) } } return diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go index 4637740d..fcfb12b7 100644 --- a/relay/channel/cohere/relay-cohere.go +++ b/relay/channel/cohere/relay-cohere.go @@ -3,7 +3,6 @@ package cohere import ( "bufio" "encoding/json" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" @@ -11,8 +10,11 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" + "one-api/types" "strings" "time" + + "github.com/gin-gonic/gin" ) func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest { @@ -76,7 +78,7 @@ func stopReasonCohere2OpenAI(reason string) string { } } -func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseId := helper.GetResponseID(c) createdTime := common.GetTimestamp() usage := &dto.Usage{} @@ -164,20 +166,20 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. if usage.PromptTokens == 0 { usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } - return nil, usage + return usage, nil } -func cohereHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { createdTime := common.GetTimestamp() responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } common.CloseResponseBodyGracefully(resp) var cohereResp CohereResponseResult err = json.Unmarshal(responseBody, &cohereResp) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } usage := dto.Usage{} usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens @@ -188,7 +190,7 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt openaiResp.Id = cohereResp.ResponseId openaiResp.Created = createdTime openaiResp.Object = "chat.completion" - openaiResp.Model = modelName + openaiResp.Model = info.UpstreamModelName openaiResp.Usage = usage openaiResp.Choices = []dto.OpenAITextResponseChoice{ @@ -201,24 +203,24 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt jsonResponse, err := json.Marshal(openaiResp) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &usage + _, _ = c.Writer.Write(jsonResponse) + return &usage, nil } -func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } common.CloseResponseBodyGracefully(resp) var cohereResp CohereRerankResponseResult err = json.Unmarshal(responseBody, &cohereResp) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } usage := dto.Usage{} if cohereResp.Meta.BilledUnits.InputTokens == 0 { @@ -237,10 +239,10 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon. jsonResponse, err := json.Marshal(rerankResp) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) _, err = c.Writer.Write(jsonResponse) - return nil, &usage + return &usage, nil } diff --git a/relay/channel/coze/adaptor.go b/relay/channel/coze/adaptor.go index 80441a51..fe5f5f00 100644 --- a/relay/channel/coze/adaptor.go +++ b/relay/channel/coze/adaptor.go @@ -9,6 +9,7 @@ import ( "one-api/dto" "one-api/relay/channel" "one-api/relay/common" + "one-api/types" "time" "github.com/gin-gonic/gin" @@ -95,11 +96,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody } // DoResponse implements channel.Adaptor. -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *types.NewAPIError) { if info.IsStream { - err, usage = cozeChatStreamHandler(c, resp, info) + usage, err = cozeChatStreamHandler(c, info, resp) } else { - err, usage = cozeChatHandler(c, resp, info) + usage, err = cozeChatHandler(c, info, resp) } return } diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 618fe16f..32cc6937 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -12,6 +12,7 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" + "one-api/types" "strings" "github.com/gin-gonic/gin" @@ -43,10 +44,10 @@ func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *C return cozeRequest } -func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } common.CloseResponseBodyGracefully(resp) // convert coze response to openai response @@ -55,10 +56,10 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela response.Model = info.UpstreamModelName err = json.Unmarshal(responseBody, &cozeResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } if cozeResponse.Code != 0 { - return service.OpenAIErrorWrapper(errors.New(cozeResponse.Msg), fmt.Sprintf("%d", cozeResponse.Code), http.StatusInternalServerError), nil + return nil, types.NewError(errors.New(cozeResponse.Msg), types.ErrorCodeBadResponseBody) } // 从上下文获取 usage var usage dto.Usage @@ -85,16 +86,16 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela } jsonResponse, err := json.Marshal(response) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) _, _ = c.Writer.Write(jsonResponse) - return nil, &usage + return &usage, nil } -func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { scanner := bufio.NewScanner(resp.Body) scanner.Split(bufio.ScanLines) helper.SetEventStreamHeaders(c) @@ -135,7 +136,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo } if err := scanner.Err(); err != nil { - return service.OpenAIErrorWrapper(err, "stream_scanner_error", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } helper.Done(c) @@ -143,7 +144,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count")) } - return nil, usage + return usage, nil } func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) { diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go index 76e7fa8d..edfc7fd3 100644 --- a/relay/channel/deepseek/adaptor.go +++ b/relay/channel/deepseek/adaptor.go @@ -10,6 +10,7 @@ import ( "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/constant" + "one-api/types" "strings" "github.com/gin-gonic/gin" @@ -81,11 +82,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.IsStream { - err, usage = openai.OaiStreamHandler(c, resp, info) + usage, err = openai.OaiStreamHandler(c, info, resp) } else { - err, usage = openai.OpenaiHandler(c, resp, info) + usage, err = openai.OpenaiHandler(c, info, resp) } return } diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go index 51dbee71..4ad16766 100644 --- a/relay/channel/dify/adaptor.go +++ b/relay/channel/dify/adaptor.go @@ -8,6 +8,7 @@ import ( "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" + "one-api/types" "github.com/gin-gonic/gin" ) @@ -96,11 +97,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.IsStream { - err, usage = difyStreamHandler(c, resp, info) + return difyStreamHandler(c, info, resp) } else { - err, usage = difyHandler(c, resp, info) + return difyHandler(c, info, resp) } return } diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go index 3a2845b3..47337127 100644 --- a/relay/channel/dify/relay-dify.go +++ b/relay/channel/dify/relay-dify.go @@ -14,6 +14,7 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" + "one-api/types" "os" "strings" @@ -209,7 +210,7 @@ func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dt return &response } -func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { var responseText string usage := &dto.Usage{} var nodeToken int @@ -247,20 +248,20 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } usage.CompletionTokens += nodeToken - return nil, usage + return usage, nil } -func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func difyHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { var difyResponse DifyChatCompletionResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &difyResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } fullTextResponse := dto.OpenAITextResponse{ Id: difyResponse.ConversationId, @@ -279,10 +280,10 @@ func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInf fullTextResponse.Choices = append(fullTextResponse.Choices, choice) jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &difyResponse.MetaData.Usage + c.Writer.Write(jsonResponse) + return &difyResponse.MetaData.Usage, nil } diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 968d9c9b..71eb9ba4 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -11,8 +11,8 @@ import ( "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/relay/constant" - "one-api/service" "one-api/setting/model_setting" + "one-api/types" "strings" "github.com/gin-gonic/gin" @@ -168,30 +168,30 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.RelayMode == constant.RelayModeGemini { if info.IsStream { - return GeminiTextGenerationStreamHandler(c, resp, info) + return GeminiTextGenerationStreamHandler(c, info, resp) } else { - return GeminiTextGenerationHandler(c, resp, info) + return GeminiTextGenerationHandler(c, info, resp) } } if strings.HasPrefix(info.UpstreamModelName, "imagen") { - return GeminiImageHandler(c, resp, info) + return GeminiImageHandler(c, info, resp) } // check if the model is an embedding model if strings.HasPrefix(info.UpstreamModelName, "text-embedding") || strings.HasPrefix(info.UpstreamModelName, "embedding") || strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") { - return GeminiEmbeddingHandler(c, resp, info) + return GeminiEmbeddingHandler(c, info, resp) } if info.IsStream { - err, usage = GeminiChatStreamHandler(c, resp, info) + return GeminiChatStreamHandler(c, info, resp) } else { - err, usage = GeminiChatHandler(c, resp, info) + return GeminiChatHandler(c, info, resp) } //if usage.(*dto.Usage).CompletionTokenDetails.ReasoningTokens > 100 { @@ -205,23 +205,23 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom // } //} - return + return nil, types.NewError(errors.New("not implemented"), types.ErrorCodeBadResponseBody) } -func GeminiImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseBody, readErr := io.ReadAll(resp.Body) if readErr != nil { - return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError) + return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody) } _ = resp.Body.Close() var geminiResponse GeminiImageResponse if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { - return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError) + return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody) } if len(geminiResponse.Predictions) == 0 { - return nil, service.OpenAIErrorWrapper(errors.New("no images generated"), "no_images", http.StatusBadRequest) + return nil, types.NewError(errors.New("no images generated"), types.ErrorCodeBadResponseBody) } // convert to openai format response @@ -241,7 +241,7 @@ func GeminiImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.R jsonResponse, jsonErr := json.Marshal(openAIResponse) if jsonErr != nil { - return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError) + return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody) } c.Writer.Header().Set("Content-Type", "application/json") @@ -253,7 +253,7 @@ func GeminiImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.R const imageTokens = 258 generatedImages := len(openAIResponse.Data) - usage = &dto.Usage{ + usage := &dto.Usage{ PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens CompletionTokens: 0, // image generation does not calculate completion tokens TotalTokens: imageTokens * generatedImages, diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index 52846c66..0870e3fa 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -8,18 +8,19 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" + "one-api/types" "strings" "github.com/gin-gonic/gin" ) -func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) { +func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { defer common.CloseResponseBodyGracefully(resp) // 读取响应体 responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } if common.DebugEnabled { @@ -28,9 +29,9 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela // 解析为 Gemini 原生响应格式 var geminiResponse GeminiChatResponse - err = common.UnmarshalJson(responseBody, &geminiResponse) + err = common.Unmarshal(responseBody, &geminiResponse) if err != nil { - return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } // 计算使用量(基于 UsageMetadata) @@ -51,9 +52,9 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela } // 直接返回 Gemini 原生格式的 JSON 响应 - jsonResponse, err := common.EncodeJson(geminiResponse) + jsonResponse, err := common.Marshal(geminiResponse) if err != nil { - return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } common.IOCopyBytesGracefully(c, resp, jsonResponse) @@ -61,7 +62,7 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela return &usage, nil } -func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) { +func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { var usage = &dto.Usage{} var imageCount int diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 1544e8cf..6f3babeb 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -2,6 +2,7 @@ package gemini import ( "encoding/json" + "errors" "fmt" "io" "net/http" @@ -12,6 +13,7 @@ import ( "one-api/relay/helper" "one-api/service" "one-api/setting/model_setting" + "one-api/types" "strconv" "strings" "unicode/utf8" @@ -792,7 +794,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C return &response, isStop, hasImage } -func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { // responseText := "" id := helper.GetResponseID(c) createAt := common.GetTimestamp() @@ -858,33 +860,25 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom } helper.Done(c) //resp.Body.Close() - return nil, usage + return usage, nil } -func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } common.CloseResponseBodyGracefully(resp) if common.DebugEnabled { println(string(responseBody)) } var geminiResponse GeminiChatResponse - err = common.UnmarshalJson(responseBody, &geminiResponse) + err = common.Unmarshal(responseBody, &geminiResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } if len(geminiResponse.Candidates) == 0 { - return &dto.OpenAIErrorWithStatusCode{ - Error: dto.OpenAIError{ - Message: "No candidates returned", - Type: "server_error", - Param: "", - Code: 500, - }, - StatusCode: resp.StatusCode, - }, nil + return nil, types.NewError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody) } fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse) fullTextResponse.Model = info.UpstreamModelName @@ -908,25 +902,25 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re fullTextResponse.Usage = usage jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &usage + c.Writer.Write(jsonResponse) + return &usage, nil } -func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { defer common.CloseResponseBodyGracefully(resp) responseBody, readErr := io.ReadAll(resp.Body) if readErr != nil { - return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError) + return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody) } var geminiResponse GeminiEmbeddingResponse - if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { - return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError) + if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { + return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody) } // convert to openai format response @@ -947,16 +941,16 @@ func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycomm // Google has not yet clarified how embedding models will be billed // refer to openai billing method to use input tokens billing // https://platform.openai.com/docs/guides/embeddings#what-are-embeddings - usage = &dto.Usage{ + usage := &dto.Usage{ PromptTokens: info.PromptTokens, CompletionTokens: 0, TotalTokens: info.PromptTokens, } - openAIResponse.Usage = *usage.(*dto.Usage) + openAIResponse.Usage = *usage - jsonResponse, jsonErr := common.EncodeJson(openAIResponse) + jsonResponse, jsonErr := common.Marshal(openAIResponse) if jsonErr != nil { - return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError) + return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody) } common.IOCopyBytesGracefully(c, resp, jsonResponse) diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index 85b6a83f..408a5c6e 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -11,6 +11,7 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/common_handler" "one-api/relay/constant" + "one-api/types" "github.com/gin-gonic/gin" ) @@ -73,11 +74,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return request, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.RelayMode == constant.RelayModeRerank { - err, usage = common_handler.RerankHandler(c, info, resp) + usage, err = common_handler.RerankHandler(c, info, resp) } else if info.RelayMode == constant.RelayModeEmbeddings { - err, usage = openai.OpenaiHandler(c, resp, info) + usage, err = openai.OpenaiHandler(c, info, resp) } return } diff --git a/relay/channel/mistral/adaptor.go b/relay/channel/mistral/adaptor.go index 44f57e61..434a1031 100644 --- a/relay/channel/mistral/adaptor.go +++ b/relay/channel/mistral/adaptor.go @@ -8,6 +8,7 @@ import ( "one-api/relay/channel" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + "one-api/types" "github.com/gin-gonic/gin" ) @@ -69,11 +70,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.IsStream { - err, usage = openai.OaiStreamHandler(c, resp, info) + usage, err = openai.OaiStreamHandler(c, info, resp) } else { - err, usage = openai.OpenaiHandler(c, resp, info) + usage, err = openai.OpenaiHandler(c, info, resp) } return } diff --git a/relay/channel/mokaai/adaptor.go b/relay/channel/mokaai/adaptor.go index b889f225..b0b54b0c 100644 --- a/relay/channel/mokaai/adaptor.go +++ b/relay/channel/mokaai/adaptor.go @@ -9,6 +9,7 @@ import ( "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/relay/constant" + "one-api/types" "strings" "github.com/gin-gonic/gin" @@ -84,11 +85,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { switch info.RelayMode { case constant.RelayModeEmbeddings: - err, usage = mokaEmbeddingHandler(c, resp) + return mokaEmbeddingHandler(c, info, resp) default: // err, usage = mokaHandler(c, resp) diff --git a/relay/channel/mokaai/relay-mokaai.go b/relay/channel/mokaai/relay-mokaai.go index 645475dd..78f96d6d 100644 --- a/relay/channel/mokaai/relay-mokaai.go +++ b/relay/channel/mokaai/relay-mokaai.go @@ -2,12 +2,14 @@ package mokaai import ( "encoding/json" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "one-api/dto" - "one-api/service" + relaycommon "one-api/relay/common" + "one-api/types" + + "github.com/gin-gonic/gin" ) func embeddingRequestOpenAI2Moka(request dto.GeneralOpenAIRequest) *dto.EmbeddingRequest { @@ -48,16 +50,16 @@ func embeddingResponseMoka2OpenAI(response *dto.EmbeddingResponse) *dto.OpenAIEm return &openAIEmbeddingResponse } -func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func mokaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { var baiduResponse dto.EmbeddingResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } // if baiduResponse.ErrorMsg != "" { // return &dto.OpenAIErrorWithStatusCode{ @@ -69,12 +71,12 @@ func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError // }, nil // } fullTextResponse := embeddingResponseMoka2OpenAI(&baiduResponse) - jsonResponse, err := json.Marshal(fullTextResponse) + jsonResponse, err := common.Marshal(fullTextResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &fullTextResponse.Usage + common.IOCopyBytesGracefully(c, resp, jsonResponse) + return &fullTextResponse.Usage, nil } diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 18069311..b9e304fc 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -9,6 +9,7 @@ import ( "one-api/relay/channel/openai" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" + "one-api/types" "github.com/gin-gonic/gin" ) @@ -74,14 +75,14 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.IsStream { - err, usage = openai.OaiStreamHandler(c, resp, info) + usage, err = openai.OaiStreamHandler(c, info, resp) } else { if info.RelayMode == relayconstant.RelayModeEmbeddings { - err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) + usage, err = ollamaEmbeddingHandler(c, info, resp) } else { - err, usage = openai.OpenaiHandler(c, resp, info) + usage, err = openai.OpenaiHandler(c, info, resp) } } return diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index bf7501e5..295349e3 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -1,15 +1,17 @@ package ollama import ( - "encoding/json" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "one-api/dto" + relaycommon "one-api/relay/common" "one-api/service" + "one-api/types" "strings" + + "github.com/gin-gonic/gin" ) func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) { @@ -82,19 +84,19 @@ func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequ } } -func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { var ollamaEmbeddingResponse OllamaEmbeddingResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } common.CloseResponseBodyGracefully(resp) - err = json.Unmarshal(responseBody, &ollamaEmbeddingResponse) + err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } if ollamaEmbeddingResponse.Error != "" { - return service.OpenAIErrorWrapper(err, "ollama_error", resp.StatusCode), nil + return nil, types.NewError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody) } flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding) data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1) @@ -103,22 +105,22 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in Object: "embedding", }) usage := &dto.Usage{ - TotalTokens: promptTokens, + TotalTokens: info.PromptTokens, CompletionTokens: 0, - PromptTokens: promptTokens, + PromptTokens: info.PromptTokens, } embeddingResponse := &dto.OpenAIEmbeddingResponse{ Object: "list", Data: data, - Model: model, + Model: info.UpstreamModelName, Usage: *usage, } - doResponseBody, err := json.Marshal(embeddingResponse) + doResponseBody, err := common.Marshal(embeddingResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } common.IOCopyBytesGracefully(c, resp, doResponseBody) - return nil, usage + return usage, nil } func flattenEmbeddings(embeddings [][]float64) []float64 { diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 367dbc47..217790a7 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -22,6 +22,7 @@ import ( "one-api/relay/common_handler" relayconstant "one-api/relay/constant" "one-api/service" + "one-api/types" "path/filepath" "strings" @@ -421,31 +422,31 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { switch info.RelayMode { case relayconstant.RelayModeRealtime: err, usage = OpenaiRealtimeHandler(c, info) case relayconstant.RelayModeAudioSpeech: - err, usage = OpenaiTTSHandler(c, resp, info) + usage = OpenaiTTSHandler(c, resp, info) case relayconstant.RelayModeAudioTranslation: fallthrough case relayconstant.RelayModeAudioTranscription: err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat) case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits: - err, usage = OpenaiHandlerWithUsage(c, resp, info) + usage, err = OpenaiHandlerWithUsage(c, info, resp) case relayconstant.RelayModeRerank: - err, usage = common_handler.RerankHandler(c, info, resp) + usage, err = common_handler.RerankHandler(c, info, resp) case relayconstant.RelayModeResponses: if info.IsStream { - err, usage = OaiResponsesStreamHandler(c, resp, info) + usage, err = OaiResponsesStreamHandler(c, info, resp) } else { - err, usage = OaiResponsesHandler(c, resp, info) + usage, err = OaiResponsesHandler(c, info, resp) } default: if info.IsStream { - err, usage = OaiStreamHandler(c, resp, info) + usage, err = OaiStreamHandler(c, info, resp) } else { - err, usage = OpenaiHandler(c, resp, info) + usage, err = OpenaiHandler(c, info, resp) } } return diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 6aa73274..bfe8bcd3 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -17,6 +17,8 @@ import ( "path/filepath" "strings" + "one-api/types" + "github.com/bytedance/gopkg/util/gopool" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" @@ -104,10 +106,10 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo return helper.ObjectData(c, lastStreamResponse) } -func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { if resp == nil || resp.Body == nil { common.LogError(c, "invalid response or response body") - return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil + return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse) } defer common.CloseResponseBodyGracefully(resp) @@ -177,26 +179,23 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage) - return nil, usage + return usage, nil } -func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { defer common.CloseResponseBodyGracefully(resp) var simpleResponse dto.OpenAITextResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) } - err = common.UnmarshalJson(responseBody, &simpleResponse) + err = common.Unmarshal(responseBody, &simpleResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } if simpleResponse.Error != nil && simpleResponse.Error.Type != "" { - return &dto.OpenAIErrorWithStatusCode{ - Error: *simpleResponse.Error, - StatusCode: resp.StatusCode, - }, nil + return nil, types.WithOpenAIError(*simpleResponse.Error, resp.StatusCode) } forceFormat := false @@ -220,28 +219,28 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI switch info.RelayFormat { case relaycommon.RelayFormatOpenAI: if forceFormat { - responseBody, err = common.EncodeJson(simpleResponse) + responseBody, err = common.Marshal(simpleResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } } else { break } case relaycommon.RelayFormatClaude: claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info) - claudeRespStr, err := common.EncodeJson(claudeResp) + claudeRespStr, err := common.Marshal(claudeResp) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } responseBody = claudeRespStr } common.IOCopyBytesGracefully(c, resp, responseBody) - return nil, &simpleResponse.Usage + return &simpleResponse.Usage, nil } -func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage { // the status code has been judged before, if there is a body reading failure, // it should be regarded as a non-recoverable error, so it should not return err for external retry. // Analogous to nginx's load balancing, it will only retry if it can't be requested or @@ -261,20 +260,20 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel if err != nil { common.LogError(c, err.Error()) } - return nil, usage + return usage } -func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) { defer common.CloseResponseBodyGracefully(resp) // count tokens by audio file duration audioTokens, err := countAudioTokens(c) if err != nil { - return service.OpenAIErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeCountTokenFailed), nil } responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil } // 写入新的 response body common.IOCopyBytesGracefully(c, resp, responseBody) @@ -328,9 +327,9 @@ func countAudioTokens(c *gin.Context) (int, error) { return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens } -func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) { +func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) { if info == nil || info.ClientWs == nil || info.TargetWs == nil { - return service.OpenAIErrorWrapper(fmt.Errorf("invalid websocket connection"), "invalid_connection", http.StatusBadRequest), nil + return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil } info.IsStream = true @@ -368,7 +367,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op } realtimeEvent := &dto.RealtimeEvent{} - err = common.UnmarshalJson(message, realtimeEvent) + err = common.Unmarshal(message, realtimeEvent) if err != nil { errChan <- fmt.Errorf("error unmarshalling message: %v", err) return @@ -428,7 +427,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op } info.SetFirstResponseTime() realtimeEvent := &dto.RealtimeEvent{} - err = common.UnmarshalJson(message, realtimeEvent) + err = common.Unmarshal(message, realtimeEvent) if err != nil { errChan <- fmt.Errorf("error unmarshalling message: %v", err) return @@ -553,18 +552,18 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R return err } -func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { defer common.CloseResponseBodyGracefully(resp) responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) } var usageResp dto.SimpleResponse - err = common.UnmarshalJson(responseBody, &usageResp) + err = common.Unmarshal(responseBody, &usageResp) if err != nil { - return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } // 写入新的 response body @@ -584,5 +583,5 @@ func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycomm usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens } - return nil, &usageResp.Usage + return &usageResp.Usage, nil } diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index 7f426c33..e874f375 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -9,33 +9,27 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" + "one-api/types" "strings" "github.com/gin-gonic/gin" ) -func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { defer common.CloseResponseBodyGracefully(resp) // read response body var responsesResponse dto.OpenAIResponsesResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) } - err = common.UnmarshalJson(responseBody, &responsesResponse) + err = common.Unmarshal(responseBody, &responsesResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } if responsesResponse.Error != nil { - return &dto.OpenAIErrorWithStatusCode{ - Error: dto.OpenAIError{ - Message: responsesResponse.Error.Message, - Type: "openai_error", - Code: responsesResponse.Error.Code, - }, - StatusCode: resp.StatusCode, - }, nil + return nil, types.WithOpenAIError(*responsesResponse.Error, resp.StatusCode) } // 写入新的 response body @@ -50,13 +44,13 @@ func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon. for _, tool := range responsesResponse.Tools { info.ResponsesUsageInfo.BuiltInTools[tool.Type].CallCount++ } - return nil, &usage + return &usage, nil } -func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { if resp == nil || resp.Body == nil { common.LogError(c, "invalid response or response body") - return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil + return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse) } var usage = &dto.Usage{} @@ -99,5 +93,5 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc } } - return nil, usage + return usage, nil } diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index aee4a307..a60dc4b2 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -9,6 +9,7 @@ import ( "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/service" + "one-api/types" "github.com/gin-gonic/gin" ) @@ -70,13 +71,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.IsStream { var responseText string err, responseText = palmStreamHandler(c, resp) usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { - err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + usage, err = palmHandler(c, info, resp) } return } diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index 44c60713..4db31573 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -2,14 +2,17 @@ package palm import ( "encoding/json" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "one-api/constant" "one-api/dto" + relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" + "one-api/types" + + "github.com/gin-gonic/gin" ) // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body @@ -70,7 +73,7 @@ func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompleti return &response } -func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) { +func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, string) { responseText := "" responseId := helper.GetResponseID(c) createdTime := common.GetTimestamp() @@ -121,42 +124,39 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit return nil, responseText } -func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) } common.CloseResponseBodyGracefully(resp) var palmResponse PaLMChatResponse err = json.Unmarshal(responseBody, &palmResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { - return &dto.OpenAIErrorWithStatusCode{ - Error: dto.OpenAIError{ - Message: palmResponse.Error.Message, - Type: palmResponse.Error.Status, - Param: "", - Code: palmResponse.Error.Code, - }, - StatusCode: resp.StatusCode, - }, nil + return nil, types.WithOpenAIError(types.OpenAIError{ + Message: palmResponse.Error.Message, + Type: palmResponse.Error.Status, + Param: "", + Code: palmResponse.Error.Code, + }, resp.StatusCode) } fullTextResponse := responsePaLM2OpenAI(&palmResponse) - completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, model) + completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, info.UpstreamModelName) usage := dto.Usage{ - PromptTokens: promptTokens, + PromptTokens: info.PromptTokens, CompletionTokens: completionTokens, - TotalTokens: promptTokens + completionTokens, + TotalTokens: info.PromptTokens + completionTokens, } fullTextResponse.Usage = usage - jsonResponse, err := json.Marshal(fullTextResponse) + jsonResponse, err := common.Marshal(fullTextResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &usage + common.IOCopyBytesGracefully(c, resp, jsonResponse) + return &usage, nil } diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index ca206503..19830aca 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -9,6 +9,7 @@ import ( "one-api/relay/channel" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + "one-api/types" "github.com/gin-gonic/gin" ) @@ -73,11 +74,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.IsStream { - err, usage = openai.OaiStreamHandler(c, resp, info) + usage, err = openai.OaiStreamHandler(c, info, resp) } else { - err, usage = openai.OpenaiHandler(c, resp, info) + usage, err = openai.OpenaiHandler(c, info, resp) } return } diff --git a/relay/channel/siliconflow/adaptor.go b/relay/channel/siliconflow/adaptor.go index 89236ea3..63c1c84d 100644 --- a/relay/channel/siliconflow/adaptor.go +++ b/relay/channel/siliconflow/adaptor.go @@ -10,6 +10,7 @@ import ( "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/constant" + "one-api/types" "github.com/gin-gonic/gin" ) @@ -76,20 +77,20 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return request, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { switch info.RelayMode { case constant.RelayModeRerank: - err, usage = siliconflowRerankHandler(c, resp) + usage, err = siliconflowRerankHandler(c, info, resp) case constant.RelayModeCompletions: fallthrough case constant.RelayModeChatCompletions: if info.IsStream { - err, usage = openai.OaiStreamHandler(c, resp, info) + usage, err = openai.OaiStreamHandler(c, info, resp) } else { - err, usage = openai.OpenaiHandler(c, resp, info) + usage, err = openai.OpenaiHandler(c, info, resp) } case constant.RelayModeEmbeddings: - err, usage = openai.OpenaiHandler(c, resp, info) + usage, err = openai.OpenaiHandler(c, info, resp) } return } diff --git a/relay/channel/siliconflow/relay-siliconflow.go b/relay/channel/siliconflow/relay-siliconflow.go index a52ebfda..fabaf9c6 100644 --- a/relay/channel/siliconflow/relay-siliconflow.go +++ b/relay/channel/siliconflow/relay-siliconflow.go @@ -2,24 +2,26 @@ package siliconflow import ( "encoding/json" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "one-api/dto" - "one-api/service" + relaycommon "one-api/relay/common" + "one-api/types" + + "github.com/gin-gonic/gin" ) -func siliconflowRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) } common.CloseResponseBodyGracefully(resp) var siliconflowResp SFRerankResponse err = json.Unmarshal(responseBody, &siliconflowResp) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } usage := &dto.Usage{ PromptTokens: siliconflowResp.Meta.Tokens.InputTokens, @@ -33,10 +35,10 @@ func siliconflowRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIE jsonResponse, err := json.Marshal(rerankResp) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, usage + common.IOCopyBytesGracefully(c, resp, jsonResponse) + return usage, nil } diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index 7ea3aae7..520276a7 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -6,10 +6,11 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" - "one-api/service" + "one-api/types" "strconv" "strings" @@ -63,7 +64,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } - apiKey := c.Request.Header.Get("Authorization") + apiKey := common.GetContextKeyString(c, constant.ContextKeyChannelKey) apiKey = strings.TrimPrefix(apiKey, "Bearer ") appId, secretId, secretKey, err := parseTencentConfig(apiKey) a.AppID = appId @@ -94,13 +95,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.IsStream { - var responseText string - err, responseText = tencentStreamHandler(c, resp) - usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage, err = tencentStreamHandler(c, info, resp) } else { - err, usage = tencentHandler(c, resp) + usage, err = tencentHandler(c, info, resp) } return } diff --git a/relay/channel/tencent/relay-tencent.go b/relay/channel/tencent/relay-tencent.go index a7106a88..c3d96c49 100644 --- a/relay/channel/tencent/relay-tencent.go +++ b/relay/channel/tencent/relay-tencent.go @@ -8,17 +8,20 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "one-api/constant" "one-api/dto" + relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" + "one-api/types" "strconv" "strings" "time" + + "github.com/gin-gonic/gin" ) // https://cloud.tencent.com/document/product/1729/97732 @@ -86,7 +89,7 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.Cha return &response } -func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) { +func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { var responseText string scanner := bufio.NewScanner(resp.Body) scanner.Split(bufio.ScanLines) @@ -126,38 +129,35 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError common.CloseResponseBodyGracefully(resp) - return nil, responseText + return service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens), nil } -func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { var tencentSb TencentChatResponseSB responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) } common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &tencentSb) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } if tencentSb.Response.Error.Code != 0 { - return &dto.OpenAIErrorWithStatusCode{ - Error: dto.OpenAIError{ - Message: tencentSb.Response.Error.Message, - Code: tencentSb.Response.Error.Code, - }, - StatusCode: resp.StatusCode, - }, nil + return nil, types.WithOpenAIError(types.OpenAIError{ + Message: tencentSb.Response.Error.Message, + Code: tencentSb.Response.Error.Code, + }, resp.StatusCode) } fullTextResponse := responseTencent2OpenAI(&tencentSb.Response) - jsonResponse, err := json.Marshal(fullTextResponse) + jsonResponse, err := common.Marshal(fullTextResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &fullTextResponse.Usage + common.IOCopyBytesGracefully(c, resp, jsonResponse) + return &fullTextResponse.Usage, nil } func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) { diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index e568f651..fa895de0 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -14,6 +14,7 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/constant" "one-api/setting/model_setting" + "one-api/types" "strings" "github.com/gin-gonic/gin" @@ -208,19 +209,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.IsStream { switch a.RequestMode { case RequestModeClaude: err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) case RequestModeGemini: if info.RelayMode == constant.RelayModeGemini { - usage, err = gemini.GeminiTextGenerationStreamHandler(c, resp, info) + usage, err = gemini.GeminiTextGenerationStreamHandler(c, info, resp) } else { - err, usage = gemini.GeminiChatStreamHandler(c, resp, info) + usage, err = gemini.GeminiChatStreamHandler(c, info, resp) } case RequestModeLlama: - err, usage = openai.OaiStreamHandler(c, resp, info) + usage, err = openai.OaiStreamHandler(c, info, resp) } } else { switch a.RequestMode { @@ -228,12 +229,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info) case RequestModeGemini: if info.RelayMode == constant.RelayModeGemini { - usage, err = gemini.GeminiTextGenerationHandler(c, resp, info) + usage, err = gemini.GeminiTextGenerationHandler(c, info, resp) } else { - err, usage = gemini.GeminiChatHandler(c, resp, info) + usage, err = gemini.GeminiChatHandler(c, info, resp) } case RequestModeLlama: - err, usage = openai.OpenaiHandler(c, resp, info) + usage, err = openai.OpenaiHandler(c, info, resp) } } return diff --git a/relay/channel/vertex/relay-vertex.go b/relay/channel/vertex/relay-vertex.go index d2596320..5ed87665 100644 --- a/relay/channel/vertex/relay-vertex.go +++ b/relay/channel/vertex/relay-vertex.go @@ -4,8 +4,11 @@ import "one-api/common" func GetModelRegion(other string, localModelName string) string { // if other is json string - if common.IsJsonStr(other) { - m := common.StrToMap(other) + if common.IsJsonObject(other) { + m, err := common.StrToMap(other) + if err != nil { + return other // return original if parsing fails + } if m[localModelName] != nil { return m[localModelName].(string) } else { diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index 78233934..af15d636 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -13,6 +13,7 @@ import ( "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/constant" + "one-api/types" "path/filepath" "strings" @@ -225,18 +226,18 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { switch info.RelayMode { case constant.RelayModeChatCompletions: if info.IsStream { - err, usage = openai.OaiStreamHandler(c, resp, info) + usage, err = openai.OaiStreamHandler(c, info, resp) } else { - err, usage = openai.OpenaiHandler(c, resp, info) + usage, err = openai.OpenaiHandler(c, info, resp) } case constant.RelayModeEmbeddings: - err, usage = openai.OpenaiHandler(c, resp, info) + usage, err = openai.OpenaiHandler(c, info, resp) case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits: - err, usage = openai.OpenaiHandlerWithUsage(c, resp, info) + usage, err = openai.OpenaiHandlerWithUsage(c, info, resp) } return } diff --git a/relay/channel/xai/adaptor.go b/relay/channel/xai/adaptor.go index 17747dd5..8d880137 100644 --- a/relay/channel/xai/adaptor.go +++ b/relay/channel/xai/adaptor.go @@ -8,6 +8,7 @@ import ( "one-api/relay/channel" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + "one-api/types" "strings" "one-api/relay/constant" @@ -104,15 +105,15 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { switch info.RelayMode { case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits: - err, usage = openai.OpenaiHandlerWithUsage(c, resp, info) + usage, err = openai.OpenaiHandlerWithUsage(c, info, resp) default: if info.IsStream { - err, usage = xAIStreamHandler(c, resp, info) + usage, err = xAIStreamHandler(c, info, resp) } else { - err, usage = xAIHandler(c, resp, info) + usage, err = xAIHandler(c, info, resp) } } return diff --git a/relay/channel/xai/text.go b/relay/channel/xai/text.go index 4a030e48..272cc749 100644 --- a/relay/channel/xai/text.go +++ b/relay/channel/xai/text.go @@ -10,6 +10,7 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" + "one-api/types" "strings" "github.com/gin-gonic/gin" @@ -34,7 +35,7 @@ func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage return openAIResp } -func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { usage := &dto.Usage{} var responseTextBuilder strings.Builder var toolCount int @@ -74,30 +75,28 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel helper.Done(c) common.CloseResponseBodyGracefully(resp) - return nil, usage + return usage, nil } -func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { defer common.CloseResponseBodyGracefully(resp) responseBody, err := io.ReadAll(resp.Body) var response *dto.SimpleResponse - err = common.UnmarshalJson(responseBody, &response) + err = common.Unmarshal(responseBody, &response) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - return nil, nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } response.Usage.CompletionTokens = response.Usage.TotalTokens - response.Usage.PromptTokens response.Usage.CompletionTokenDetails.TextTokens = response.Usage.CompletionTokens - response.Usage.CompletionTokenDetails.ReasoningTokens // new body - encodeJson, err := common.EncodeJson(response) + encodeJson, err := common.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return nil, nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } common.IOCopyBytesGracefully(c, resp, encodeJson) - return nil, &response.Usage + return &response.Usage, nil } diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index 7591e0e7..0d218ada 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -7,7 +7,7 @@ import ( "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" - "one-api/service" + "one-api/types" "strings" "github.com/gin-gonic/gin" @@ -74,18 +74,18 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return dummyResp, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { splits := strings.Split(info.ApiKey, "|") if len(splits) != 3 { - return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) + return nil, types.NewError(errors.New("invalid auth"), types.ErrorCodeChannelInvalidKey) } if a.request == nil { - return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest) + return nil, types.NewError(errors.New("request is nil"), types.ErrorCodeInvalidRequest) } if info.IsStream { - err, usage = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2]) + usage, err = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2]) } else { - err, usage = xunfeiHandler(c, *a.request, splits[0], splits[1], splits[2]) + usage, err = xunfeiHandler(c, *a.request, splits[0], splits[1], splits[2]) } return } diff --git a/relay/channel/xunfei/relay-xunfei.go b/relay/channel/xunfei/relay-xunfei.go index c6ef722c..373ad605 100644 --- a/relay/channel/xunfei/relay-xunfei.go +++ b/relay/channel/xunfei/relay-xunfei.go @@ -6,18 +6,18 @@ import ( "encoding/base64" "encoding/json" "fmt" - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" "io" - "net/http" "net/url" "one-api/common" "one-api/constant" "one-api/dto" "one-api/relay/helper" - "one-api/service" + "one-api/types" "strings" "time" + + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" ) // https://console.xfyun.cn/services/cbm @@ -126,11 +126,11 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { return callUrl } -func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.Usage, *types.NewAPIError) { domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model) dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) if err != nil { - return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeDoRequestFailed) } helper.SetEventStreamHeaders(c) var usage dto.Usage @@ -153,14 +153,14 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a return false } }) - return nil, &usage + return &usage, nil } -func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.Usage, *types.NewAPIError) { domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model) dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) if err != nil { - return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeDoRequestFailed) } var usage dto.Usage var content string @@ -191,11 +191,11 @@ func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId s response := responseXunfei2OpenAI(&xunfeiResponse) jsonResponse, err := json.Marshal(response) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } c.Writer.Header().Set("Content-Type", "application/json") _, _ = c.Writer.Write(jsonResponse) - return nil, &usage + return &usage, nil } func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) { diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index b4d8fb30..43344428 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -8,6 +8,7 @@ import ( "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" + "one-api/types" "github.com/gin-gonic/gin" ) @@ -77,11 +78,11 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo return nil, errors.New("not implemented") } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.IsStream { - err, usage = zhipuStreamHandler(c, resp) + usage, err = zhipuStreamHandler(c, info, resp) } else { - err, usage = zhipuHandler(c, resp) + usage, err = zhipuHandler(c, info, resp) } return } diff --git a/relay/channel/zhipu/relay-zhipu.go b/relay/channel/zhipu/relay-zhipu.go index 91cd384b..916a200d 100644 --- a/relay/channel/zhipu/relay-zhipu.go +++ b/relay/channel/zhipu/relay-zhipu.go @@ -3,18 +3,20 @@ package zhipu import ( "bufio" "encoding/json" - "github.com/gin-gonic/gin" - "github.com/golang-jwt/jwt" "io" "net/http" "one-api/common" "one-api/constant" "one-api/dto" + relaycommon "one-api/relay/common" "one-api/relay/helper" - "one-api/service" + "one-api/types" "strings" "sync" "time" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt" ) // https://open.bigmodel.cn/doc/api#chatglm_std @@ -150,7 +152,7 @@ func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dt return &response, &zhipuResponse.Usage } -func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { var usage *dto.Usage scanner := bufio.NewScanner(resp.Body) scanner.Split(bufio.ScanLines) @@ -211,38 +213,33 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi } }) common.CloseResponseBodyGracefully(resp) - return nil, usage + return usage, nil } -func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func zhipuHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { var zhipuResponse ZhipuResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) } common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &zhipuResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } if !zhipuResponse.Success { - return &dto.OpenAIErrorWithStatusCode{ - Error: dto.OpenAIError{ - Message: zhipuResponse.Msg, - Type: "zhipu_error", - Param: "", - Code: zhipuResponse.Code, - }, - StatusCode: resp.StatusCode, - }, nil + return nil, types.WithOpenAIError(types.OpenAIError{ + Message: zhipuResponse.Msg, + Code: zhipuResponse.Code, + }, resp.StatusCode) } fullTextResponse := responseZhipu2OpenAI(&zhipuResponse) jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) _, err = c.Writer.Write(jsonResponse) - return nil, &fullTextResponse.Usage + return &fullTextResponse.Usage, nil } diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index 222cdff8..edd7a534 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -10,6 +10,7 @@ import ( "one-api/relay/channel/openai" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" + "one-api/types" "github.com/gin-gonic/gin" ) @@ -80,11 +81,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.IsStream { - err, usage = openai.OaiStreamHandler(c, resp, info) + usage, err = openai.OaiStreamHandler(c, info, resp) } else { - err, usage = openai.OpenaiHandler(c, resp, info) + usage, err = openai.OpenaiHandler(c, info, resp) } return } diff --git a/relay/claude_handler.go b/relay/claude_handler.go index 42139ddf..da0edfd9 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -2,10 +2,8 @@ package relay import ( "bytes" - "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" @@ -14,7 +12,10 @@ import ( "one-api/relay/helper" "one-api/service" "one-api/setting/model_setting" + "one-api/types" "strings" + + "github.com/gin-gonic/gin" ) func getAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) { @@ -32,14 +33,14 @@ func getAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest return textRequest, nil } -func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) { +func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { relayInfo := relaycommon.GenRelayInfoClaude(c) // get & validate textRequest 获取并验证文本请求 textRequest, err := getAndValidateClaudeRequest(c) if err != nil { - return service.ClaudeErrorWrapperLocal(err, "invalid_claude_request", http.StatusBadRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest) } if textRequest.Stream { @@ -48,35 +49,35 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) { err = helper.ModelMappedHelper(c, relayInfo, textRequest) if err != nil { - return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError) } promptTokens, err := getClaudePromptTokens(textRequest, relayInfo) // count messages token error 计算promptTokens错误 if err != nil { - return service.ClaudeErrorWrapperLocal(err, "count_token_messages_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeCountTokenFailed) } priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens)) if err != nil { - return service.ClaudeErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeModelPriceError) } // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if openaiErr != nil { - return service.OpenAIErrorToClaudeError(openaiErr) + if newAPIError != nil { + return newAPIError } defer func() { - if openaiErr != nil { + if newAPIError != nil { returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) } }() adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return service.ClaudeErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) } adaptor.Init(relayInfo) var requestBody io.Reader @@ -109,14 +110,14 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) { convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest) if err != nil { - return service.ClaudeErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeConvertRequestFailed) } - jsonData, err := json.Marshal(convertedRequest) + jsonData, err := common.Marshal(convertedRequest) if common.DebugEnabled { println("requestBody: ", string(jsonData)) } if err != nil { - return service.ClaudeErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeConvertRequestFailed) } requestBody = bytes.NewBuffer(jsonData) @@ -124,26 +125,26 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) { var httpResp *http.Response resp, err := adaptor.DoRequest(c, relayInfo, requestBody) if err != nil { - return service.ClaudeErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeDoRequestFailed) } if resp != nil { httpResp = resp.(*http.Response) relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { - openaiErr = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(httpResp, false) // reset status code 重置状态码 - service.ResetStatusCode(openaiErr, statusCodeMappingStr) - return service.OpenAIErrorToClaudeError(openaiErr) + service.ResetStatusCode(newAPIError, statusCodeMappingStr) + return newAPIError } } - usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) //log.Printf("usage: %v", usage) - if openaiErr != nil { + if newAPIError != nil { // reset status code 重置状态码 - service.ResetStatusCode(openaiErr, statusCodeMappingStr) - return service.OpenAIErrorToClaudeError(openaiErr) + service.ResetStatusCode(newAPIError, statusCodeMappingStr) + return newAPIError } service.PostClaudeConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") return nil diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 2f5f5d38..5b7dee80 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -213,7 +213,7 @@ func GenRelayInfoImage(c *gin.Context) *RelayInfo { func GenRelayInfo(c *gin.Context) *RelayInfo { channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType) channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId) - paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyParamOverride) + paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride) tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId) tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey) @@ -229,7 +229,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail), isFirstResponse: true, RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), - BaseUrl: common.GetContextKeyString(c, constant.ContextKeyBaseUrl), + BaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl), RequestURLPath: c.Request.URL.String(), ChannelType: channelType, ChannelId: channelId, @@ -247,7 +247,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { IsModelMapped: false, ApiType: apiType, ApiVersion: c.GetString("api_version"), - ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), + ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey), Organization: c.GetString("channel_organization"), ChannelCreateTime: c.GetInt64("channel_create_time"), diff --git a/relay/common_handler/rerank.go b/relay/common_handler/rerank.go index 0df219e3..ce823b3a 100644 --- a/relay/common_handler/rerank.go +++ b/relay/common_handler/rerank.go @@ -1,7 +1,6 @@ package common_handler import ( - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" @@ -9,13 +8,15 @@ import ( "one-api/dto" "one-api/relay/channel/xinference" relaycommon "one-api/relay/common" - "one-api/service" + "one-api/types" + + "github.com/gin-gonic/gin" ) -func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) } common.CloseResponseBodyGracefully(resp) if common.DebugEnabled { @@ -24,9 +25,9 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo var jinaResp dto.RerankResponse if info.ChannelType == constant.ChannelTypeXinference { var xinRerankResponse xinference.XinRerankResponse - err = common.UnmarshalJson(responseBody, &xinRerankResponse) + err = common.Unmarshal(responseBody, &xinRerankResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results)) for i, result := range xinRerankResponse.Results { @@ -59,14 +60,14 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo }, } } else { - err = common.UnmarshalJson(responseBody, &jinaResp) + err = common.Unmarshal(responseBody, &jinaResp) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens } c.Writer.Header().Set("Content-Type", "application/json") c.JSON(http.StatusOK, jinaResp) - return nil, &jinaResp.Usage + return &jinaResp.Usage, nil } diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go index 849c70da..20b028ed 100644 --- a/relay/embedding_handler.go +++ b/relay/embedding_handler.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/dto" @@ -12,6 +11,9 @@ import ( relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" + "one-api/types" + + "github.com/gin-gonic/gin" ) func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int { @@ -32,24 +34,24 @@ func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embed return nil } -func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { +func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) { relayInfo := relaycommon.GenRelayInfoEmbedding(c) var embeddingRequest *dto.EmbeddingRequest err := common.UnmarshalBodyReusable(c, &embeddingRequest) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) - return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest) } err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest) } err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError) } promptToken := getEmbeddingPromptToken(*embeddingRequest) @@ -57,57 +59,57 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeModelPriceError) } // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if openaiErr != nil { - return openaiErr + preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + if newAPIError != nil { + return newAPIError } defer func() { - if openaiErr != nil { + if newAPIError != nil { returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) } }() adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) } adaptor.Init(relayInfo) convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeConvertRequestFailed) } jsonData, err := json.Marshal(convertedRequest) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeConvertRequestFailed) } requestBody := bytes.NewBuffer(jsonData) statusCodeMappingStr := c.GetString("status_code_mapping") resp, err := adaptor.DoRequest(c, relayInfo, requestBody) if err != nil { - return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeDoRequestFailed) } var httpResp *http.Response if resp != nil { httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { - openaiErr = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(httpResp, false) // reset status code 重置状态码 - service.ResetStatusCode(openaiErr, statusCodeMappingStr) - return openaiErr + service.ResetStatusCode(newAPIError, statusCodeMappingStr) + return newAPIError } } - usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo) - if openaiErr != nil { + usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + if newAPIError != nil { // reset status code 重置状态码 - service.ResetStatusCode(openaiErr, statusCodeMappingStr) - return openaiErr + service.ResetStatusCode(newAPIError, statusCodeMappingStr) + return newAPIError } postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") return nil diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index 9185ce62..e448b491 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -14,6 +14,7 @@ import ( "one-api/service" "one-api/setting" "one-api/setting/model_setting" + "one-api/types" "strings" "github.com/gin-gonic/gin" @@ -104,11 +105,11 @@ func trimModelThinking(modelName string) string { return modelName } -func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { +func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { req, err := getAndValidateGeminiRequest(c) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error())) - return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest) } relayInfo := relaycommon.GenRelayInfoGemini(c) @@ -120,14 +121,14 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { sensitiveWords, err := checkGeminiInputSensitive(req) if err != nil { common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", "))) - return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest) + return types.NewError(err, types.ErrorCodeSensitiveWordsDetected) } } // model mapped 模型映射 err = helper.ModelMappedHelper(c, relayInfo, req) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest) + return types.NewError(err, types.ErrorCodeChannelModelMappedError) } if value, exists := c.Get("prompt_tokens"); exists { @@ -158,23 +159,23 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens)) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeModelPriceError) } // pre consume quota - preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if openaiErr != nil { - return openaiErr + preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + if newAPIError != nil { + return newAPIError } defer func() { - if openaiErr != nil { + if newAPIError != nil { returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) } }() adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) } adaptor.Init(relayInfo) @@ -195,7 +196,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { requestBody, err := json.Marshal(req) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeConvertRequestFailed) } if common.DebugEnabled { @@ -205,7 +206,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody)) if err != nil { common.LogError(c, "Do gemini request failed: "+err.Error()) - return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeDoRequestFailed) } statusCodeMappingStr := c.GetString("status_code_mapping") @@ -215,10 +216,10 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { httpResp = resp.(*http.Response) relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { - openaiErr = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(httpResp, false) // reset status code 重置状态码 - service.ResetStatusCode(openaiErr, statusCodeMappingStr) - return openaiErr + service.ResetStatusCode(newAPIError, statusCodeMappingStr) + return newAPIError } } diff --git a/relay/helper/common.go b/relay/helper/common.go index 35d983f7..5d23b512 100644 --- a/relay/helper/common.go +++ b/relay/helper/common.go @@ -4,27 +4,29 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" "net/http" "one-api/common" "one-api/dto" + "one-api/types" + + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" ) func SetEventStreamHeaders(c *gin.Context) { - // 检查是否已经设置过头部 - if _, exists := c.Get("event_stream_headers_set"); exists { - return - } - - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("Transfer-Encoding", "chunked") - c.Writer.Header().Set("X-Accel-Buffering", "no") - - // 设置标志,表示头部已经设置过 - c.Set("event_stream_headers_set", true) + // 检查是否已经设置过头部 + if _, exists := c.Get("event_stream_headers_set"); exists { + return + } + + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") + + // 设置标志,表示头部已经设置过 + c.Set("event_stream_headers_set", true) } func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error { @@ -85,7 +87,7 @@ func ObjectData(c *gin.Context, object interface{}) error { if object == nil { return errors.New("object is nil") } - jsonData, err := json.Marshal(object) + jsonData, err := common.Marshal(object) if err != nil { return fmt.Errorf("error marshalling object: %w", err) } @@ -118,7 +120,7 @@ func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error { return ws.WriteMessage(1, jsonData) } -func WssError(c *gin.Context, ws *websocket.Conn, openaiError dto.OpenAIError) { +func WssError(c *gin.Context, ws *websocket.Conn, openaiError types.OpenAIError) { errorObj := &dto.RealtimeEvent{ Type: "error", EventId: GetLocalRealtimeID(c), diff --git a/relay/image_handler.go b/relay/image_handler.go index 5decb497..44f44277 100644 --- a/relay/image_handler.go +++ b/relay/image_handler.go @@ -16,6 +16,7 @@ import ( "one-api/relay/helper" "one-api/service" "one-api/setting" + "one-api/types" "strings" "github.com/gin-gonic/gin" @@ -107,23 +108,23 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. return imageRequest, nil } -func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { +func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { relayInfo := relaycommon.GenRelayInfoImage(c) imageRequest, err := getAndValidImageRequest(c, relayInfo) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error())) - return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest) } err = helper.ModelMappedHelper(c, relayInfo, imageRequest) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError) } priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeModelPriceError) } var preConsumedQuota int var quota int @@ -132,13 +133,12 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { // modelRatio 16 = modelPrice $0.04 // per 1 modelRatio = $0.04 / 16 // priceData.ModelPrice = 0.0025 * priceData.ModelRatio - var openaiErr *dto.OpenAIErrorWithStatusCode - preConsumedQuota, userQuota, openaiErr = preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if openaiErr != nil { - return openaiErr + preConsumedQuota, userQuota, newAPIError = preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + if newAPIError != nil { + return newAPIError } defer func() { - if openaiErr != nil { + if newAPIError != nil { returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) } }() @@ -169,16 +169,16 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit) userQuota, err = model.GetUserQuota(relayInfo.UserId, false) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeQueryDataError) } if userQuota-quota < 0 { - return service.OpenAIErrorWrapperLocal(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), "insufficient_user_quota", http.StatusForbidden) + return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota) } } adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) } adaptor.Init(relayInfo) @@ -186,14 +186,14 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeConvertRequestFailed) } if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits { requestBody = convertedRequest.(io.Reader) } else { jsonData, err := json.Marshal(convertedRequest) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeConvertRequestFailed) } requestBody = bytes.NewBuffer(jsonData) } @@ -206,25 +206,25 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { resp, err := adaptor.DoRequest(c, relayInfo, requestBody) if err != nil { - return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeDoRequestFailed) } var httpResp *http.Response if resp != nil { httpResp = resp.(*http.Response) relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { - openaiErr := service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(httpResp, false) // reset status code 重置状态码 - service.ResetStatusCode(openaiErr, statusCodeMappingStr) - return openaiErr + service.ResetStatusCode(newAPIError, statusCodeMappingStr) + return newAPIError } } - usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo) - if openaiErr != nil { + usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + if newAPIError != nil { // reset status code 重置状态码 - service.ResetStatusCode(openaiErr, statusCodeMappingStr) - return openaiErr + service.ResetStatusCode(newAPIError, statusCodeMappingStr) + return newAPIError } if usage.(*dto.Usage).TotalTokens == 0 { diff --git a/relay/relay-mj.go b/relay/relay-mj.go index f23f8152..e7f316b9 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -575,7 +575,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons common.SysError("get_channel_null: " + err.Error()) } if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled { - model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance") + model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance") } } if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 { diff --git a/relay/relay-text.go b/relay/relay-text.go index 86b6c530..46120529 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -19,6 +19,7 @@ import ( "one-api/setting" "one-api/setting/model_setting" "one-api/setting/operation_setting" + "one-api/types" "strings" "time" @@ -84,7 +85,7 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) return textRequest, nil } -func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { +func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { relayInfo := relaycommon.GenRelayInfo(c) @@ -92,8 +93,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { textRequest, err := getAndValidateTextRequest(c, relayInfo) if err != nil { - common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) - return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest) } if textRequest.WebSearchOptions != nil { @@ -104,13 +104,13 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { words, err := checkRequestSensitive(textRequest, relayInfo) if err != nil { common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", "))) - return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest) + return types.NewError(err, types.ErrorCodeSensitiveWordsDetected) } } err = helper.ModelMappedHelper(c, relayInfo, textRequest) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError) } // 获取 promptTokens,如果上下文中已经存在,则直接使用 @@ -122,23 +122,23 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { promptTokens, err = getPromptTokens(textRequest, relayInfo) // count messages token error 计算promptTokens错误 if err != nil { - return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeCountTokenFailed) } c.Set("prompt_tokens", promptTokens) } priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens)))) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeModelPriceError) } // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if openaiErr != nil { - return openaiErr + preConsumedQuota, userQuota, newApiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + if newApiErr != nil { + return newApiErr } defer func() { - if openaiErr != nil { + if newApiErr != nil { returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) } }() @@ -166,7 +166,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) } adaptor.Init(relayInfo) var requestBody io.Reader @@ -174,32 +174,29 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { if model_setting.GetGlobalSettings().PassThroughRequestEnabled { body, err := common.GetRequestBody(c) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "get_request_body_failed", http.StatusInternalServerError) + return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest) } requestBody = bytes.NewBuffer(body) } else { convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeConvertRequestFailed) } jsonData, err := json.Marshal(convertedRequest) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeConvertRequestFailed) } // apply param override if len(relayInfo.ParamOverride) > 0 { reqMap := make(map[string]interface{}) - err = json.Unmarshal(jsonData, &reqMap) - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "param_override_unmarshal_failed", http.StatusInternalServerError) - } + _ = common.Unmarshal(jsonData, &reqMap) for key, value := range relayInfo.ParamOverride { reqMap[key] = value } - jsonData, err = json.Marshal(reqMap) + jsonData, err = common.Marshal(reqMap) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "param_override_marshal_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid) } } @@ -213,7 +210,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { resp, err := adaptor.DoRequest(c, relayInfo, requestBody) if err != nil { - return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeDoRequestFailed) } statusCodeMappingStr := c.GetString("status_code_mapping") @@ -222,18 +219,18 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { httpResp = resp.(*http.Response) relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { - openaiErr = service.RelayErrorHandler(httpResp, false) + newApiErr = service.RelayErrorHandler(httpResp, false) // reset status code 重置状态码 - service.ResetStatusCode(openaiErr, statusCodeMappingStr) - return openaiErr + service.ResetStatusCode(newApiErr, statusCodeMappingStr) + return newApiErr } } - usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo) - if openaiErr != nil { + usage, newApiErr := adaptor.DoResponse(c, httpResp, relayInfo) + if newApiErr != nil { // reset status code 重置状态码 - service.ResetStatusCode(openaiErr, statusCodeMappingStr) - return openaiErr + service.ResetStatusCode(newApiErr, statusCodeMappingStr) + return newApiErr } if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") { @@ -281,16 +278,16 @@ func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycom } // 预扣费并返回用户剩余配额 -func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *dto.OpenAIErrorWithStatusCode) { +func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *types.NewAPIError) { userQuota, err := model.GetUserQuota(relayInfo.UserId, false) if err != nil { - return 0, 0, service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) + return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError) } if userQuota <= 0 { - return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden) } if userQuota-preConsumedQuota < 0 { - return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusForbidden) + return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden) } relayInfo.UserQuota = userQuota if userQuota > 100*preConsumedQuota { @@ -314,11 +311,11 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo if preConsumedQuota > 0 { err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota) if err != nil { - return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) + return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden) } err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota) if err != nil { - return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError) + return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError) } } return preConsumedQuota, userQuota, nil diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go index 5cf384a8..72ca6a0b 100644 --- a/relay/rerank_handler.go +++ b/relay/rerank_handler.go @@ -2,15 +2,16 @@ package relay import ( "bytes" - "encoding/json" "fmt" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" + "one-api/types" + + "github.com/gin-gonic/gin" ) func getRerankPromptToken(rerankRequest dto.RerankRequest) int { @@ -22,27 +23,27 @@ func getRerankPromptToken(rerankRequest dto.RerankRequest) int { return token } -func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) { +func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError) { var rerankRequest *dto.RerankRequest err := common.UnmarshalBodyReusable(c, &rerankRequest) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) - return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest) } relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest) if rerankRequest.Query == "" { - return service.OpenAIErrorWrapperLocal(fmt.Errorf("query is empty"), "invalid_query", http.StatusBadRequest) + return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest) } if len(rerankRequest.Documents) == 0 { - return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest) + return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest) } err = helper.ModelMappedHelper(c, relayInfo, rerankRequest) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError) } promptToken := getRerankPromptToken(*rerankRequest) @@ -50,32 +51,32 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeModelPriceError) } // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if openaiErr != nil { - return openaiErr + preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + if newAPIError != nil { + return newAPIError } defer func() { - if openaiErr != nil { + if newAPIError != nil { returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) } }() adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) } adaptor.Init(relayInfo) convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeConvertRequestFailed) } - jsonData, err := json.Marshal(convertedRequest) + jsonData, err := common.Marshal(convertedRequest) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeConvertRequestFailed) } requestBody := bytes.NewBuffer(jsonData) if common.DebugEnabled { @@ -83,7 +84,7 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith } resp, err := adaptor.DoRequest(c, relayInfo, requestBody) if err != nil { - return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeDoRequestFailed) } statusCodeMappingStr := c.GetString("status_code_mapping") @@ -91,18 +92,18 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith if resp != nil { httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { - openaiErr = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(httpResp, false) // reset status code 重置状态码 - service.ResetStatusCode(openaiErr, statusCodeMappingStr) - return openaiErr + service.ResetStatusCode(newAPIError, statusCodeMappingStr) + return newAPIError } } - usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo) - if openaiErr != nil { + usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + if newAPIError != nil { // reset status code 重置状态码 - service.ResetStatusCode(openaiErr, statusCodeMappingStr) - return openaiErr + service.ResetStatusCode(newAPIError, statusCodeMappingStr) + return newAPIError } postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") return nil diff --git a/relay/responses_handler.go b/relay/responses_handler.go index e744e354..10fa448b 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -14,6 +14,7 @@ import ( "one-api/service" "one-api/setting" "one-api/setting/model_setting" + "one-api/types" "strings" "github.com/gin-gonic/gin" @@ -46,11 +47,11 @@ func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo return inputTokens } -func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { +func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { req, err := getAndValidateResponsesRequest(c) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error())) - return service.OpenAIErrorWrapperLocal(err, "invalid_responses_request", http.StatusBadRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest) } relayInfo := relaycommon.GenRelayInfoResponses(c, req) @@ -59,13 +60,13 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) sensitiveWords, err := checkInputSensitive(req, relayInfo) if err != nil { common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", "))) - return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest) + return types.NewError(err, types.ErrorCodeSensitiveWordsDetected) } } err = helper.ModelMappedHelper(c, relayInfo, req) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest) + return types.NewError(err, types.ErrorCodeChannelModelMappedError) } if value, exists := c.Get("prompt_tokens"); exists { @@ -78,52 +79,52 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.MaxOutputTokens)) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeModelPriceError) } // pre consume quota - preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if openaiErr != nil { - return openaiErr + preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + if newAPIError != nil { + return newAPIError } defer func() { - if openaiErr != nil { + if newAPIError != nil { returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) } }() adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) } adaptor.Init(relayInfo) var requestBody io.Reader if model_setting.GetGlobalSettings().PassThroughRequestEnabled { body, err := common.GetRequestBody(c) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "get_request_body_error", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeReadRequestBodyFailed) } requestBody = bytes.NewBuffer(body) } else { convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, relayInfo, *req) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "convert_request_error", http.StatusBadRequest) + return types.NewError(err, types.ErrorCodeConvertRequestFailed) } jsonData, err := json.Marshal(convertedRequest) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "marshal_request_error", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeConvertRequestFailed) } // apply param override if len(relayInfo.ParamOverride) > 0 { reqMap := make(map[string]interface{}) err = json.Unmarshal(jsonData, &reqMap) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "param_override_unmarshal_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid) } for key, value := range relayInfo.ParamOverride { reqMap[key] = value } jsonData, err = json.Marshal(reqMap) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "param_override_marshal_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeConvertRequestFailed) } } @@ -136,7 +137,7 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) var httpResp *http.Response resp, err := adaptor.DoRequest(c, relayInfo, requestBody) if err != nil { - return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeDoRequestFailed) } statusCodeMappingStr := c.GetString("status_code_mapping") @@ -145,18 +146,18 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { - openaiErr = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(httpResp, false) // reset status code 重置状态码 - service.ResetStatusCode(openaiErr, statusCodeMappingStr) - return openaiErr + service.ResetStatusCode(newAPIError, statusCodeMappingStr) + return newAPIError } } - usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo) - if openaiErr != nil { + usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + if newAPIError != nil { // reset status code 重置状态码 - service.ResetStatusCode(openaiErr, statusCodeMappingStr) - return openaiErr + service.ResetStatusCode(newAPIError, statusCodeMappingStr) + return newAPIError } if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") { diff --git a/relay/websocket.go b/relay/websocket.go index 571f3a82..659e27d5 100644 --- a/relay/websocket.go +++ b/relay/websocket.go @@ -1,18 +1,18 @@ package relay import ( - "encoding/json" "fmt" - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" - "net/http" "one-api/dto" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" + "one-api/types" + + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" ) -func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) { +func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIError) { relayInfo := relaycommon.GenRelayInfoWs(c, ws) // get & validate textRequest 获取并验证文本请求 @@ -22,42 +22,31 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi // return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) //} - // map model name - modelMapping := c.GetString("model_mapping") - //isModelMapped := false - if modelMapping != "" && modelMapping != "{}" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[relayInfo.OriginModelName] != "" { - relayInfo.UpstreamModelName = modelMap[relayInfo.OriginModelName] - // set upstream model name - //isModelMapped = true - } + err := helper.ModelMappedHelper(c, relayInfo, nil) + if err != nil { + return types.NewError(err, types.ErrorCodeChannelModelMappedError) } priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0) if err != nil { - return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeModelPriceError) } // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if openaiErr != nil { - return openaiErr + preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + if newAPIError != nil { + return newAPIError } defer func() { - if openaiErr != nil { + if newAPIError != nil { returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) } }() adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) } adaptor.Init(relayInfo) //var requestBody io.Reader @@ -67,7 +56,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi statusCodeMappingStr := c.GetString("status_code_mapping") resp, err := adaptor.DoRequest(c, relayInfo, nil) if err != nil { - return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + return types.NewError(err, types.ErrorCodeDoRequestFailed) } if resp != nil { @@ -75,11 +64,11 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi defer relayInfo.TargetWs.Close() } - usage, openaiErr := adaptor.DoResponse(c, nil, relayInfo) - if openaiErr != nil { + usage, newAPIError := adaptor.DoResponse(c, nil, relayInfo) + if newAPIError != nil { // reset status code 重置状态码 - service.ResetStatusCode(openaiErr, statusCodeMappingStr) - return openaiErr + service.ResetStatusCode(newAPIError, statusCodeMappingStr) + return newAPIError } service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota, userQuota, priceData, "") diff --git a/router/relay-router.go b/router/relay-router.go index b48c9dc7..5b293dbd 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -20,7 +20,7 @@ func SetRelayRouter(router *gin.Engine) { modelsRouter.GET("/:model", controller.RetrieveModel) } playgroundRouter := router.Group("/pg") - playgroundRouter.Use(middleware.UserAuth()) + playgroundRouter.Use(middleware.UserAuth(), middleware.Distribute()) { playgroundRouter.POST("/chat/completions", controller.Playground) } diff --git a/service/channel.go b/service/channel.go index d50de78d..4d38e6ed 100644 --- a/service/channel.go +++ b/service/channel.go @@ -8,6 +8,7 @@ import ( "one-api/dto" "one-api/model" "one-api/setting/operation_setting" + "one-api/types" "strings" ) @@ -16,17 +17,17 @@ func formatNotifyType(channelId int, status int) string { } // disable & notify -func DisableChannel(channelId int, channelName string, reason string) { - success := model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason) +func DisableChannel(channelError types.ChannelError, reason string) { + success := model.UpdateChannelStatus(channelError.ChannelId, channelError.UsingKey, common.ChannelStatusAutoDisabled, reason) if success { - subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) - content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) - NotifyRootUser(formatNotifyType(channelId, common.ChannelStatusAutoDisabled), subject, content) + subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelError.ChannelName, channelError.ChannelId) + content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelError.ChannelName, channelError.ChannelId, reason) + NotifyRootUser(formatNotifyType(channelError.ChannelId, common.ChannelStatusAutoDisabled), subject, content) } } -func EnableChannel(channelId int, channelName string) { - success := model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "") +func EnableChannel(channelId int, usingKey string, channelName string) { + success := model.UpdateChannelStatus(channelId, usingKey, common.ChannelStatusEnabled, "") if success { subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) @@ -34,14 +35,17 @@ func EnableChannel(channelId int, channelName string) { } } -func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) bool { +func ShouldDisableChannel(channelType int, err *types.NewAPIError) bool { if !common.AutomaticDisableChannelEnabled { return false } if err == nil { return false } - if err.LocalError { + if types.IsChannelError(err) { + return true + } + if types.IsLocalError(err) { return false } if err.StatusCode == http.StatusUnauthorized { @@ -53,7 +57,8 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b return true } } - switch err.Error.Code { + oaiErr := err.ToOpenAIError() + switch oaiErr.Code { case "invalid_api_key": return true case "account_deactivated": @@ -63,7 +68,7 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b case "pre_consume_token_quota_failed": return true } - switch err.Error.Type { + switch oaiErr.Type { case "insufficient_quota": return true case "insufficient_user_quota": @@ -77,23 +82,16 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b return true } - lowerMessage := strings.ToLower(err.Error.Message) + lowerMessage := strings.ToLower(err.Error()) search, _ := AcSearch(lowerMessage, operation_setting.AutomaticDisableKeywords, true) - if search { - return true - } - - return false + return search } -func ShouldEnableChannel(err error, openaiWithStatusErr *dto.OpenAIErrorWithStatusCode, status int) bool { +func ShouldEnableChannel(newAPIError *types.NewAPIError, status int) bool { if !common.AutomaticEnableChannelEnabled { return false } - if err != nil { - return false - } - if openaiWithStatusErr != nil { + if newAPIError != nil { return false } if status != common.ChannelStatusAutoDisabled { diff --git a/service/convert.go b/service/convert.go index c97f8475..593b59d9 100644 --- a/service/convert.go +++ b/service/convert.go @@ -163,7 +163,7 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re oaiToolMessage.SetStringContent(mediaMsg.GetStringContent()) } else { mediaContents := mediaMsg.ParseMediaContent() - encodeJson, _ := common.EncodeJson(mediaContents) + encodeJson, _ := common.Marshal(mediaContents) oaiToolMessage.SetStringContent(string(encodeJson)) } openAIMessages = append(openAIMessages, oaiToolMessage) diff --git a/service/error.go b/service/error.go index 21835f2a..e655f448 100644 --- a/service/error.go +++ b/service/error.go @@ -2,11 +2,13 @@ package service import ( "encoding/json" + "errors" "fmt" "io" "net/http" "one-api/common" "one-api/dto" + "one-api/types" "strconv" "strings" ) @@ -25,32 +27,32 @@ func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int) } } -// OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode -func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode { - text := err.Error() - lowerText := strings.ToLower(text) - if !strings.HasPrefix(lowerText, "get file base64 from url") && !strings.HasPrefix(lowerText, "mime type is not supported") { - if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") { - common.SysLog(fmt.Sprintf("error: %s", text)) - text = "请求上游地址失败" - } - } - openAIError := dto.OpenAIError{ - Message: text, - Type: "new_api_error", - Code: code, - } - return &dto.OpenAIErrorWithStatusCode{ - Error: openAIError, - StatusCode: statusCode, - } -} - -func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode { - openaiErr := OpenAIErrorWrapper(err, code, statusCode) - openaiErr.LocalError = true - return openaiErr -} +//// OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode +//func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode { +// text := err.Error() +// lowerText := strings.ToLower(text) +// if !strings.HasPrefix(lowerText, "get file base64 from url") && !strings.HasPrefix(lowerText, "mime type is not supported") { +// if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") { +// common.SysLog(fmt.Sprintf("error: %s", text)) +// text = "请求上游地址失败" +// } +// } +// openAIError := dto.OpenAIError{ +// Message: text, +// Type: "new_api_error", +// Code: code, +// } +// return &dto.OpenAIErrorWithStatusCode{ +// Error: openAIError, +// StatusCode: statusCode, +// } +//} +// +//func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode { +// openaiErr := OpenAIErrorWrapper(err, code, statusCode) +// openaiErr.LocalError = true +// return openaiErr +//} func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode { text := err.Error() @@ -77,43 +79,37 @@ func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.Claude return claudeErr } -func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (errWithStatusCode *dto.OpenAIErrorWithStatusCode) { - errWithStatusCode = &dto.OpenAIErrorWithStatusCode{ +func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) { + newApiErr = &types.NewAPIError{ StatusCode: resp.StatusCode, - Error: dto.OpenAIError{ - Type: "upstream_error", - Code: "bad_response_status_code", - Param: strconv.Itoa(resp.StatusCode), - }, } + responseBody, err := io.ReadAll(resp.Body) if err != nil { return } common.CloseResponseBodyGracefully(resp) var errResponse dto.GeneralErrorResponse - err = json.Unmarshal(responseBody, &errResponse) + + err = common.Unmarshal(responseBody, &errResponse) if err != nil { if showBodyWhenFail { - errWithStatusCode.Error.Message = string(responseBody) + newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)) } else { - errWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) + newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode) } return } if errResponse.Error.Message != "" { - // OpenAI format error, so we override the default one - errWithStatusCode.Error = errResponse.Error + // General format error (OpenAI, Anthropic, Gemini, etc.) + newApiErr = types.WithOpenAIError(errResponse.Error, resp.StatusCode) } else { - errWithStatusCode.Error.Message = errResponse.ToMessage() - } - if errWithStatusCode.Error.Message == "" { - errWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) + newApiErr = types.NewErrorWithStatusCode(errors.New(errResponse.ToMessage()), types.ErrorCodeBadResponseStatusCode, resp.StatusCode) } return } -func ResetStatusCode(openaiErr *dto.OpenAIErrorWithStatusCode, statusCodeMappingStr string) { +func ResetStatusCode(newApiErr *types.NewAPIError, statusCodeMappingStr string) { if statusCodeMappingStr == "" || statusCodeMappingStr == "{}" { return } @@ -122,13 +118,13 @@ func ResetStatusCode(openaiErr *dto.OpenAIErrorWithStatusCode, statusCodeMapping if err != nil { return } - if openaiErr.StatusCode == http.StatusOK { + if newApiErr.StatusCode == http.StatusOK { return } - codeStr := strconv.Itoa(openaiErr.StatusCode) + codeStr := strconv.Itoa(newApiErr.StatusCode) if _, ok := statusCodeMapping[codeStr]; ok { intCode, _ := strconv.Atoi(statusCodeMapping[codeStr]) - openaiErr.StatusCode = intCode + newApiErr.StatusCode = intCode } } diff --git a/service/log_info_generate.go b/service/log_info_generate.go index affae5fb..020a2ba9 100644 --- a/service/log_info_generate.go +++ b/service/log_info_generate.go @@ -1,6 +1,8 @@ package service import ( + "one-api/common" + "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" "one-api/relay/helper" @@ -28,6 +30,11 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m } adminInfo := make(map[string]interface{}) adminInfo["use_channel"] = ctx.GetStringSlice("use_channel") + isMultiKey := common.GetContextKeyBool(ctx, constant.ContextKeyChannelIsMultiKey) + if isMultiKey { + adminInfo["is_multi_key"] = true + adminInfo["multi_key_index"] = common.GetContextKeyInt(ctx, constant.ContextKeyChannelMultiKeyIndex) + } other["admin_info"] = adminInfo return other } diff --git a/service/midjourney.go b/service/midjourney.go index 83404bd9..1fc19682 100644 --- a/service/midjourney.go +++ b/service/midjourney.go @@ -204,7 +204,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU req = req.WithContext(ctx) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) - auth := c.Request.Header.Get("Authorization") + auth := common.GetContextKeyString(c, constant.ContextKeyChannelKey) if auth != "" { auth = strings.TrimPrefix(auth, "Bearer ") req.Header.Set("mj-api-secret", auth) diff --git a/types/channel_error.go b/types/channel_error.go new file mode 100644 index 00000000..f2d72bf5 --- /dev/null +++ b/types/channel_error.go @@ -0,0 +1,21 @@ +package types + +type ChannelError struct { + ChannelId int `json:"channel_id"` + ChannelType int `json:"channel_type"` + ChannelName string `json:"channel_name"` + IsMultiKey bool `json:"is_multi_key"` + AutoBan bool `json:"auto_ban"` + UsingKey string `json:"using_key"` +} + +func NewChannelError(channelId int, channelType int, channelName string, isMultiKey bool, usingKey string, autoBan bool) *ChannelError { + return &ChannelError{ + ChannelId: channelId, + ChannelType: channelType, + ChannelName: channelName, + IsMultiKey: isMultiKey, + AutoBan: autoBan, + UsingKey: usingKey, + } +} diff --git a/types/error.go b/types/error.go new file mode 100644 index 00000000..7ef770ec --- /dev/null +++ b/types/error.go @@ -0,0 +1,195 @@ +package types + +import ( + "errors" + "fmt" + "net/http" + "strings" +) + +type OpenAIError struct { + Message string `json:"message"` + Type string `json:"type"` + Param string `json:"param"` + Code any `json:"code"` +} + +type ClaudeError struct { + Message string `json:"message,omitempty"` + Type string `json:"type,omitempty"` +} + +type ErrorType string + +const ( + ErrorTypeNewAPIError ErrorType = "new_api_error" + ErrorTypeOpenAIError ErrorType = "openai_error" + ErrorTypeClaudeError ErrorType = "claude_error" + ErrorTypeMidjourneyError ErrorType = "midjourney_error" + ErrorTypeGeminiError ErrorType = "gemini_error" + ErrorTypeRerankError ErrorType = "rerank_error" +) + +type ErrorCode string + +const ( + ErrorCodeInvalidRequest ErrorCode = "invalid_request" + ErrorCodeSensitiveWordsDetected ErrorCode = "sensitive_words_detected" + + // new api error + ErrorCodeCountTokenFailed ErrorCode = "count_token_failed" + ErrorCodeModelPriceError ErrorCode = "model_price_error" + ErrorCodeInvalidApiType ErrorCode = "invalid_api_type" + ErrorCodeJsonMarshalFailed ErrorCode = "json_marshal_failed" + ErrorCodeDoRequestFailed ErrorCode = "do_request_failed" + ErrorCodeGetChannelFailed ErrorCode = "get_channel_failed" + + // channel error + ErrorCodeChannelNoAvailableKey ErrorCode = "channel:no_available_key" + ErrorCodeChannelParamOverrideInvalid ErrorCode = "channel:param_override_invalid" + ErrorCodeChannelModelMappedError ErrorCode = "channel:model_mapped_error" + ErrorCodeChannelAwsClientError ErrorCode = "channel:aws_client_error" + ErrorCodeChannelInvalidKey ErrorCode = "channel:invalid_key" + ErrorCodeChannelResponseTimeExceeded ErrorCode = "channel:response_time_exceeded" + + // client request error + ErrorCodeReadRequestBodyFailed ErrorCode = "read_request_body_failed" + ErrorCodeConvertRequestFailed ErrorCode = "convert_request_failed" + ErrorCodeAccessDenied ErrorCode = "access_denied" + + // response error + ErrorCodeReadResponseBodyFailed ErrorCode = "read_response_body_failed" + ErrorCodeBadResponseStatusCode ErrorCode = "bad_response_status_code" + ErrorCodeBadResponse ErrorCode = "bad_response" + ErrorCodeBadResponseBody ErrorCode = "bad_response_body" + + // sql error + ErrorCodeQueryDataError ErrorCode = "query_data_error" + ErrorCodeUpdateDataError ErrorCode = "update_data_error" + + // quota error + ErrorCodeInsufficientUserQuota ErrorCode = "insufficient_user_quota" + ErrorCodePreConsumeTokenQuotaFailed ErrorCode = "pre_consume_token_quota_failed" +) + +type NewAPIError struct { + Err error + RelayError any + ErrorType ErrorType + errorCode ErrorCode + StatusCode int +} + +func (e *NewAPIError) GetErrorCode() ErrorCode { + if e == nil { + return "" + } + return e.errorCode +} + +func (e *NewAPIError) Error() string { + return e.Err.Error() +} + +func (e *NewAPIError) SetMessage(message string) { + e.Err = errors.New(message) +} + +func (e *NewAPIError) ToOpenAIError() OpenAIError { + switch e.ErrorType { + case ErrorTypeOpenAIError: + return e.RelayError.(OpenAIError) + case ErrorTypeClaudeError: + claudeError := e.RelayError.(ClaudeError) + return OpenAIError{ + Message: e.Error(), + Type: claudeError.Type, + Param: "", + Code: e.errorCode, + } + default: + return OpenAIError{ + Message: e.Error(), + Type: string(e.ErrorType), + Param: "", + Code: e.errorCode, + } + } +} + +func (e *NewAPIError) ToClaudeError() ClaudeError { + switch e.ErrorType { + case ErrorTypeOpenAIError: + openAIError := e.RelayError.(OpenAIError) + return ClaudeError{ + Message: e.Error(), + Type: fmt.Sprintf("%v", openAIError.Code), + } + case ErrorTypeClaudeError: + return e.RelayError.(ClaudeError) + default: + return ClaudeError{ + Message: e.Error(), + Type: string(e.ErrorType), + } + } +} + +func NewError(err error, errorCode ErrorCode) *NewAPIError { + return &NewAPIError{ + Err: err, + RelayError: nil, + ErrorType: ErrorTypeNewAPIError, + StatusCode: http.StatusInternalServerError, + errorCode: errorCode, + } +} + +func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int) *NewAPIError { + return &NewAPIError{ + Err: err, + RelayError: nil, + ErrorType: ErrorTypeNewAPIError, + StatusCode: statusCode, + errorCode: errorCode, + } +} + +func WithOpenAIError(openAIError OpenAIError, statusCode int) *NewAPIError { + code, ok := openAIError.Code.(string) + if !ok { + code = fmt.Sprintf("%v", openAIError.Code) + } + return &NewAPIError{ + RelayError: openAIError, + ErrorType: ErrorTypeOpenAIError, + StatusCode: statusCode, + Err: errors.New(openAIError.Message), + errorCode: ErrorCode(code), + } +} + +func WithClaudeError(claudeError ClaudeError, statusCode int) *NewAPIError { + return &NewAPIError{ + RelayError: claudeError, + ErrorType: ErrorTypeClaudeError, + StatusCode: statusCode, + Err: errors.New(claudeError.Message), + errorCode: ErrorCode(claudeError.Type), + } +} + +func IsChannelError(err *NewAPIError) bool { + if err == nil { + return false + } + return strings.HasPrefix(string(err.errorCode), "channel:") +} + +func IsLocalError(err *NewAPIError) bool { + if err == nil { + return false + } + + return err.ErrorType == ErrorTypeNewAPIError +} diff --git a/web/src/components/table/ChannelsTable.js b/web/src/components/table/ChannelsTable.js index df0838a7..19d759cd 100644 --- a/web/src/components/table/ChannelsTable.js +++ b/web/src/components/table/ChannelsTable.js @@ -42,18 +42,20 @@ import { IconTreeTriangleDown, IconSearch, IconMore, + IconList, IconDescend2 } from '@douyinfe/semi-icons'; import { loadChannelModels, isMobile, copy } from '../../helpers'; import EditTagModal from '../../pages/Channel/EditTagModal.js'; import { useTranslation } from 'react-i18next'; import { useTableCompactMode } from '../../hooks/useTableCompactMode'; +import { FaRandom } from 'react-icons/fa'; const ChannelsTable = () => { const { t } = useTranslation(); let type2label = undefined; - const renderType = (type) => { + const renderType = (type, channelInfo = undefined) => { if (!type2label) { type2label = new Map(); for (let i = 0; i < CHANNEL_OPTIONS.length; i++) { @@ -61,11 +63,30 @@ const ChannelsTable = () => { } type2label[0] = { value: 0, label: t('未知类型'), color: 'grey' }; } + + let icon = getChannelIcon(type); + + if (channelInfo?.is_multi_key) { + icon = ( + channelInfo?.multi_key_mode === 'random' ? ( +
+ + {icon} +
+ ) : ( +
+ + {icon} +
+ ) + ) + } + return ( {type2label[type]?.label} @@ -84,7 +105,19 @@ const ChannelsTable = () => { ); }; - const renderStatus = (status) => { + const renderStatus = (status, channelInfo = undefined) => { + if (channelInfo) { + if (channelInfo.is_multi_key) { + let keySize = channelInfo.multi_key_size; + let enabledKeySize = keySize; + if (channelInfo.multi_key_status_list) { + // multi_key_status_list is a map, key is key, value is status + // get multi_key_status_list length + enabledKeySize = keySize - Object.keys(channelInfo.multi_key_status_list).length; + } + return renderMultiKeyStatus(status, keySize, enabledKeySize); + } + } switch (status) { case 1: return ( @@ -113,6 +146,36 @@ const ChannelsTable = () => { } }; + const renderMultiKeyStatus = (status, keySize, enabledKeySize) => { + switch (status) { + case 1: + return ( + + {t('已启用')} {enabledKeySize}/{keySize} + + ); + case 2: + return ( + + {t('已禁用')} {enabledKeySize}/{keySize} + + ); + case 3: + return ( + + {t('自动禁用')} {enabledKeySize}/{keySize} + + ); + default: + return ( + + {t('未知状态')} {enabledKeySize}/{keySize} + + ); + } + } + + const renderResponseTime = (responseTime) => { let time = responseTime / 1000; time = time.toFixed(2) + t(' 秒'); @@ -279,6 +342,11 @@ const ChannelsTable = () => { dataIndex: 'type', render: (text, record, index) => { if (record.children === undefined) { + if (record.channel_info) { + if (record.channel_info.is_multi_key) { + return <>{renderType(text, record.channel_info)}; + } + } return <>{renderType(text)}; } else { return <>{renderTagType()}; @@ -302,12 +370,12 @@ const ChannelsTable = () => { - {renderStatus(text)} + {renderStatus(text, record.channel_info)} ); } else { - return renderStatus(text); + return renderStatus(text, record.channel_info); } }, }, @@ -524,24 +592,70 @@ const ChannelsTable = () => { /> - {record.status === 1 ? ( - + { + record.status === 1 ? ( + + ) : ( + + ) + } + manageChannel(record.id, 'enable_all', record), + } + ]} + > + + record.status === 1 ? ( + + ) : ( + + ) )}