✨ feat(channel): enhance channel status management
This commit is contained in:
@@ -29,6 +29,7 @@ const (
|
|||||||
ContextKeyChannelModelMapping ContextKey = "model_mapping"
|
ContextKeyChannelModelMapping ContextKey = "model_mapping"
|
||||||
ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping"
|
ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping"
|
||||||
ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key"
|
ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key"
|
||||||
|
ContextKeyChannelKey ContextKey = "channel_key"
|
||||||
|
|
||||||
/* user related keys */
|
/* user related keys */
|
||||||
ContextKeyUserId ContextKey = "id"
|
ContextKeyUserId ContextKey = "id"
|
||||||
|
|||||||
@@ -452,14 +452,14 @@ func updateAllChannelsBalance() error {
|
|||||||
//if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
|
//if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
|
||||||
// continue
|
// continue
|
||||||
//}
|
//}
|
||||||
balance, err := updateChannelBalance(channel)
|
_, err := updateChannelBalance(channel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
} else {
|
} else {
|
||||||
// err is nil & balance <= 0 means quota is used up
|
// err is nil & balance <= 0 means quota is used up
|
||||||
if balance <= 0 {
|
//if balance <= 0 {
|
||||||
service.DisableChannel(channel.Id, channel.Name, "余额不足")
|
// service.DisableChannel(channel.Id, channel.Name, "余额不足")
|
||||||
}
|
//}
|
||||||
}
|
}
|
||||||
time.Sleep(common.RequestInterval)
|
time.Sleep(common.RequestInterval)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,22 +30,43 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func testChannel(channel *model.Channel, testModel string) (err error, newAPIError *types.NewAPIError) {
|
type testResult struct {
|
||||||
|
context *gin.Context
|
||||||
|
localErr error
|
||||||
|
newAPIError *types.NewAPIError
|
||||||
|
}
|
||||||
|
|
||||||
|
func testChannel(channel *model.Channel, testModel string) testResult {
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
if channel.Type == constant.ChannelTypeMidjourney {
|
if channel.Type == constant.ChannelTypeMidjourney {
|
||||||
return errors.New("midjourney channel test is not supported"), nil
|
return testResult{
|
||||||
|
localErr: errors.New("midjourney channel test is not supported"),
|
||||||
|
newAPIError: nil,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if channel.Type == constant.ChannelTypeMidjourneyPlus {
|
if channel.Type == constant.ChannelTypeMidjourneyPlus {
|
||||||
return errors.New("midjourney plus channel test is not supported"), nil
|
return testResult{
|
||||||
|
localErr: errors.New("midjourney plus channel test is not supported"),
|
||||||
|
newAPIError: nil,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if channel.Type == constant.ChannelTypeSunoAPI {
|
if channel.Type == constant.ChannelTypeSunoAPI {
|
||||||
return errors.New("suno channel test is not supported"), nil
|
return testResult{
|
||||||
|
localErr: errors.New("suno channel test is not supported"),
|
||||||
|
newAPIError: nil,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if channel.Type == constant.ChannelTypeKling {
|
if channel.Type == constant.ChannelTypeKling {
|
||||||
return errors.New("kling channel test is not supported"), nil
|
return testResult{
|
||||||
|
localErr: errors.New("kling channel test is not supported"),
|
||||||
|
newAPIError: nil,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if channel.Type == constant.ChannelTypeJimeng {
|
if channel.Type == constant.ChannelTypeJimeng {
|
||||||
return errors.New("jimeng channel test is not supported"), nil
|
return testResult{
|
||||||
|
localErr: errors.New("jimeng channel test is not supported"),
|
||||||
|
newAPIError: nil,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
@@ -82,7 +103,10 @@ func testChannel(channel *model.Channel, testModel string) (err error, newAPIErr
|
|||||||
|
|
||||||
cache, err := model.GetUserCache(1)
|
cache, err := model.GetUserCache(1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return testResult{
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: nil,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
cache.WriteContext(c)
|
cache.WriteContext(c)
|
||||||
|
|
||||||
@@ -93,20 +117,35 @@ func testChannel(channel *model.Channel, testModel string) (err error, newAPIErr
|
|||||||
group, _ := model.GetUserGroup(1, false)
|
group, _ := model.GetUserGroup(1, false)
|
||||||
c.Set("group", group)
|
c.Set("group", group)
|
||||||
|
|
||||||
middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
||||||
|
if newAPIError != nil {
|
||||||
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: newAPIError,
|
||||||
|
newAPIError: newAPIError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
info := relaycommon.GenRelayInfo(c)
|
info := relaycommon.GenRelayInfo(c)
|
||||||
|
|
||||||
err = helper.ModelMappedHelper(c, info, nil)
|
err = helper.ModelMappedHelper(c, info, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
testModel = info.UpstreamModelName
|
testModel = info.UpstreamModelName
|
||||||
|
|
||||||
apiType, _ := common.ChannelType2APIType(channel.Type)
|
apiType, _ := common.ChannelType2APIType(channel.Type)
|
||||||
adaptor := relay.GetAdaptor(apiType)
|
adaptor := relay.GetAdaptor(apiType)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType)
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: fmt.Errorf("invalid api type: %d, adaptor is nil", apiType),
|
||||||
|
newAPIError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
request := buildTestRequest(testModel)
|
request := buildTestRequest(testModel)
|
||||||
@@ -117,45 +156,77 @@ func testChannel(channel *model.Channel, testModel string) (err error, newAPIErr
|
|||||||
|
|
||||||
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
|
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, types.NewError(err, types.ErrorCodeModelPriceError)
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: types.NewError(err, types.ErrorCodeModelPriceError),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
adaptor.Init(info)
|
adaptor.Init(info)
|
||||||
|
|
||||||
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
|
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
jsonData, err := json.Marshal(convertedRequest)
|
jsonData, err := json.Marshal(convertedRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, types.NewError(err, types.ErrorCodeJsonMarshalFailed)
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
requestBody := bytes.NewBuffer(jsonData)
|
requestBody := bytes.NewBuffer(jsonData)
|
||||||
c.Request.Body = io.NopCloser(requestBody)
|
c.Request.Body = io.NopCloser(requestBody)
|
||||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, types.NewError(err, types.ErrorCodeDoRequestFailed)
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: types.NewError(err, types.ErrorCodeDoRequestFailed),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
var httpResp *http.Response
|
var httpResp *http.Response
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
httpResp = resp.(*http.Response)
|
httpResp = resp.(*http.Response)
|
||||||
if httpResp.StatusCode != http.StatusOK {
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
err := service.RelayErrorHandler(httpResp, true)
|
err := service.RelayErrorHandler(httpResp, true)
|
||||||
return err, types.NewError(err, types.ErrorCodeBadResponse)
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: types.NewError(err, types.ErrorCodeBadResponse),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
usageA, respErr := adaptor.DoResponse(c, httpResp, info)
|
usageA, respErr := adaptor.DoResponse(c, httpResp, info)
|
||||||
if respErr != nil {
|
if respErr != nil {
|
||||||
return respErr, respErr
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: respErr,
|
||||||
|
newAPIError: respErr,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if usageA == nil {
|
if usageA == nil {
|
||||||
return errors.New("usage is nil"), types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody)
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: errors.New("usage is nil"),
|
||||||
|
newAPIError: types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
usage := usageA.(*dto.Usage)
|
usage := usageA.(*dto.Usage)
|
||||||
result := w.Result()
|
result := w.Result()
|
||||||
respBody, err := io.ReadAll(result.Body)
|
respBody, err := io.ReadAll(result.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: types.NewError(err, types.ErrorCodeReadResponseBodyFailed),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
info.PromptTokens = usage.PromptTokens
|
info.PromptTokens = usage.PromptTokens
|
||||||
|
|
||||||
@@ -188,7 +259,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, newAPIErr
|
|||||||
Other: other,
|
Other: other,
|
||||||
})
|
})
|
||||||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||||
return nil, nil
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: nil,
|
||||||
|
newAPIError: nil,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
||||||
@@ -247,15 +322,23 @@ func TestChannel(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
testModel := c.Query("model")
|
testModel := c.Query("model")
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
_, newAPIError := testChannel(channel, testModel)
|
result := testChannel(channel, testModel)
|
||||||
|
if result.localErr != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": result.localErr.Error(),
|
||||||
|
"time": 0.0,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
go channel.UpdateResponseTime(milliseconds)
|
go channel.UpdateResponseTime(milliseconds)
|
||||||
consumedTime := float64(milliseconds) / 1000.0
|
consumedTime := float64(milliseconds) / 1000.0
|
||||||
if newAPIError != nil {
|
if result.newAPIError != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": newAPIError.Error(),
|
"message": result.newAPIError.Error(),
|
||||||
"time": consumedTime,
|
"time": consumedTime,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
@@ -280,9 +363,9 @@ func testAllChannels(notify bool) error {
|
|||||||
}
|
}
|
||||||
testAllChannelsRunning = true
|
testAllChannelsRunning = true
|
||||||
testAllChannelsLock.Unlock()
|
testAllChannelsLock.Unlock()
|
||||||
channels, err := model.GetAllChannels(0, 0, true, false)
|
channels, getChannelErr := model.GetAllChannels(0, 0, true, false)
|
||||||
if err != nil {
|
if getChannelErr != nil {
|
||||||
return err
|
return getChannelErr
|
||||||
}
|
}
|
||||||
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
|
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
|
||||||
if disableThreshold == 0 {
|
if disableThreshold == 0 {
|
||||||
@@ -299,30 +382,34 @@ func testAllChannels(notify bool) error {
|
|||||||
for _, channel := range channels {
|
for _, channel := range channels {
|
||||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
err, newAPIError := testChannel(channel, "")
|
result := testChannel(channel, "")
|
||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
|
|
||||||
shouldBanChannel := false
|
shouldBanChannel := false
|
||||||
|
newAPIError := result.newAPIError
|
||||||
// request error disables the channel
|
// request error disables the channel
|
||||||
if err != nil {
|
if newAPIError != nil {
|
||||||
shouldBanChannel = service.ShouldDisableChannel(channel.Type, newAPIError)
|
shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 当错误检查通过,才检查响应时间
|
||||||
|
if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
|
||||||
if milliseconds > disableThreshold {
|
if milliseconds > disableThreshold {
|
||||||
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
||||||
|
newAPIError = types.NewError(err, types.ErrorCodeChannelResponseTimeExceeded)
|
||||||
shouldBanChannel = true
|
shouldBanChannel = true
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// disable channel
|
// disable channel
|
||||||
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
|
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
|
||||||
service.DisableChannel(channel.Id, channel.Name, err.Error())
|
go processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// enable channel
|
// enable channel
|
||||||
if !isChannelEnabled && service.ShouldEnableChannel(err, newAPIError, channel.Status) {
|
if !isChannelEnabled && service.ShouldEnableChannel(newAPIError, channel.Status) {
|
||||||
service.EnableChannel(channel.Id, channel.Name)
|
service.EnableChannel(channel.Id, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
channel.UpdateResponseTime(milliseconds)
|
channel.UpdateResponseTime(milliseconds)
|
||||||
|
|||||||
@@ -497,6 +497,7 @@ func AddChannel(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(array)
|
||||||
addChannelRequest.Channel.Key = strings.Join(array, "\n")
|
addChannelRequest.Channel.Key = strings.Join(array, "\n")
|
||||||
} else {
|
} else {
|
||||||
cleanKeys := make([]string, 0)
|
cleanKeys := make([]string, 0)
|
||||||
@@ -507,6 +508,7 @@ func AddChannel(c *gin.Context) {
|
|||||||
key = strings.TrimSpace(key)
|
key = strings.TrimSpace(key)
|
||||||
cleanKeys = append(cleanKeys, key)
|
cleanKeys = append(cleanKeys, key)
|
||||||
}
|
}
|
||||||
|
addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys)
|
||||||
addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
|
addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
|
||||||
}
|
}
|
||||||
keys = []string{addChannelRequest.Channel.Key}
|
keys = []string{addChannelRequest.Channel.Key}
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ func Relay(c *gin.Context) {
|
|||||||
channel, err := getChannel(c, group, originalModel, i)
|
channel, err := getChannel(c, group, originalModel, i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, err.Error())
|
common.LogError(c, err.Error())
|
||||||
newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed)
|
newAPIError = err
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -90,7 +90,7 @@ func Relay(c *gin.Context) {
|
|||||||
return // 成功处理请求,直接返回
|
return // 成功处理请求,直接返回
|
||||||
}
|
}
|
||||||
|
|
||||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), newAPIError)
|
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||||
|
|
||||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||||
break
|
break
|
||||||
@@ -103,10 +103,10 @@ func Relay(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if newAPIError != nil {
|
if newAPIError != nil {
|
||||||
if newAPIError.StatusCode == http.StatusTooManyRequests {
|
//if newAPIError.StatusCode == http.StatusTooManyRequests {
|
||||||
common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error()))
|
// common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error()))
|
||||||
newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
|
// newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
|
||||||
}
|
//}
|
||||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||||
c.JSON(newAPIError.StatusCode, gin.H{
|
c.JSON(newAPIError.StatusCode, gin.H{
|
||||||
"error": newAPIError.ToOpenAIError(),
|
"error": newAPIError.ToOpenAIError(),
|
||||||
@@ -143,7 +143,7 @@ func WssRelay(c *gin.Context) {
|
|||||||
channel, err := getChannel(c, group, originalModel, i)
|
channel, err := getChannel(c, group, originalModel, i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, err.Error())
|
common.LogError(c, err.Error())
|
||||||
newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed)
|
newAPIError = err
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -153,7 +153,7 @@ func WssRelay(c *gin.Context) {
|
|||||||
return // 成功处理请求,直接返回
|
return // 成功处理请求,直接返回
|
||||||
}
|
}
|
||||||
|
|
||||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), newAPIError)
|
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||||
|
|
||||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||||
break
|
break
|
||||||
@@ -166,9 +166,9 @@ func WssRelay(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if newAPIError != nil {
|
if newAPIError != nil {
|
||||||
if newAPIError.StatusCode == http.StatusTooManyRequests {
|
//if newAPIError.StatusCode == http.StatusTooManyRequests {
|
||||||
newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
|
// newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
|
||||||
}
|
//}
|
||||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||||
helper.WssError(c, ws, newAPIError.ToOpenAIError())
|
helper.WssError(c, ws, newAPIError.ToOpenAIError())
|
||||||
}
|
}
|
||||||
@@ -185,7 +185,7 @@ func RelayClaude(c *gin.Context) {
|
|||||||
channel, err := getChannel(c, group, originalModel, i)
|
channel, err := getChannel(c, group, originalModel, i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, err.Error())
|
common.LogError(c, err.Error())
|
||||||
newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed)
|
newAPIError = err
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -195,7 +195,7 @@ func RelayClaude(c *gin.Context) {
|
|||||||
return // 成功处理请求,直接返回
|
return // 成功处理请求,直接返回
|
||||||
}
|
}
|
||||||
|
|
||||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), newAPIError)
|
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||||
|
|
||||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||||
break
|
break
|
||||||
@@ -243,7 +243,7 @@ func addUsedChannel(c *gin.Context, channelId int) {
|
|||||||
c.Set("use_channel", useChannel)
|
c.Set("use_channel", useChannel)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) {
|
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, *types.NewAPIError) {
|
||||||
if retryCount == 0 {
|
if retryCount == 0 {
|
||||||
autoBan := c.GetBool("auto_ban")
|
autoBan := c.GetBool("auto_ban")
|
||||||
autoBanInt := 1
|
autoBanInt := 1
|
||||||
@@ -260,11 +260,14 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
|
|||||||
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if group == "auto" {
|
if group == "auto" {
|
||||||
return nil, errors.New(fmt.Sprintf("获取自动分组下模型 %s 的可用渠道失败: %s", originalModel, err.Error()))
|
return nil, types.NewError(errors.New(fmt.Sprintf("获取自动分组下模型 %s 的可用渠道失败: %s", originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
|
||||||
}
|
}
|
||||||
return nil, errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败: %s", selectGroup, originalModel, err.Error()))
|
return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败: %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
|
||||||
|
}
|
||||||
|
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||||
|
if newAPIError != nil {
|
||||||
|
return nil, newAPIError
|
||||||
}
|
}
|
||||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
|
||||||
return channel, nil
|
return channel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -314,12 +317,12 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *types.NewAPIError) {
|
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
|
||||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||||
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error()))
|
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
||||||
if service.ShouldDisableChannel(channelType, err) && autoBan {
|
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
||||||
service.DisableChannel(channelId, channelName, err.Error())
|
service.DisableChannel(channelError, err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -392,10 +395,10 @@ func RelayTask(c *gin.Context) {
|
|||||||
retryTimes = 0
|
retryTimes = 0
|
||||||
}
|
}
|
||||||
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
||||||
channel, err := getChannel(c, group, originalModel, i)
|
channel, newAPIError := getChannel(c, group, originalModel, i)
|
||||||
if err != nil {
|
if newAPIError != nil {
|
||||||
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
|
||||||
taskErr = service.TaskErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
channelId = channel.Id
|
channelId = channel.Id
|
||||||
@@ -405,7 +408,7 @@ func RelayTask(c *gin.Context) {
|
|||||||
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
||||||
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||||
|
|
||||||
requestBody, err := common.GetRequestBody(c)
|
requestBody, _ := common.GetRequestBody(c)
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
taskErr = taskRelayHandler(c, relayMode)
|
taskErr = taskRelayHandler(c, relayMode)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"one-api/setting/ratio_setting"
|
"one-api/setting/ratio_setting"
|
||||||
|
"one-api/types"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -249,10 +250,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|||||||
return &modelRequest, shouldSelectChannel, nil
|
return &modelRequest, shouldSelectChannel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
|
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError {
|
||||||
c.Set("original_model", modelName) // for retry
|
c.Set("original_model", modelName) // for retry
|
||||||
if channel == nil {
|
if channel == nil {
|
||||||
return
|
return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed)
|
||||||
}
|
}
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id)
|
common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id)
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name)
|
common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name)
|
||||||
@@ -270,7 +271,13 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
|||||||
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
|
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
|
||||||
|
|
||||||
}
|
}
|
||||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
|
||||||
|
key, newAPIError := channel.GetNextEnabledKey()
|
||||||
|
if newAPIError != nil {
|
||||||
|
return newAPIError
|
||||||
|
}
|
||||||
|
// c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
|
||||||
|
common.SetContextKey(c, constant.ContextKeyChannelKey, key)
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
|
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
|
||||||
|
|
||||||
// TODO: api_version统一
|
// TODO: api_version统一
|
||||||
@@ -292,6 +299,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
|||||||
case constant.ChannelTypeCoze:
|
case constant.ChannelTypeCoze:
|
||||||
c.Set("bot_id", channel.Other)
|
c.Set("bot_id", channel.Other)
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名
|
// extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名
|
||||||
|
|||||||
@@ -203,3 +203,16 @@ func CacheUpdateChannelStatus(id int, status int) {
|
|||||||
channel.Status = status
|
channel.Status = status
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CacheUpdateChannel(channel *Channel) {
|
||||||
|
if !common.MemoryCacheEnabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
channelSyncLock.Lock()
|
||||||
|
defer channelSyncLock.Unlock()
|
||||||
|
|
||||||
|
if channel == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
channelsIDM[channel.Id] = channel
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,11 +3,12 @@ package model
|
|||||||
import (
|
import (
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"errors"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -48,6 +49,7 @@ type Channel struct {
|
|||||||
|
|
||||||
type ChannelInfo struct {
|
type ChannelInfo struct {
|
||||||
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
|
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
|
||||||
|
MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量
|
||||||
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
|
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
|
||||||
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
|
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
|
||||||
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||||||
@@ -73,7 +75,7 @@ func (channel *Channel) getKeys() []string {
|
|||||||
return keys
|
return keys
|
||||||
}
|
}
|
||||||
|
|
||||||
func (channel *Channel) GetNextEnabledKey() (string, error) {
|
func (channel *Channel) GetNextEnabledKey() (string, *types.NewAPIError) {
|
||||||
// If not in multi-key mode, return the original key string directly.
|
// If not in multi-key mode, return the original key string directly.
|
||||||
if !channel.ChannelInfo.IsMultiKey {
|
if !channel.ChannelInfo.IsMultiKey {
|
||||||
return channel.Key, nil
|
return channel.Key, nil
|
||||||
@@ -83,7 +85,7 @@ func (channel *Channel) GetNextEnabledKey() (string, error) {
|
|||||||
keys := channel.getKeys()
|
keys := channel.getKeys()
|
||||||
if len(keys) == 0 {
|
if len(keys) == 0 {
|
||||||
// No keys available, return error, should disable the channel
|
// No keys available, return error, should disable the channel
|
||||||
return "", fmt.Errorf("no valid keys in channel")
|
return "", types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
statusList := channel.ChannelInfo.MultiKeyStatusList
|
statusList := channel.ChannelInfo.MultiKeyStatusList
|
||||||
@@ -404,48 +406,94 @@ func (channel *Channel) Delete() error {
|
|||||||
|
|
||||||
var channelStatusLock sync.Mutex
|
var channelStatusLock sync.Mutex
|
||||||
|
|
||||||
func UpdateChannelStatusById(id int, status int, reason string) bool {
|
func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) {
|
||||||
|
keys := channel.getKeys()
|
||||||
|
if len(keys) == 0 {
|
||||||
|
channel.Status = status
|
||||||
|
} else {
|
||||||
|
var keyIndex int
|
||||||
|
for i, key := range keys {
|
||||||
|
if key == usingKey {
|
||||||
|
keyIndex = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if channel.ChannelInfo.MultiKeyStatusList == nil {
|
||||||
|
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
|
||||||
|
}
|
||||||
|
if status == common.ChannelStatusEnabled {
|
||||||
|
delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
|
||||||
|
} else {
|
||||||
|
channel.ChannelInfo.MultiKeyStatusList[keyIndex] = status
|
||||||
|
}
|
||||||
|
if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize {
|
||||||
|
channel.Status = common.ChannelStatusAutoDisabled
|
||||||
|
info := channel.GetOtherInfo()
|
||||||
|
info["status_reason"] = "All keys are disabled"
|
||||||
|
info["status_time"] = common.GetTimestamp()
|
||||||
|
channel.SetOtherInfo(info)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpdateChannelStatus(channelId int, usingKey string, status int, reason string) bool {
|
||||||
if common.MemoryCacheEnabled {
|
if common.MemoryCacheEnabled {
|
||||||
channelStatusLock.Lock()
|
channelStatusLock.Lock()
|
||||||
defer channelStatusLock.Unlock()
|
defer channelStatusLock.Unlock()
|
||||||
|
|
||||||
channelCache, _ := CacheGetChannel(id)
|
channelCache, _ := CacheGetChannel(channelId)
|
||||||
|
if channelCache == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if channelCache.ChannelInfo.IsMultiKey {
|
||||||
|
// 如果是多Key模式,更新缓存中的状态
|
||||||
|
handlerMultiKeyUpdate(channelCache, usingKey, status)
|
||||||
|
CacheUpdateChannel(channelCache)
|
||||||
|
//return true
|
||||||
|
} else {
|
||||||
// 如果缓存渠道存在,且状态已是目标状态,直接返回
|
// 如果缓存渠道存在,且状态已是目标状态,直接返回
|
||||||
if channelCache != nil && channelCache.Status == status {
|
if channelCache.Status == status {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
// 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回
|
// 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回
|
||||||
if channelCache == nil && status != common.ChannelStatusEnabled {
|
if status != common.ChannelStatusEnabled {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
CacheUpdateChannelStatus(id, status)
|
CacheUpdateChannelStatus(channelId, status)
|
||||||
}
|
}
|
||||||
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
|
}
|
||||||
|
|
||||||
|
shouldUpdateAbilities := false
|
||||||
|
defer func() {
|
||||||
|
if shouldUpdateAbilities {
|
||||||
|
err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update ability status: " + err.Error())
|
common.SysError("failed to update ability status: " + err.Error())
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
channel, err := GetChannelById(id, true)
|
}
|
||||||
|
}()
|
||||||
|
channel, err := GetChannelById(channelId, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// find channel by id error, directly update status
|
|
||||||
result := DB.Model(&Channel{}).Where("id = ?", id).Update("status", status)
|
|
||||||
if result.Error != nil {
|
|
||||||
common.SysError("failed to update channel status: " + result.Error.Error())
|
|
||||||
return false
|
return false
|
||||||
}
|
|
||||||
if result.RowsAffected == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
if channel.Status == status {
|
if channel.Status == status {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
// find channel by id success, update status and other info
|
|
||||||
|
if channel.ChannelInfo.IsMultiKey {
|
||||||
|
beforeStatus := channel.Status
|
||||||
|
handlerMultiKeyUpdate(channel, usingKey, status)
|
||||||
|
if beforeStatus != channel.Status {
|
||||||
|
shouldUpdateAbilities = true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
info := channel.GetOtherInfo()
|
info := channel.GetOtherInfo()
|
||||||
info["status_reason"] = reason
|
info["status_reason"] = reason
|
||||||
info["status_time"] = common.GetTimestamp()
|
info["status_time"] = common.GetTimestamp()
|
||||||
channel.SetOtherInfo(info)
|
channel.SetOtherInfo(info)
|
||||||
channel.Status = status
|
channel.Status = status
|
||||||
|
shouldUpdateAbilities = true
|
||||||
|
}
|
||||||
err = channel.Save()
|
err = channel.Save()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update channel status: " + err.Error())
|
common.SysError("failed to update channel status: " + err.Error())
|
||||||
@@ -628,6 +676,8 @@ func (channel *Channel) GetSetting() dto.ChannelSettings {
|
|||||||
err := json.Unmarshal([]byte(*channel.Setting), &setting)
|
err := json.Unmarshal([]byte(*channel.Setting), &setting)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to unmarshal setting: " + err.Error())
|
common.SysError("failed to unmarshal setting: " + err.Error())
|
||||||
|
channel.Setting = nil // 清空设置以避免后续错误
|
||||||
|
_ = channel.Save() // 保存修改
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return setting
|
return setting
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
@@ -63,7 +64,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
apiKey := c.Request.Header.Get("Authorization")
|
apiKey := common.GetContextKeyString(c, constant.ContextKeyChannelKey)
|
||||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||||
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
|
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
|
||||||
a.AppID = appId
|
a.AppID = appId
|
||||||
|
|||||||
@@ -247,7 +247,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
|||||||
IsModelMapped: false,
|
IsModelMapped: false,
|
||||||
ApiType: apiType,
|
ApiType: apiType,
|
||||||
ApiVersion: c.GetString("api_version"),
|
ApiVersion: c.GetString("api_version"),
|
||||||
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey),
|
||||||
Organization: c.GetString("channel_organization"),
|
Organization: c.GetString("channel_organization"),
|
||||||
|
|
||||||
ChannelCreateTime: c.GetInt64("channel_create_time"),
|
ChannelCreateTime: c.GetInt64("channel_create_time"),
|
||||||
|
|||||||
@@ -575,7 +575,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
common.SysError("get_channel_null: " + err.Error())
|
common.SysError("get_channel_null: " + err.Error())
|
||||||
}
|
}
|
||||||
if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled {
|
if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled {
|
||||||
model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance")
|
model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {
|
if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {
|
||||||
|
|||||||
@@ -17,17 +17,17 @@ func formatNotifyType(channelId int, status int) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// disable & notify
|
// disable & notify
|
||||||
func DisableChannel(channelId int, channelName string, reason string) {
|
func DisableChannel(channelError types.ChannelError, reason string) {
|
||||||
success := model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason)
|
success := model.UpdateChannelStatus(channelError.ChannelId, channelError.UsingKey, common.ChannelStatusAutoDisabled, reason)
|
||||||
if success {
|
if success {
|
||||||
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelError.ChannelName, channelError.ChannelId)
|
||||||
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelError.ChannelName, channelError.ChannelId, reason)
|
||||||
NotifyRootUser(formatNotifyType(channelId, common.ChannelStatusAutoDisabled), subject, content)
|
NotifyRootUser(formatNotifyType(channelError.ChannelId, common.ChannelStatusAutoDisabled), subject, content)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func EnableChannel(channelId int, channelName string) {
|
func EnableChannel(channelId int, usingKey string, channelName string) {
|
||||||
success := model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "")
|
success := model.UpdateChannelStatus(channelId, usingKey, common.ChannelStatusEnabled, "")
|
||||||
if success {
|
if success {
|
||||||
subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
||||||
content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
||||||
@@ -87,13 +87,10 @@ func ShouldDisableChannel(channelType int, err *types.NewAPIError) bool {
|
|||||||
return search
|
return search
|
||||||
}
|
}
|
||||||
|
|
||||||
func ShouldEnableChannel(err error, newAPIError *types.NewAPIError, status int) bool {
|
func ShouldEnableChannel(newAPIError *types.NewAPIError, status int) bool {
|
||||||
if !common.AutomaticEnableChannelEnabled {
|
if !common.AutomaticEnableChannelEnabled {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if newAPIError != nil {
|
if newAPIError != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -204,7 +204,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
|
|||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||||
auth := c.Request.Header.Get("Authorization")
|
auth := common.GetContextKeyString(c, constant.ContextKeyChannelKey)
|
||||||
if auth != "" {
|
if auth != "" {
|
||||||
auth = strings.TrimPrefix(auth, "Bearer ")
|
auth = strings.TrimPrefix(auth, "Bearer ")
|
||||||
req.Header.Set("mj-api-secret", auth)
|
req.Header.Set("mj-api-secret", auth)
|
||||||
|
|||||||
21
types/channel_error.go
Normal file
21
types/channel_error.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
type ChannelError struct {
|
||||||
|
ChannelId int `json:"channel_id"`
|
||||||
|
ChannelType int `json:"channel_type"`
|
||||||
|
ChannelName string `json:"channel_name"`
|
||||||
|
IsMultiKey bool `json:"is_multi_key"`
|
||||||
|
AutoBan bool `json:"auto_ban"`
|
||||||
|
UsingKey string `json:"using_key"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChannelError(channelId int, channelType int, channelName string, isMultiKey bool, usingKey string, autoBan bool) *ChannelError {
|
||||||
|
return &ChannelError{
|
||||||
|
ChannelId: channelId,
|
||||||
|
ChannelType: channelType,
|
||||||
|
ChannelName: channelName,
|
||||||
|
IsMultiKey: isMultiKey,
|
||||||
|
AutoBan: autoBan,
|
||||||
|
UsingKey: usingKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -50,6 +50,7 @@ const (
|
|||||||
ErrorCodeChannelModelMappedError ErrorCode = "channel:model_mapped_error"
|
ErrorCodeChannelModelMappedError ErrorCode = "channel:model_mapped_error"
|
||||||
ErrorCodeChannelAwsClientError ErrorCode = "channel:aws_client_error"
|
ErrorCodeChannelAwsClientError ErrorCode = "channel:aws_client_error"
|
||||||
ErrorCodeChannelInvalidKey ErrorCode = "channel:invalid_key"
|
ErrorCodeChannelInvalidKey ErrorCode = "channel:invalid_key"
|
||||||
|
ErrorCodeChannelResponseTimeExceeded ErrorCode = "channel:response_time_exceeded"
|
||||||
|
|
||||||
// client request error
|
// client request error
|
||||||
ErrorCodeReadRequestBodyFailed ErrorCode = "read_request_body_failed"
|
ErrorCodeReadRequestBodyFailed ErrorCode = "read_request_body_failed"
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ import {
|
|||||||
IconTreeTriangleDown,
|
IconTreeTriangleDown,
|
||||||
IconSearch,
|
IconSearch,
|
||||||
IconMore,
|
IconMore,
|
||||||
|
IconList
|
||||||
} from '@douyinfe/semi-icons';
|
} from '@douyinfe/semi-icons';
|
||||||
import { loadChannelModels, isMobile, copy } from '../../helpers';
|
import { loadChannelModels, isMobile, copy } from '../../helpers';
|
||||||
import EditTagModal from '../../pages/Channel/EditTagModal.js';
|
import EditTagModal from '../../pages/Channel/EditTagModal.js';
|
||||||
@@ -53,7 +54,7 @@ const ChannelsTable = () => {
|
|||||||
|
|
||||||
let type2label = undefined;
|
let type2label = undefined;
|
||||||
|
|
||||||
const renderType = (type) => {
|
const renderType = (type, multiKey = false) => {
|
||||||
if (!type2label) {
|
if (!type2label) {
|
||||||
type2label = new Map();
|
type2label = new Map();
|
||||||
for (let i = 0; i < CHANNEL_OPTIONS.length; i++) {
|
for (let i = 0; i < CHANNEL_OPTIONS.length; i++) {
|
||||||
@@ -61,12 +62,24 @@ const ChannelsTable = () => {
|
|||||||
}
|
}
|
||||||
type2label[0] = { value: 0, label: t('未知类型'), color: 'grey' };
|
type2label[0] = { value: 0, label: t('未知类型'), color: 'grey' };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let icon = getChannelIcon(type);
|
||||||
|
|
||||||
|
if (multiKey) {
|
||||||
|
icon = (
|
||||||
|
<div className="flex items-center gap-1">
|
||||||
|
<IconList className="text-blue-500" />
|
||||||
|
{icon}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Tag
|
<Tag
|
||||||
size='large'
|
size='large'
|
||||||
color={type2label[type]?.color}
|
color={type2label[type]?.color}
|
||||||
shape='circle'
|
shape='circle'
|
||||||
prefixIcon={getChannelIcon(type)}
|
prefixIcon={icon}
|
||||||
>
|
>
|
||||||
{type2label[type]?.label}
|
{type2label[type]?.label}
|
||||||
</Tag>
|
</Tag>
|
||||||
@@ -86,7 +99,19 @@ const ChannelsTable = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
const renderStatus = (status) => {
|
const renderStatus = (status, channelInfo = undefined) => {
|
||||||
|
if (channelInfo) {
|
||||||
|
if (channelInfo.is_multi_key) {
|
||||||
|
let keySize = channelInfo.multi_key_size;
|
||||||
|
let enabledKeySize = keySize;
|
||||||
|
if (channelInfo.multi_key_status_list) {
|
||||||
|
// multi_key_status_list is a map, key is key, value is status
|
||||||
|
// get multi_key_status_list length
|
||||||
|
enabledKeySize = keySize - Object.keys(channelInfo.multi_key_status_list).length;
|
||||||
|
}
|
||||||
|
return renderMultiKeyStatus(status, keySize, enabledKeySize);
|
||||||
|
}
|
||||||
|
}
|
||||||
switch (status) {
|
switch (status) {
|
||||||
case 1:
|
case 1:
|
||||||
return (
|
return (
|
||||||
@@ -115,6 +140,36 @@ const ChannelsTable = () => {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const renderMultiKeyStatus = (status, keySize, enabledKeySize) => {
|
||||||
|
switch (status) {
|
||||||
|
case 1:
|
||||||
|
return (
|
||||||
|
<Tag size='large' color='green' shape='circle'>
|
||||||
|
{t('已启用')} {enabledKeySize}/{keySize}
|
||||||
|
</Tag>
|
||||||
|
);
|
||||||
|
case 2:
|
||||||
|
return (
|
||||||
|
<Tag size='large' color='red' shape='circle'>
|
||||||
|
{t('已禁用')} {enabledKeySize}/{keySize}
|
||||||
|
</Tag>
|
||||||
|
);
|
||||||
|
case 3:
|
||||||
|
return (
|
||||||
|
<Tag size='large' color='yellow' shape='circle'>
|
||||||
|
{t('自动禁用')} {enabledKeySize}/{keySize}
|
||||||
|
</Tag>
|
||||||
|
);
|
||||||
|
default:
|
||||||
|
return (
|
||||||
|
<Tag size='large' color='grey' shape='circle'>
|
||||||
|
{t('未知状态')} {enabledKeySize}/{keySize}
|
||||||
|
</Tag>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
const renderResponseTime = (responseTime) => {
|
const renderResponseTime = (responseTime) => {
|
||||||
let time = responseTime / 1000;
|
let time = responseTime / 1000;
|
||||||
time = time.toFixed(2) + t(' 秒');
|
time = time.toFixed(2) + t(' 秒');
|
||||||
@@ -281,6 +336,11 @@ const ChannelsTable = () => {
|
|||||||
dataIndex: 'type',
|
dataIndex: 'type',
|
||||||
render: (text, record, index) => {
|
render: (text, record, index) => {
|
||||||
if (record.children === undefined) {
|
if (record.children === undefined) {
|
||||||
|
if (record.channel_info) {
|
||||||
|
if (record.channel_info.is_multi_key) {
|
||||||
|
return <>{renderType(text, record.channel_info)}</>;
|
||||||
|
}
|
||||||
|
}
|
||||||
return <>{renderType(text)}</>;
|
return <>{renderType(text)}</>;
|
||||||
} else {
|
} else {
|
||||||
return <>{renderTagType()}</>;
|
return <>{renderTagType()}</>;
|
||||||
@@ -304,12 +364,12 @@ const ChannelsTable = () => {
|
|||||||
<Tooltip
|
<Tooltip
|
||||||
content={t('原因:') + reason + t(',时间:') + timestamp2string(time)}
|
content={t('原因:') + reason + t(',时间:') + timestamp2string(time)}
|
||||||
>
|
>
|
||||||
{renderStatus(text)}
|
{renderStatus(text, record.channel_info)}
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
return renderStatus(text);
|
return renderStatus(text, record.channel_info);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
Reference in New Issue
Block a user