feat(channel): enhance channel status management

This commit is contained in:
CaIon
2025-07-10 17:49:53 +08:00
parent a9e03e6172
commit cd8c23c0ab
16 changed files with 363 additions and 119 deletions

View File

@@ -29,6 +29,7 @@ const (
ContextKeyChannelModelMapping ContextKey = "model_mapping" ContextKeyChannelModelMapping ContextKey = "model_mapping"
ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping" ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping"
ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key" ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key"
ContextKeyChannelKey ContextKey = "channel_key"
/* user related keys */ /* user related keys */
ContextKeyUserId ContextKey = "id" ContextKeyUserId ContextKey = "id"

View File

@@ -452,14 +452,14 @@ func updateAllChannelsBalance() error {
//if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom { //if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
// continue // continue
//} //}
balance, err := updateChannelBalance(channel) _, err := updateChannelBalance(channel)
if err != nil { if err != nil {
continue continue
} else { } else {
// err is nil & balance <= 0 means quota is used up // err is nil & balance <= 0 means quota is used up
if balance <= 0 { //if balance <= 0 {
service.DisableChannel(channel.Id, channel.Name, "余额不足") // service.DisableChannel(channel.Id, channel.Name, "余额不足")
} //}
} }
time.Sleep(common.RequestInterval) time.Sleep(common.RequestInterval)
} }

View File

@@ -30,22 +30,43 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func testChannel(channel *model.Channel, testModel string) (err error, newAPIError *types.NewAPIError) { type testResult struct {
context *gin.Context
localErr error
newAPIError *types.NewAPIError
}
func testChannel(channel *model.Channel, testModel string) testResult {
tik := time.Now() tik := time.Now()
if channel.Type == constant.ChannelTypeMidjourney { if channel.Type == constant.ChannelTypeMidjourney {
return errors.New("midjourney channel test is not supported"), nil return testResult{
localErr: errors.New("midjourney channel test is not supported"),
newAPIError: nil,
}
} }
if channel.Type == constant.ChannelTypeMidjourneyPlus { if channel.Type == constant.ChannelTypeMidjourneyPlus {
return errors.New("midjourney plus channel test is not supported"), nil return testResult{
localErr: errors.New("midjourney plus channel test is not supported"),
newAPIError: nil,
}
} }
if channel.Type == constant.ChannelTypeSunoAPI { if channel.Type == constant.ChannelTypeSunoAPI {
return errors.New("suno channel test is not supported"), nil return testResult{
localErr: errors.New("suno channel test is not supported"),
newAPIError: nil,
}
} }
if channel.Type == constant.ChannelTypeKling { if channel.Type == constant.ChannelTypeKling {
return errors.New("kling channel test is not supported"), nil return testResult{
localErr: errors.New("kling channel test is not supported"),
newAPIError: nil,
}
} }
if channel.Type == constant.ChannelTypeJimeng { if channel.Type == constant.ChannelTypeJimeng {
return errors.New("jimeng channel test is not supported"), nil return testResult{
localErr: errors.New("jimeng channel test is not supported"),
newAPIError: nil,
}
} }
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
@@ -82,7 +103,10 @@ func testChannel(channel *model.Channel, testModel string) (err error, newAPIErr
cache, err := model.GetUserCache(1) cache, err := model.GetUserCache(1)
if err != nil { if err != nil {
return err, nil return testResult{
localErr: err,
newAPIError: nil,
}
} }
cache.WriteContext(c) cache.WriteContext(c)
@@ -93,20 +117,35 @@ func testChannel(channel *model.Channel, testModel string) (err error, newAPIErr
group, _ := model.GetUserGroup(1, false) group, _ := model.GetUserGroup(1, false)
c.Set("group", group) c.Set("group", group)
middleware.SetupContextForSelectedChannel(c, channel, testModel) newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
if newAPIError != nil {
return testResult{
context: c,
localErr: newAPIError,
newAPIError: newAPIError,
}
}
info := relaycommon.GenRelayInfo(c) info := relaycommon.GenRelayInfo(c)
err = helper.ModelMappedHelper(c, info, nil) err = helper.ModelMappedHelper(c, info, nil)
if err != nil { if err != nil {
return err, types.NewError(err, types.ErrorCodeChannelModelMappedError) return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
}
} }
testModel = info.UpstreamModelName testModel = info.UpstreamModelName
apiType, _ := common.ChannelType2APIType(channel.Type) apiType, _ := common.ChannelType2APIType(channel.Type)
adaptor := relay.GetAdaptor(apiType) adaptor := relay.GetAdaptor(apiType)
if adaptor == nil { if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType) return testResult{
context: c,
localErr: fmt.Errorf("invalid api type: %d, adaptor is nil", apiType),
newAPIError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType),
}
} }
request := buildTestRequest(testModel) request := buildTestRequest(testModel)
@@ -117,45 +156,77 @@ func testChannel(channel *model.Channel, testModel string) (err error, newAPIErr
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens)) priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
if err != nil { if err != nil {
return err, types.NewError(err, types.ErrorCodeModelPriceError) return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeModelPriceError),
}
} }
adaptor.Init(info) adaptor.Init(info)
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request) convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
if err != nil { if err != nil {
return err, types.NewError(err, types.ErrorCodeConvertRequestFailed) return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
}
} }
jsonData, err := json.Marshal(convertedRequest) jsonData, err := json.Marshal(convertedRequest)
if err != nil { if err != nil {
return err, types.NewError(err, types.ErrorCodeJsonMarshalFailed) return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
}
} }
requestBody := bytes.NewBuffer(jsonData) requestBody := bytes.NewBuffer(jsonData)
c.Request.Body = io.NopCloser(requestBody) c.Request.Body = io.NopCloser(requestBody)
resp, err := adaptor.DoRequest(c, info, requestBody) resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil { if err != nil {
return err, types.NewError(err, types.ErrorCodeDoRequestFailed) return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeDoRequestFailed),
}
} }
var httpResp *http.Response var httpResp *http.Response
if resp != nil { if resp != nil {
httpResp = resp.(*http.Response) httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK { if httpResp.StatusCode != http.StatusOK {
err := service.RelayErrorHandler(httpResp, true) err := service.RelayErrorHandler(httpResp, true)
return err, types.NewError(err, types.ErrorCodeBadResponse) return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeBadResponse),
}
} }
} }
usageA, respErr := adaptor.DoResponse(c, httpResp, info) usageA, respErr := adaptor.DoResponse(c, httpResp, info)
if respErr != nil { if respErr != nil {
return respErr, respErr return testResult{
context: c,
localErr: respErr,
newAPIError: respErr,
}
} }
if usageA == nil { if usageA == nil {
return errors.New("usage is nil"), types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody) return testResult{
context: c,
localErr: errors.New("usage is nil"),
newAPIError: types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody),
}
} }
usage := usageA.(*dto.Usage) usage := usageA.(*dto.Usage)
result := w.Result() result := w.Result()
respBody, err := io.ReadAll(result.Body) respBody, err := io.ReadAll(result.Body)
if err != nil { if err != nil {
return err, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeReadResponseBodyFailed),
}
} }
info.PromptTokens = usage.PromptTokens info.PromptTokens = usage.PromptTokens
@@ -188,7 +259,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, newAPIErr
Other: other, Other: other,
}) })
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
return nil, nil return testResult{
context: c,
localErr: nil,
newAPIError: nil,
}
} }
func buildTestRequest(model string) *dto.GeneralOpenAIRequest { func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
@@ -247,15 +322,23 @@ func TestChannel(c *gin.Context) {
} }
testModel := c.Query("model") testModel := c.Query("model")
tik := time.Now() tik := time.Now()
_, newAPIError := testChannel(channel, testModel) result := testChannel(channel, testModel)
if result.localErr != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": result.localErr.Error(),
"time": 0.0,
})
return
}
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
go channel.UpdateResponseTime(milliseconds) go channel.UpdateResponseTime(milliseconds)
consumedTime := float64(milliseconds) / 1000.0 consumedTime := float64(milliseconds) / 1000.0
if newAPIError != nil { if result.newAPIError != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": newAPIError.Error(), "message": result.newAPIError.Error(),
"time": consumedTime, "time": consumedTime,
}) })
return return
@@ -280,9 +363,9 @@ func testAllChannels(notify bool) error {
} }
testAllChannelsRunning = true testAllChannelsRunning = true
testAllChannelsLock.Unlock() testAllChannelsLock.Unlock()
channels, err := model.GetAllChannels(0, 0, true, false) channels, getChannelErr := model.GetAllChannels(0, 0, true, false)
if err != nil { if getChannelErr != nil {
return err return getChannelErr
} }
var disableThreshold = int64(common.ChannelDisableThreshold * 1000) var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
if disableThreshold == 0 { if disableThreshold == 0 {
@@ -299,30 +382,34 @@ func testAllChannels(notify bool) error {
for _, channel := range channels { for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now() tik := time.Now()
err, newAPIError := testChannel(channel, "") result := testChannel(channel, "")
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
shouldBanChannel := false shouldBanChannel := false
newAPIError := result.newAPIError
// request error disables the channel // request error disables the channel
if err != nil { if newAPIError != nil {
shouldBanChannel = service.ShouldDisableChannel(channel.Type, newAPIError) shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError)
} }
if milliseconds > disableThreshold { // 当错误检查通过,才检查响应时间
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
shouldBanChannel = true if milliseconds > disableThreshold {
err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
newAPIError = types.NewError(err, types.ErrorCodeChannelResponseTimeExceeded)
shouldBanChannel = true
}
} }
// disable channel // disable channel
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() { if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
service.DisableChannel(channel.Id, channel.Name, err.Error()) go processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
} }
// enable channel // enable channel
if !isChannelEnabled && service.ShouldEnableChannel(err, newAPIError, channel.Status) { if !isChannelEnabled && service.ShouldEnableChannel(newAPIError, channel.Status) {
service.EnableChannel(channel.Id, channel.Name) service.EnableChannel(channel.Id, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.Name)
} }
channel.UpdateResponseTime(milliseconds) channel.UpdateResponseTime(milliseconds)

View File

@@ -497,6 +497,7 @@ func AddChannel(c *gin.Context) {
}) })
return return
} }
addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(array)
addChannelRequest.Channel.Key = strings.Join(array, "\n") addChannelRequest.Channel.Key = strings.Join(array, "\n")
} else { } else {
cleanKeys := make([]string, 0) cleanKeys := make([]string, 0)
@@ -507,6 +508,7 @@ func AddChannel(c *gin.Context) {
key = strings.TrimSpace(key) key = strings.TrimSpace(key)
cleanKeys = append(cleanKeys, key) cleanKeys = append(cleanKeys, key)
} }
addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys)
addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n") addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
} }
keys = []string{addChannelRequest.Channel.Key} keys = []string{addChannelRequest.Channel.Key}

View File

@@ -80,7 +80,7 @@ func Relay(c *gin.Context) {
channel, err := getChannel(c, group, originalModel, i) channel, err := getChannel(c, group, originalModel, i)
if err != nil { if err != nil {
common.LogError(c, err.Error()) common.LogError(c, err.Error())
newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed) newAPIError = err
break break
} }
@@ -90,7 +90,7 @@ func Relay(c *gin.Context) {
return // 成功处理请求,直接返回 return // 成功处理请求,直接返回
} }
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), newAPIError) go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
if !shouldRetry(c, newAPIError, common.RetryTimes-i) { if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
break break
@@ -103,10 +103,10 @@ func Relay(c *gin.Context) {
} }
if newAPIError != nil { if newAPIError != nil {
if newAPIError.StatusCode == http.StatusTooManyRequests { //if newAPIError.StatusCode == http.StatusTooManyRequests {
common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error())) // common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error()))
newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试") // newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
} //}
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId)) newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
c.JSON(newAPIError.StatusCode, gin.H{ c.JSON(newAPIError.StatusCode, gin.H{
"error": newAPIError.ToOpenAIError(), "error": newAPIError.ToOpenAIError(),
@@ -143,7 +143,7 @@ func WssRelay(c *gin.Context) {
channel, err := getChannel(c, group, originalModel, i) channel, err := getChannel(c, group, originalModel, i)
if err != nil { if err != nil {
common.LogError(c, err.Error()) common.LogError(c, err.Error())
newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed) newAPIError = err
break break
} }
@@ -153,7 +153,7 @@ func WssRelay(c *gin.Context) {
return // 成功处理请求,直接返回 return // 成功处理请求,直接返回
} }
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), newAPIError) go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
if !shouldRetry(c, newAPIError, common.RetryTimes-i) { if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
break break
@@ -166,9 +166,9 @@ func WssRelay(c *gin.Context) {
} }
if newAPIError != nil { if newAPIError != nil {
if newAPIError.StatusCode == http.StatusTooManyRequests { //if newAPIError.StatusCode == http.StatusTooManyRequests {
newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试") // newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
} //}
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId)) newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
helper.WssError(c, ws, newAPIError.ToOpenAIError()) helper.WssError(c, ws, newAPIError.ToOpenAIError())
} }
@@ -185,7 +185,7 @@ func RelayClaude(c *gin.Context) {
channel, err := getChannel(c, group, originalModel, i) channel, err := getChannel(c, group, originalModel, i)
if err != nil { if err != nil {
common.LogError(c, err.Error()) common.LogError(c, err.Error())
newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed) newAPIError = err
break break
} }
@@ -195,7 +195,7 @@ func RelayClaude(c *gin.Context) {
return // 成功处理请求,直接返回 return // 成功处理请求,直接返回
} }
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), newAPIError) go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
if !shouldRetry(c, newAPIError, common.RetryTimes-i) { if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
break break
@@ -243,7 +243,7 @@ func addUsedChannel(c *gin.Context, channelId int) {
c.Set("use_channel", useChannel) c.Set("use_channel", useChannel)
} }
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) { func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, *types.NewAPIError) {
if retryCount == 0 { if retryCount == 0 {
autoBan := c.GetBool("auto_ban") autoBan := c.GetBool("auto_ban")
autoBanInt := 1 autoBanInt := 1
@@ -260,11 +260,14 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount) channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
if err != nil { if err != nil {
if group == "auto" { if group == "auto" {
return nil, errors.New(fmt.Sprintf("获取自动分组下模型 %s 的可用渠道失败: %s", originalModel, err.Error())) return nil, types.NewError(errors.New(fmt.Sprintf("获取自动分组下模型 %s 的可用渠道失败: %s", originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
} }
return nil, errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败: %s", selectGroup, originalModel, err.Error())) return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败: %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
}
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
if newAPIError != nil {
return nil, newAPIError
} }
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
return channel, nil return channel, nil
} }
@@ -314,12 +317,12 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
return true return true
} }
func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *types.NewAPIError) { func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
// 不要使用context获取渠道信息异步处理时可能会出现渠道信息不一致的情况 // 不要使用context获取渠道信息异步处理时可能会出现渠道信息不一致的情况
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error())) common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
if service.ShouldDisableChannel(channelType, err) && autoBan { if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
service.DisableChannel(channelId, channelName, err.Error()) service.DisableChannel(channelError, err.Error())
} }
} }
@@ -392,10 +395,10 @@ func RelayTask(c *gin.Context) {
retryTimes = 0 retryTimes = 0
} }
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ { for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i) channel, newAPIError := getChannel(c, group, originalModel, i)
if err != nil { if newAPIError != nil {
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
taskErr = service.TaskErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError) taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
break break
} }
channelId = channel.Id channelId = channel.Id
@@ -405,7 +408,7 @@ func RelayTask(c *gin.Context) {
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
//middleware.SetupContextForSelectedChannel(c, channel, originalModel) //middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c) requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
taskErr = taskRelayHandler(c, relayMode) taskErr = taskRelayHandler(c, relayMode)
} }

View File

@@ -12,6 +12,7 @@ import (
"one-api/service" "one-api/service"
"one-api/setting" "one-api/setting"
"one-api/setting/ratio_setting" "one-api/setting/ratio_setting"
"one-api/types"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -249,10 +250,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
return &modelRequest, shouldSelectChannel, nil return &modelRequest, shouldSelectChannel, nil
} }
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) { func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError {
c.Set("original_model", modelName) // for retry c.Set("original_model", modelName) // for retry
if channel == nil { if channel == nil {
return return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed)
} }
common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id) common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id)
common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name) common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name)
@@ -270,7 +271,13 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true) common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
} }
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
key, newAPIError := channel.GetNextEnabledKey()
if newAPIError != nil {
return newAPIError
}
// c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
common.SetContextKey(c, constant.ContextKeyChannelKey, key)
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL()) common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
// TODO: api_version统一 // TODO: api_version统一
@@ -292,6 +299,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
case constant.ChannelTypeCoze: case constant.ChannelTypeCoze:
c.Set("bot_id", channel.Other) c.Set("bot_id", channel.Other)
} }
return nil
} }
// extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名 // extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名

View File

@@ -203,3 +203,16 @@ func CacheUpdateChannelStatus(id int, status int) {
channel.Status = status channel.Status = status
} }
} }
func CacheUpdateChannel(channel *Channel) {
if !common.MemoryCacheEnabled {
return
}
channelSyncLock.Lock()
defer channelSyncLock.Unlock()
if channel == nil {
return
}
channelsIDM[channel.Id] = channel
}

View File

@@ -3,11 +3,12 @@ package model
import ( import (
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"fmt" "errors"
"math/rand" "math/rand"
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/types"
"strings" "strings"
"sync" "sync"
@@ -48,6 +49,7 @@ type Channel struct {
type ChannelInfo struct { type ChannelInfo struct {
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式 IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表key index -> status MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表key index -> status
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引 MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"` MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
@@ -73,7 +75,7 @@ func (channel *Channel) getKeys() []string {
return keys return keys
} }
func (channel *Channel) GetNextEnabledKey() (string, error) { func (channel *Channel) GetNextEnabledKey() (string, *types.NewAPIError) {
// If not in multi-key mode, return the original key string directly. // If not in multi-key mode, return the original key string directly.
if !channel.ChannelInfo.IsMultiKey { if !channel.ChannelInfo.IsMultiKey {
return channel.Key, nil return channel.Key, nil
@@ -83,7 +85,7 @@ func (channel *Channel) GetNextEnabledKey() (string, error) {
keys := channel.getKeys() keys := channel.getKeys()
if len(keys) == 0 { if len(keys) == 0 {
// No keys available, return error, should disable the channel // No keys available, return error, should disable the channel
return "", fmt.Errorf("no valid keys in channel") return "", types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey)
} }
statusList := channel.ChannelInfo.MultiKeyStatusList statusList := channel.ChannelInfo.MultiKeyStatusList
@@ -404,48 +406,94 @@ func (channel *Channel) Delete() error {
var channelStatusLock sync.Mutex var channelStatusLock sync.Mutex
func UpdateChannelStatusById(id int, status int, reason string) bool { func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) {
keys := channel.getKeys()
if len(keys) == 0 {
channel.Status = status
} else {
var keyIndex int
for i, key := range keys {
if key == usingKey {
keyIndex = i
break
}
}
if channel.ChannelInfo.MultiKeyStatusList == nil {
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
}
if status == common.ChannelStatusEnabled {
delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
} else {
channel.ChannelInfo.MultiKeyStatusList[keyIndex] = status
}
if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize {
channel.Status = common.ChannelStatusAutoDisabled
info := channel.GetOtherInfo()
info["status_reason"] = "All keys are disabled"
info["status_time"] = common.GetTimestamp()
channel.SetOtherInfo(info)
}
}
}
func UpdateChannelStatus(channelId int, usingKey string, status int, reason string) bool {
if common.MemoryCacheEnabled { if common.MemoryCacheEnabled {
channelStatusLock.Lock() channelStatusLock.Lock()
defer channelStatusLock.Unlock() defer channelStatusLock.Unlock()
channelCache, _ := CacheGetChannel(id) channelCache, _ := CacheGetChannel(channelId)
// 如果缓存渠道存在,且状态已是目标状态,直接返回 if channelCache == nil {
if channelCache != nil && channelCache.Status == status {
return false return false
} }
// 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回 if channelCache.ChannelInfo.IsMultiKey {
if channelCache == nil && status != common.ChannelStatusEnabled { // 如果是多Key模式更新缓存中的状态
return false handlerMultiKeyUpdate(channelCache, usingKey, status)
CacheUpdateChannel(channelCache)
//return true
} else {
// 如果缓存渠道存在,且状态已是目标状态,直接返回
if channelCache.Status == status {
return false
}
// 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回
if status != common.ChannelStatusEnabled {
return false
}
CacheUpdateChannelStatus(channelId, status)
} }
CacheUpdateChannelStatus(id, status)
} }
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
shouldUpdateAbilities := false
defer func() {
if shouldUpdateAbilities {
err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled)
if err != nil {
common.SysError("failed to update ability status: " + err.Error())
}
}
}()
channel, err := GetChannelById(channelId, true)
if err != nil { if err != nil {
common.SysError("failed to update ability status: " + err.Error())
return false return false
}
channel, err := GetChannelById(id, true)
if err != nil {
// find channel by id error, directly update status
result := DB.Model(&Channel{}).Where("id = ?", id).Update("status", status)
if result.Error != nil {
common.SysError("failed to update channel status: " + result.Error.Error())
return false
}
if result.RowsAffected == 0 {
return false
}
} else { } else {
if channel.Status == status { if channel.Status == status {
return false return false
} }
// find channel by id success, update status and other info
info := channel.GetOtherInfo() if channel.ChannelInfo.IsMultiKey {
info["status_reason"] = reason beforeStatus := channel.Status
info["status_time"] = common.GetTimestamp() handlerMultiKeyUpdate(channel, usingKey, status)
channel.SetOtherInfo(info) if beforeStatus != channel.Status {
channel.Status = status shouldUpdateAbilities = true
}
} else {
info := channel.GetOtherInfo()
info["status_reason"] = reason
info["status_time"] = common.GetTimestamp()
channel.SetOtherInfo(info)
channel.Status = status
shouldUpdateAbilities = true
}
err = channel.Save() err = channel.Save()
if err != nil { if err != nil {
common.SysError("failed to update channel status: " + err.Error()) common.SysError("failed to update channel status: " + err.Error())
@@ -628,6 +676,8 @@ func (channel *Channel) GetSetting() dto.ChannelSettings {
err := json.Unmarshal([]byte(*channel.Setting), &setting) err := json.Unmarshal([]byte(*channel.Setting), &setting)
if err != nil { if err != nil {
common.SysError("failed to unmarshal setting: " + err.Error()) common.SysError("failed to unmarshal setting: " + err.Error())
channel.Setting = nil // 清空设置以避免后续错误
_ = channel.Save() // 保存修改
} }
} }
return setting return setting

View File

@@ -6,6 +6,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
@@ -63,7 +64,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
apiKey := c.Request.Header.Get("Authorization") apiKey := common.GetContextKeyString(c, constant.ContextKeyChannelKey)
apiKey = strings.TrimPrefix(apiKey, "Bearer ") apiKey = strings.TrimPrefix(apiKey, "Bearer ")
appId, secretId, secretKey, err := parseTencentConfig(apiKey) appId, secretId, secretKey, err := parseTencentConfig(apiKey)
a.AppID = appId a.AppID = appId

View File

@@ -247,7 +247,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
IsModelMapped: false, IsModelMapped: false,
ApiType: apiType, ApiType: apiType,
ApiVersion: c.GetString("api_version"), ApiVersion: c.GetString("api_version"),
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey),
Organization: c.GetString("channel_organization"), Organization: c.GetString("channel_organization"),
ChannelCreateTime: c.GetInt64("channel_create_time"), ChannelCreateTime: c.GetInt64("channel_create_time"),

View File

@@ -575,7 +575,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
common.SysError("get_channel_null: " + err.Error()) common.SysError("get_channel_null: " + err.Error())
} }
if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled { if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled {
model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance") model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance")
} }
} }
if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 { if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {

View File

@@ -17,17 +17,17 @@ func formatNotifyType(channelId int, status int) string {
} }
// disable & notify // disable & notify
func DisableChannel(channelId int, channelName string, reason string) { func DisableChannel(channelError types.ChannelError, reason string) {
success := model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason) success := model.UpdateChannelStatus(channelError.ChannelId, channelError.UsingKey, common.ChannelStatusAutoDisabled, reason)
if success { if success {
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId) subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelError.ChannelName, channelError.ChannelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason) content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelError.ChannelName, channelError.ChannelId, reason)
NotifyRootUser(formatNotifyType(channelId, common.ChannelStatusAutoDisabled), subject, content) NotifyRootUser(formatNotifyType(channelError.ChannelId, common.ChannelStatusAutoDisabled), subject, content)
} }
} }
func EnableChannel(channelId int, channelName string) { func EnableChannel(channelId int, usingKey string, channelName string) {
success := model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "") success := model.UpdateChannelStatus(channelId, usingKey, common.ChannelStatusEnabled, "")
if success { if success {
subject := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId) subject := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId) content := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
@@ -87,13 +87,10 @@ func ShouldDisableChannel(channelType int, err *types.NewAPIError) bool {
return search return search
} }
func ShouldEnableChannel(err error, newAPIError *types.NewAPIError, status int) bool { func ShouldEnableChannel(newAPIError *types.NewAPIError, status int) bool {
if !common.AutomaticEnableChannelEnabled { if !common.AutomaticEnableChannelEnabled {
return false return false
} }
if err != nil {
return false
}
if newAPIError != nil { if newAPIError != nil {
return false return false
} }

View File

@@ -204,7 +204,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
req = req.WithContext(ctx) req = req.WithContext(ctx)
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Accept", c.Request.Header.Get("Accept"))
auth := c.Request.Header.Get("Authorization") auth := common.GetContextKeyString(c, constant.ContextKeyChannelKey)
if auth != "" { if auth != "" {
auth = strings.TrimPrefix(auth, "Bearer ") auth = strings.TrimPrefix(auth, "Bearer ")
req.Header.Set("mj-api-secret", auth) req.Header.Set("mj-api-secret", auth)

21
types/channel_error.go Normal file
View File

@@ -0,0 +1,21 @@
package types
type ChannelError struct {
ChannelId int `json:"channel_id"`
ChannelType int `json:"channel_type"`
ChannelName string `json:"channel_name"`
IsMultiKey bool `json:"is_multi_key"`
AutoBan bool `json:"auto_ban"`
UsingKey string `json:"using_key"`
}
func NewChannelError(channelId int, channelType int, channelName string, isMultiKey bool, usingKey string, autoBan bool) *ChannelError {
return &ChannelError{
ChannelId: channelId,
ChannelType: channelType,
ChannelName: channelName,
IsMultiKey: isMultiKey,
AutoBan: autoBan,
UsingKey: usingKey,
}
}

View File

@@ -50,6 +50,7 @@ const (
ErrorCodeChannelModelMappedError ErrorCode = "channel:model_mapped_error" ErrorCodeChannelModelMappedError ErrorCode = "channel:model_mapped_error"
ErrorCodeChannelAwsClientError ErrorCode = "channel:aws_client_error" ErrorCodeChannelAwsClientError ErrorCode = "channel:aws_client_error"
ErrorCodeChannelInvalidKey ErrorCode = "channel:invalid_key" ErrorCodeChannelInvalidKey ErrorCode = "channel:invalid_key"
ErrorCodeChannelResponseTimeExceeded ErrorCode = "channel:response_time_exceeded"
// client request error // client request error
ErrorCodeReadRequestBodyFailed ErrorCode = "read_request_body_failed" ErrorCodeReadRequestBodyFailed ErrorCode = "read_request_body_failed"

View File

@@ -42,6 +42,7 @@ import {
IconTreeTriangleDown, IconTreeTriangleDown,
IconSearch, IconSearch,
IconMore, IconMore,
IconList
} from '@douyinfe/semi-icons'; } from '@douyinfe/semi-icons';
import { loadChannelModels, isMobile, copy } from '../../helpers'; import { loadChannelModels, isMobile, copy } from '../../helpers';
import EditTagModal from '../../pages/Channel/EditTagModal.js'; import EditTagModal from '../../pages/Channel/EditTagModal.js';
@@ -53,7 +54,7 @@ const ChannelsTable = () => {
let type2label = undefined; let type2label = undefined;
const renderType = (type) => { const renderType = (type, multiKey = false) => {
if (!type2label) { if (!type2label) {
type2label = new Map(); type2label = new Map();
for (let i = 0; i < CHANNEL_OPTIONS.length; i++) { for (let i = 0; i < CHANNEL_OPTIONS.length; i++) {
@@ -61,12 +62,24 @@ const ChannelsTable = () => {
} }
type2label[0] = { value: 0, label: t('未知类型'), color: 'grey' }; type2label[0] = { value: 0, label: t('未知类型'), color: 'grey' };
} }
let icon = getChannelIcon(type);
if (multiKey) {
icon = (
<div className="flex items-center gap-1">
<IconList className="text-blue-500" />
{icon}
</div>
);
}
return ( return (
<Tag <Tag
size='large' size='large'
color={type2label[type]?.color} color={type2label[type]?.color}
shape='circle' shape='circle'
prefixIcon={getChannelIcon(type)} prefixIcon={icon}
> >
{type2label[type]?.label} {type2label[type]?.label}
</Tag> </Tag>
@@ -86,7 +99,19 @@ const ChannelsTable = () => {
); );
}; };
const renderStatus = (status) => { const renderStatus = (status, channelInfo = undefined) => {
if (channelInfo) {
if (channelInfo.is_multi_key) {
let keySize = channelInfo.multi_key_size;
let enabledKeySize = keySize;
if (channelInfo.multi_key_status_list) {
// multi_key_status_list is a map, key is key, value is status
// get multi_key_status_list length
enabledKeySize = keySize - Object.keys(channelInfo.multi_key_status_list).length;
}
return renderMultiKeyStatus(status, keySize, enabledKeySize);
}
}
switch (status) { switch (status) {
case 1: case 1:
return ( return (
@@ -115,6 +140,36 @@ const ChannelsTable = () => {
} }
}; };
const renderMultiKeyStatus = (status, keySize, enabledKeySize) => {
switch (status) {
case 1:
return (
<Tag size='large' color='green' shape='circle'>
{t('已启用')} {enabledKeySize}/{keySize}
</Tag>
);
case 2:
return (
<Tag size='large' color='red' shape='circle'>
{t('已禁用')} {enabledKeySize}/{keySize}
</Tag>
);
case 3:
return (
<Tag size='large' color='yellow' shape='circle'>
{t('自动禁用')} {enabledKeySize}/{keySize}
</Tag>
);
default:
return (
<Tag size='large' color='grey' shape='circle'>
{t('未知状态')} {enabledKeySize}/{keySize}
</Tag>
);
}
}
const renderResponseTime = (responseTime) => { const renderResponseTime = (responseTime) => {
let time = responseTime / 1000; let time = responseTime / 1000;
time = time.toFixed(2) + t(' 秒'); time = time.toFixed(2) + t(' 秒');
@@ -281,6 +336,11 @@ const ChannelsTable = () => {
dataIndex: 'type', dataIndex: 'type',
render: (text, record, index) => { render: (text, record, index) => {
if (record.children === undefined) { if (record.children === undefined) {
if (record.channel_info) {
if (record.channel_info.is_multi_key) {
return <>{renderType(text, record.channel_info)}</>;
}
}
return <>{renderType(text)}</>; return <>{renderType(text)}</>;
} else { } else {
return <>{renderTagType()}</>; return <>{renderTagType()}</>;
@@ -304,12 +364,12 @@ const ChannelsTable = () => {
<Tooltip <Tooltip
content={t('原因:') + reason + t(',时间:') + timestamp2string(time)} content={t('原因:') + reason + t(',时间:') + timestamp2string(time)}
> >
{renderStatus(text)} {renderStatus(text, record.channel_info)}
</Tooltip> </Tooltip>
</div> </div>
); );
} else { } else {
return renderStatus(text); return renderStatus(text, record.channel_info);
} }
}, },
}, },