From cd8c23c0abe70d8f962b9f5021f252c3b1afa9cb Mon Sep 17 00:00:00 2001 From: CaIon Date: Thu, 10 Jul 2025 17:49:53 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat(channel):=20enhance=20channel?= =?UTF-8?q?=20status=20management?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- constant/context_key.go | 1 + controller/channel-billing.go | 8 +- controller/channel-test.go | 157 +++++++++++++++++----- controller/channel.go | 2 + controller/relay.go | 55 ++++---- middleware/distributor.go | 14 +- model/cache.go | 13 ++ model/channel.go | 112 ++++++++++----- relay/channel/tencent/adaptor.go | 3 +- relay/common/relay_info.go | 2 +- relay/relay-mj.go | 2 +- service/channel.go | 19 ++- service/midjourney.go | 2 +- types/channel_error.go | 21 +++ types/error.go | 1 + web/src/components/table/ChannelsTable.js | 70 +++++++++- 16 files changed, 363 insertions(+), 119 deletions(-) create mode 100644 types/channel_error.go diff --git a/constant/context_key.go b/constant/context_key.go index d58f1205..36fe39ed 100644 --- a/constant/context_key.go +++ b/constant/context_key.go @@ -29,6 +29,7 @@ const ( ContextKeyChannelModelMapping ContextKey = "model_mapping" ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping" ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key" + ContextKeyChannelKey ContextKey = "channel_key" /* user related keys */ ContextKeyUserId ContextKey = "id" diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 3c92c78b..a93f6c64 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -452,14 +452,14 @@ func updateAllChannelsBalance() error { //if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom { // continue //} - balance, err := updateChannelBalance(channel) + _, err := updateChannelBalance(channel) if err != nil { continue } else { // err is nil & balance <= 0 means quota is used up - if balance <= 0 { - service.DisableChannel(channel.Id, channel.Name, "余额不足") - } + //if balance <= 0 { + // service.DisableChannel(channel.Id, channel.Name, "余额不足") + //} } time.Sleep(common.RequestInterval) } diff --git a/controller/channel-test.go b/controller/channel-test.go index 89c1a133..82bb1d7f 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -30,22 +30,43 @@ import ( "github.com/gin-gonic/gin" ) -func testChannel(channel *model.Channel, testModel string) (err error, newAPIError *types.NewAPIError) { +type testResult struct { + context *gin.Context + localErr error + newAPIError *types.NewAPIError +} + +func testChannel(channel *model.Channel, testModel string) testResult { tik := time.Now() if channel.Type == constant.ChannelTypeMidjourney { - return errors.New("midjourney channel test is not supported"), nil + return testResult{ + localErr: errors.New("midjourney channel test is not supported"), + newAPIError: nil, + } } if channel.Type == constant.ChannelTypeMidjourneyPlus { - return errors.New("midjourney plus channel test is not supported"), nil + return testResult{ + localErr: errors.New("midjourney plus channel test is not supported"), + newAPIError: nil, + } } if channel.Type == constant.ChannelTypeSunoAPI { - return errors.New("suno channel test is not supported"), nil + return testResult{ + localErr: errors.New("suno channel test is not supported"), + newAPIError: nil, + } } if channel.Type == constant.ChannelTypeKling { - return errors.New("kling channel test is not supported"), nil + return testResult{ + localErr: errors.New("kling channel test is not supported"), + newAPIError: nil, + } } if channel.Type == constant.ChannelTypeJimeng { - return errors.New("jimeng channel test is not supported"), nil + return testResult{ + localErr: errors.New("jimeng channel test is not supported"), + newAPIError: nil, + } } w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -82,7 +103,10 @@ func testChannel(channel *model.Channel, testModel string) (err error, newAPIErr cache, err := model.GetUserCache(1) if err != nil { - return err, nil + return testResult{ + localErr: err, + newAPIError: nil, + } } cache.WriteContext(c) @@ -93,20 +117,35 @@ func testChannel(channel *model.Channel, testModel string) (err error, newAPIErr group, _ := model.GetUserGroup(1, false) c.Set("group", group) - middleware.SetupContextForSelectedChannel(c, channel, testModel) + newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel) + if newAPIError != nil { + return testResult{ + context: c, + localErr: newAPIError, + newAPIError: newAPIError, + } + } info := relaycommon.GenRelayInfo(c) err = helper.ModelMappedHelper(c, info, nil) if err != nil { - return err, types.NewError(err, types.ErrorCodeChannelModelMappedError) + return testResult{ + context: c, + localErr: err, + newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError), + } } testModel = info.UpstreamModelName apiType, _ := common.ChannelType2APIType(channel.Type) adaptor := relay.GetAdaptor(apiType) if adaptor == nil { - return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType) + return testResult{ + context: c, + localErr: fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), + newAPIError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType), + } } request := buildTestRequest(testModel) @@ -117,45 +156,77 @@ func testChannel(channel *model.Channel, testModel string) (err error, newAPIErr priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens)) if err != nil { - return err, types.NewError(err, types.ErrorCodeModelPriceError) + return testResult{ + context: c, + localErr: err, + newAPIError: types.NewError(err, types.ErrorCodeModelPriceError), + } } adaptor.Init(info) convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request) if err != nil { - return err, types.NewError(err, types.ErrorCodeConvertRequestFailed) + return testResult{ + context: c, + localErr: err, + newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed), + } } jsonData, err := json.Marshal(convertedRequest) if err != nil { - return err, types.NewError(err, types.ErrorCodeJsonMarshalFailed) + return testResult{ + context: c, + localErr: err, + newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed), + } } requestBody := bytes.NewBuffer(jsonData) c.Request.Body = io.NopCloser(requestBody) resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { - return err, types.NewError(err, types.ErrorCodeDoRequestFailed) + return testResult{ + context: c, + localErr: err, + newAPIError: types.NewError(err, types.ErrorCodeDoRequestFailed), + } } var httpResp *http.Response if resp != nil { httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { err := service.RelayErrorHandler(httpResp, true) - return err, types.NewError(err, types.ErrorCodeBadResponse) + return testResult{ + context: c, + localErr: err, + newAPIError: types.NewError(err, types.ErrorCodeBadResponse), + } } } usageA, respErr := adaptor.DoResponse(c, httpResp, info) if respErr != nil { - return respErr, respErr + return testResult{ + context: c, + localErr: respErr, + newAPIError: respErr, + } } if usageA == nil { - return errors.New("usage is nil"), types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody) + return testResult{ + context: c, + localErr: errors.New("usage is nil"), + newAPIError: types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody), + } } usage := usageA.(*dto.Usage) result := w.Result() respBody, err := io.ReadAll(result.Body) if err != nil { - return err, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return testResult{ + context: c, + localErr: err, + newAPIError: types.NewError(err, types.ErrorCodeReadResponseBodyFailed), + } } info.PromptTokens = usage.PromptTokens @@ -188,7 +259,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, newAPIErr Other: other, }) common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) - return nil, nil + return testResult{ + context: c, + localErr: nil, + newAPIError: nil, + } } func buildTestRequest(model string) *dto.GeneralOpenAIRequest { @@ -247,15 +322,23 @@ func TestChannel(c *gin.Context) { } testModel := c.Query("model") tik := time.Now() - _, newAPIError := testChannel(channel, testModel) + result := testChannel(channel, testModel) + if result.localErr != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": result.localErr.Error(), + "time": 0.0, + }) + return + } tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() go channel.UpdateResponseTime(milliseconds) consumedTime := float64(milliseconds) / 1000.0 - if newAPIError != nil { + if result.newAPIError != nil { c.JSON(http.StatusOK, gin.H{ "success": false, - "message": newAPIError.Error(), + "message": result.newAPIError.Error(), "time": consumedTime, }) return @@ -280,9 +363,9 @@ func testAllChannels(notify bool) error { } testAllChannelsRunning = true testAllChannelsLock.Unlock() - channels, err := model.GetAllChannels(0, 0, true, false) - if err != nil { - return err + channels, getChannelErr := model.GetAllChannels(0, 0, true, false) + if getChannelErr != nil { + return getChannelErr } var disableThreshold = int64(common.ChannelDisableThreshold * 1000) if disableThreshold == 0 { @@ -299,30 +382,34 @@ func testAllChannels(notify bool) error { for _, channel := range channels { isChannelEnabled := channel.Status == common.ChannelStatusEnabled tik := time.Now() - err, newAPIError := testChannel(channel, "") + result := testChannel(channel, "") tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() shouldBanChannel := false - + newAPIError := result.newAPIError // request error disables the channel - if err != nil { - shouldBanChannel = service.ShouldDisableChannel(channel.Type, newAPIError) + if newAPIError != nil { + shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError) } - if milliseconds > disableThreshold { - err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) - shouldBanChannel = true + // 当错误检查通过,才检查响应时间 + if common.AutomaticDisableChannelEnabled && !shouldBanChannel { + if milliseconds > disableThreshold { + err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) + newAPIError = types.NewError(err, types.ErrorCodeChannelResponseTimeExceeded) + shouldBanChannel = true + } } // disable channel if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() { - service.DisableChannel(channel.Id, channel.Name, err.Error()) + go processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) } // enable channel - if !isChannelEnabled && service.ShouldEnableChannel(err, newAPIError, channel.Status) { - service.EnableChannel(channel.Id, channel.Name) + if !isChannelEnabled && service.ShouldEnableChannel(newAPIError, channel.Status) { + service.EnableChannel(channel.Id, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.Name) } channel.UpdateResponseTime(milliseconds) diff --git a/controller/channel.go b/controller/channel.go index c9f20fa5..6e387064 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -497,6 +497,7 @@ func AddChannel(c *gin.Context) { }) return } + addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(array) addChannelRequest.Channel.Key = strings.Join(array, "\n") } else { cleanKeys := make([]string, 0) @@ -507,6 +508,7 @@ func AddChannel(c *gin.Context) { key = strings.TrimSpace(key) cleanKeys = append(cleanKeys, key) } + addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys) addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n") } keys = []string{addChannelRequest.Channel.Key} diff --git a/controller/relay.go b/controller/relay.go index 018351d2..b224b42c 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -80,7 +80,7 @@ func Relay(c *gin.Context) { channel, err := getChannel(c, group, originalModel, i) if err != nil { common.LogError(c, err.Error()) - newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed) + newAPIError = err break } @@ -90,7 +90,7 @@ func Relay(c *gin.Context) { return // 成功处理请求,直接返回 } - go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), newAPIError) + go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) if !shouldRetry(c, newAPIError, common.RetryTimes-i) { break @@ -103,10 +103,10 @@ func Relay(c *gin.Context) { } if newAPIError != nil { - if newAPIError.StatusCode == http.StatusTooManyRequests { - common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error())) - newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试") - } + //if newAPIError.StatusCode == http.StatusTooManyRequests { + // common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error())) + // newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试") + //} newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId)) c.JSON(newAPIError.StatusCode, gin.H{ "error": newAPIError.ToOpenAIError(), @@ -143,7 +143,7 @@ func WssRelay(c *gin.Context) { channel, err := getChannel(c, group, originalModel, i) if err != nil { common.LogError(c, err.Error()) - newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed) + newAPIError = err break } @@ -153,7 +153,7 @@ func WssRelay(c *gin.Context) { return // 成功处理请求,直接返回 } - go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), newAPIError) + go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) if !shouldRetry(c, newAPIError, common.RetryTimes-i) { break @@ -166,9 +166,9 @@ func WssRelay(c *gin.Context) { } if newAPIError != nil { - if newAPIError.StatusCode == http.StatusTooManyRequests { - newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试") - } + //if newAPIError.StatusCode == http.StatusTooManyRequests { + // newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试") + //} newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId)) helper.WssError(c, ws, newAPIError.ToOpenAIError()) } @@ -185,7 +185,7 @@ func RelayClaude(c *gin.Context) { channel, err := getChannel(c, group, originalModel, i) if err != nil { common.LogError(c, err.Error()) - newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed) + newAPIError = err break } @@ -195,7 +195,7 @@ func RelayClaude(c *gin.Context) { return // 成功处理请求,直接返回 } - go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), newAPIError) + go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) if !shouldRetry(c, newAPIError, common.RetryTimes-i) { break @@ -243,7 +243,7 @@ func addUsedChannel(c *gin.Context, channelId int) { c.Set("use_channel", useChannel) } -func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) { +func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, *types.NewAPIError) { if retryCount == 0 { autoBan := c.GetBool("auto_ban") autoBanInt := 1 @@ -260,11 +260,14 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount) if err != nil { if group == "auto" { - return nil, errors.New(fmt.Sprintf("获取自动分组下模型 %s 的可用渠道失败: %s", originalModel, err.Error())) + return nil, types.NewError(errors.New(fmt.Sprintf("获取自动分组下模型 %s 的可用渠道失败: %s", originalModel, err.Error())), types.ErrorCodeGetChannelFailed) } - return nil, errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败: %s", selectGroup, originalModel, err.Error())) + return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败: %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed) + } + newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel) + if newAPIError != nil { + return nil, newAPIError } - middleware.SetupContextForSelectedChannel(c, channel, originalModel) return channel, nil } @@ -314,12 +317,12 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b return true } -func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *types.NewAPIError) { +func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) { // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况 // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously - common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error())) - if service.ShouldDisableChannel(channelType, err) && autoBan { - service.DisableChannel(channelId, channelName, err.Error()) + common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error())) + if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan { + service.DisableChannel(channelError, err.Error()) } } @@ -392,10 +395,10 @@ func RelayTask(c *gin.Context) { retryTimes = 0 } for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ { - channel, err := getChannel(c, group, originalModel, i) - if err != nil { - common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) - taskErr = service.TaskErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError) + channel, newAPIError := getChannel(c, group, originalModel, i) + if newAPIError != nil { + common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error())) + taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError) break } channelId = channel.Id @@ -405,7 +408,7 @@ func RelayTask(c *gin.Context) { common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) //middleware.SetupContextForSelectedChannel(c, channel, originalModel) - requestBody, err := common.GetRequestBody(c) + requestBody, _ := common.GetRequestBody(c) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) taskErr = taskRelayHandler(c, relayMode) } diff --git a/middleware/distributor.go b/middleware/distributor.go index 18959e61..7c30daf3 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -12,6 +12,7 @@ import ( "one-api/service" "one-api/setting" "one-api/setting/ratio_setting" + "one-api/types" "strconv" "strings" "time" @@ -249,10 +250,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { return &modelRequest, shouldSelectChannel, nil } -func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) { +func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError { c.Set("original_model", modelName) // for retry if channel == nil { - return + return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed) } common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id) common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name) @@ -270,7 +271,13 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true) } - c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) + + key, newAPIError := channel.GetNextEnabledKey() + if newAPIError != nil { + return newAPIError + } + // c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key)) + common.SetContextKey(c, constant.ContextKeyChannelKey, key) common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL()) // TODO: api_version统一 @@ -292,6 +299,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode case constant.ChannelTypeCoze: c.Set("bot_id", channel.Other) } + return nil } // extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名 diff --git a/model/cache.go b/model/cache.go index 3e5eb4c4..340d14ff 100644 --- a/model/cache.go +++ b/model/cache.go @@ -203,3 +203,16 @@ func CacheUpdateChannelStatus(id int, status int) { channel.Status = status } } + +func CacheUpdateChannel(channel *Channel) { + if !common.MemoryCacheEnabled { + return + } + channelSyncLock.Lock() + defer channelSyncLock.Unlock() + + if channel == nil { + return + } + channelsIDM[channel.Id] = channel +} diff --git a/model/channel.go b/model/channel.go index 9d2ad853..fea4ce61 100644 --- a/model/channel.go +++ b/model/channel.go @@ -3,11 +3,12 @@ package model import ( "database/sql/driver" "encoding/json" - "fmt" + "errors" "math/rand" "one-api/common" "one-api/constant" "one-api/dto" + "one-api/types" "strings" "sync" @@ -48,6 +49,7 @@ type Channel struct { type ChannelInfo struct { IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式 + MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量 MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引 MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"` @@ -73,7 +75,7 @@ func (channel *Channel) getKeys() []string { return keys } -func (channel *Channel) GetNextEnabledKey() (string, error) { +func (channel *Channel) GetNextEnabledKey() (string, *types.NewAPIError) { // If not in multi-key mode, return the original key string directly. if !channel.ChannelInfo.IsMultiKey { return channel.Key, nil @@ -83,7 +85,7 @@ func (channel *Channel) GetNextEnabledKey() (string, error) { keys := channel.getKeys() if len(keys) == 0 { // No keys available, return error, should disable the channel - return "", fmt.Errorf("no valid keys in channel") + return "", types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey) } statusList := channel.ChannelInfo.MultiKeyStatusList @@ -404,48 +406,94 @@ func (channel *Channel) Delete() error { var channelStatusLock sync.Mutex -func UpdateChannelStatusById(id int, status int, reason string) bool { +func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) { + keys := channel.getKeys() + if len(keys) == 0 { + channel.Status = status + } else { + var keyIndex int + for i, key := range keys { + if key == usingKey { + keyIndex = i + break + } + } + if channel.ChannelInfo.MultiKeyStatusList == nil { + channel.ChannelInfo.MultiKeyStatusList = make(map[int]int) + } + if status == common.ChannelStatusEnabled { + delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex) + } else { + channel.ChannelInfo.MultiKeyStatusList[keyIndex] = status + } + if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize { + channel.Status = common.ChannelStatusAutoDisabled + info := channel.GetOtherInfo() + info["status_reason"] = "All keys are disabled" + info["status_time"] = common.GetTimestamp() + channel.SetOtherInfo(info) + } + } +} + +func UpdateChannelStatus(channelId int, usingKey string, status int, reason string) bool { if common.MemoryCacheEnabled { channelStatusLock.Lock() defer channelStatusLock.Unlock() - channelCache, _ := CacheGetChannel(id) - // 如果缓存渠道存在,且状态已是目标状态,直接返回 - if channelCache != nil && channelCache.Status == status { + channelCache, _ := CacheGetChannel(channelId) + if channelCache == nil { return false } - // 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回 - if channelCache == nil && status != common.ChannelStatusEnabled { - return false + if channelCache.ChannelInfo.IsMultiKey { + // 如果是多Key模式,更新缓存中的状态 + handlerMultiKeyUpdate(channelCache, usingKey, status) + CacheUpdateChannel(channelCache) + //return true + } else { + // 如果缓存渠道存在,且状态已是目标状态,直接返回 + if channelCache.Status == status { + return false + } + // 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回 + if status != common.ChannelStatusEnabled { + return false + } + CacheUpdateChannelStatus(channelId, status) } - CacheUpdateChannelStatus(id, status) } - err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) + + shouldUpdateAbilities := false + defer func() { + if shouldUpdateAbilities { + err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled) + if err != nil { + common.SysError("failed to update ability status: " + err.Error()) + } + } + }() + channel, err := GetChannelById(channelId, true) if err != nil { - common.SysError("failed to update ability status: " + err.Error()) return false - } - channel, err := GetChannelById(id, true) - if err != nil { - // find channel by id error, directly update status - result := DB.Model(&Channel{}).Where("id = ?", id).Update("status", status) - if result.Error != nil { - common.SysError("failed to update channel status: " + result.Error.Error()) - return false - } - if result.RowsAffected == 0 { - return false - } } else { if channel.Status == status { return false } - // find channel by id success, update status and other info - info := channel.GetOtherInfo() - info["status_reason"] = reason - info["status_time"] = common.GetTimestamp() - channel.SetOtherInfo(info) - channel.Status = status + + if channel.ChannelInfo.IsMultiKey { + beforeStatus := channel.Status + handlerMultiKeyUpdate(channel, usingKey, status) + if beforeStatus != channel.Status { + shouldUpdateAbilities = true + } + } else { + info := channel.GetOtherInfo() + info["status_reason"] = reason + info["status_time"] = common.GetTimestamp() + channel.SetOtherInfo(info) + channel.Status = status + shouldUpdateAbilities = true + } err = channel.Save() if err != nil { common.SysError("failed to update channel status: " + err.Error()) @@ -628,6 +676,8 @@ func (channel *Channel) GetSetting() dto.ChannelSettings { err := json.Unmarshal([]byte(*channel.Setting), &setting) if err != nil { common.SysError("failed to unmarshal setting: " + err.Error()) + channel.Setting = nil // 清空设置以避免后续错误 + _ = channel.Save() // 保存修改 } } return setting diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index 1a5a55e3..520276a7 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" @@ -63,7 +64,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } - apiKey := c.Request.Header.Get("Authorization") + apiKey := common.GetContextKeyString(c, constant.ContextKeyChannelKey) apiKey = strings.TrimPrefix(apiKey, "Bearer ") appId, secretId, secretKey, err := parseTencentConfig(apiKey) a.AppID = appId diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index beada0ee..5b7dee80 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -247,7 +247,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { IsModelMapped: false, ApiType: apiType, ApiVersion: c.GetString("api_version"), - ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), + ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey), Organization: c.GetString("channel_organization"), ChannelCreateTime: c.GetInt64("channel_create_time"), diff --git a/relay/relay-mj.go b/relay/relay-mj.go index f23f8152..e7f316b9 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -575,7 +575,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons common.SysError("get_channel_null: " + err.Error()) } if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled { - model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance") + model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance") } } if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 { diff --git a/service/channel.go b/service/channel.go index 5c268c87..4d38e6ed 100644 --- a/service/channel.go +++ b/service/channel.go @@ -17,17 +17,17 @@ func formatNotifyType(channelId int, status int) string { } // disable & notify -func DisableChannel(channelId int, channelName string, reason string) { - success := model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason) +func DisableChannel(channelError types.ChannelError, reason string) { + success := model.UpdateChannelStatus(channelError.ChannelId, channelError.UsingKey, common.ChannelStatusAutoDisabled, reason) if success { - subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) - content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) - NotifyRootUser(formatNotifyType(channelId, common.ChannelStatusAutoDisabled), subject, content) + subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelError.ChannelName, channelError.ChannelId) + content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelError.ChannelName, channelError.ChannelId, reason) + NotifyRootUser(formatNotifyType(channelError.ChannelId, common.ChannelStatusAutoDisabled), subject, content) } } -func EnableChannel(channelId int, channelName string) { - success := model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "") +func EnableChannel(channelId int, usingKey string, channelName string) { + success := model.UpdateChannelStatus(channelId, usingKey, common.ChannelStatusEnabled, "") if success { subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) @@ -87,13 +87,10 @@ func ShouldDisableChannel(channelType int, err *types.NewAPIError) bool { return search } -func ShouldEnableChannel(err error, newAPIError *types.NewAPIError, status int) bool { +func ShouldEnableChannel(newAPIError *types.NewAPIError, status int) bool { if !common.AutomaticEnableChannelEnabled { return false } - if err != nil { - return false - } if newAPIError != nil { return false } diff --git a/service/midjourney.go b/service/midjourney.go index 83404bd9..1fc19682 100644 --- a/service/midjourney.go +++ b/service/midjourney.go @@ -204,7 +204,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU req = req.WithContext(ctx) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) - auth := c.Request.Header.Get("Authorization") + auth := common.GetContextKeyString(c, constant.ContextKeyChannelKey) if auth != "" { auth = strings.TrimPrefix(auth, "Bearer ") req.Header.Set("mj-api-secret", auth) diff --git a/types/channel_error.go b/types/channel_error.go new file mode 100644 index 00000000..f2d72bf5 --- /dev/null +++ b/types/channel_error.go @@ -0,0 +1,21 @@ +package types + +type ChannelError struct { + ChannelId int `json:"channel_id"` + ChannelType int `json:"channel_type"` + ChannelName string `json:"channel_name"` + IsMultiKey bool `json:"is_multi_key"` + AutoBan bool `json:"auto_ban"` + UsingKey string `json:"using_key"` +} + +func NewChannelError(channelId int, channelType int, channelName string, isMultiKey bool, usingKey string, autoBan bool) *ChannelError { + return &ChannelError{ + ChannelId: channelId, + ChannelType: channelType, + ChannelName: channelName, + IsMultiKey: isMultiKey, + AutoBan: autoBan, + UsingKey: usingKey, + } +} diff --git a/types/error.go b/types/error.go index 63e79c25..7ef770ec 100644 --- a/types/error.go +++ b/types/error.go @@ -50,6 +50,7 @@ const ( ErrorCodeChannelModelMappedError ErrorCode = "channel:model_mapped_error" ErrorCodeChannelAwsClientError ErrorCode = "channel:aws_client_error" ErrorCodeChannelInvalidKey ErrorCode = "channel:invalid_key" + ErrorCodeChannelResponseTimeExceeded ErrorCode = "channel:response_time_exceeded" // client request error ErrorCodeReadRequestBodyFailed ErrorCode = "read_request_body_failed" diff --git a/web/src/components/table/ChannelsTable.js b/web/src/components/table/ChannelsTable.js index 810993c4..67be1a2b 100644 --- a/web/src/components/table/ChannelsTable.js +++ b/web/src/components/table/ChannelsTable.js @@ -42,6 +42,7 @@ import { IconTreeTriangleDown, IconSearch, IconMore, + IconList } from '@douyinfe/semi-icons'; import { loadChannelModels, isMobile, copy } from '../../helpers'; import EditTagModal from '../../pages/Channel/EditTagModal.js'; @@ -53,7 +54,7 @@ const ChannelsTable = () => { let type2label = undefined; - const renderType = (type) => { + const renderType = (type, multiKey = false) => { if (!type2label) { type2label = new Map(); for (let i = 0; i < CHANNEL_OPTIONS.length; i++) { @@ -61,12 +62,24 @@ const ChannelsTable = () => { } type2label[0] = { value: 0, label: t('未知类型'), color: 'grey' }; } + + let icon = getChannelIcon(type); + + if (multiKey) { + icon = ( +
+ + {icon} +
+ ); + } + return ( {type2label[type]?.label} @@ -86,7 +99,19 @@ const ChannelsTable = () => { ); }; - const renderStatus = (status) => { + const renderStatus = (status, channelInfo = undefined) => { + if (channelInfo) { + if (channelInfo.is_multi_key) { + let keySize = channelInfo.multi_key_size; + let enabledKeySize = keySize; + if (channelInfo.multi_key_status_list) { + // multi_key_status_list is a map, key is key, value is status + // get multi_key_status_list length + enabledKeySize = keySize - Object.keys(channelInfo.multi_key_status_list).length; + } + return renderMultiKeyStatus(status, keySize, enabledKeySize); + } + } switch (status) { case 1: return ( @@ -115,6 +140,36 @@ const ChannelsTable = () => { } }; + const renderMultiKeyStatus = (status, keySize, enabledKeySize) => { + switch (status) { + case 1: + return ( + + {t('已启用')} {enabledKeySize}/{keySize} + + ); + case 2: + return ( + + {t('已禁用')} {enabledKeySize}/{keySize} + + ); + case 3: + return ( + + {t('自动禁用')} {enabledKeySize}/{keySize} + + ); + default: + return ( + + {t('未知状态')} {enabledKeySize}/{keySize} + + ); + } + } + + const renderResponseTime = (responseTime) => { let time = responseTime / 1000; time = time.toFixed(2) + t(' 秒'); @@ -281,6 +336,11 @@ const ChannelsTable = () => { dataIndex: 'type', render: (text, record, index) => { if (record.children === undefined) { + if (record.channel_info) { + if (record.channel_info.is_multi_key) { + return <>{renderType(text, record.channel_info)}; + } + } return <>{renderType(text)}; } else { return <>{renderTagType()}; @@ -304,12 +364,12 @@ const ChannelsTable = () => { - {renderStatus(text)} + {renderStatus(text, record.channel_info)} ); } else { - return renderStatus(text); + return renderStatus(text, record.channel_info); } }, },