@@ -2,6 +2,7 @@ package controller
import (
"bytes"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
@@ -39,44 +40,35 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
func Relay ( c * gin . Context ) {
relayMode := constant . Path2RelayMode ( c . Request . URL . Path )
retryTimes := common . RetryTimes
requestId := c . GetString ( common . RequestIdKey )
channelId := c . GetInt ( "channel_id" )
channelType := c . GetInt ( "channel_type" )
channelName := c . GetString ( "channel_name" )
group := c . GetString ( "group" )
originalModel := c . GetString ( "original_model" )
openaiErr := relayHandler ( c , relayM ode)
c . Set ( "use_channel" , [ ] string { fmt . Sprintf ( "%d" , channelId ) } )
if openaiErr ! = nil {
go processChannelError ( c , channelId , channelType , channelName , openaiErr )
} else {
retryTimes = 0
}
for i := 0 ; shouldRetry ( c , channelId , openaiErr , retryTimes ) && i < retryTimes ; i ++ {
channel , err := model . CacheGetRandomSatisfiedChannel ( group , originalModel , i )
var openaiErr * dto . OpenAIErrorWithStatusC ode
for i := 0 ; i < = common . RetryTimes ; i ++ {
channel , err := getChannel ( c , group , originalModel , i )
if err != nil {
common . LogError ( c . Request . Context ( ) , fmt . Sprintf ( "CacheGetRandomSatisfiedChannel failed: %s" , err . Error ( ) ) )
common . LogError ( c , err . Error ( ) )
openaiErr = service . OpenAIErrorWrapperLocal ( err , "get_channel_failed" , http . StatusInternalServerError )
break
}
channelId = channel . Id
useChannel := c . GetStringSlice ( "use_channel" )
useChannel = append ( useChannel , fmt . Sprintf ( "%d" , channel . Id ) )
c . Set ( "use_channel" , useChannel )
common . LogInfo ( c . Request . Context ( ) , fmt . Sprintf ( "using channel #%d to retry (remain times %d)" , channel . Id , i ) )
middleware . SetupContextForSelectedChannel ( c , channel , originalModel )
requestBody , e rr : = common . GetRequestB ody ( c )
c . Request . Body = io . NopCloser ( bytes . NewBuffer ( requestBody ) )
openaiErr = relayHandler ( c , relayMode )
if openaiErr != nil {
go processChannelError ( c , channel . Id , channel . Type , channel . Name , openaiErr )
openaiE rr = relayRequest ( c , relayM ode , channel )
if openaiErr = = nil {
return // 成功处理请求,直接返回
}
go processChannelError ( c , channel . Id , channel . Type , channel . Name , channel . GetAutoBan ( ) , openaiErr )
if ! shouldRetry ( c , openaiErr , common . RetryTimes - i ) {
break
}
}
useChannel := c . GetStringSlice ( "use_channel" )
if len ( useChannel ) > 1 {
retryLogStr := fmt . Sprintf ( "重试:%s" , strings . Trim ( strings . Join ( strings . Fields ( fmt . Sprint ( useChannel ) ) , "->" ) , "[]" ) )
common . LogInfo ( c . Request . Context ( ) , retryLogStr )
common . LogInfo ( c , retryLogStr )
}
if openaiErr != nil {
@@ -90,7 +82,42 @@ func Relay(c *gin.Context) {
}
}
func shouldRetry ( c * gin . Context , channelId int , openaiErr * dto . OpenAIErrorWithStatusCode , retryTimes int ) bool {
func relayRequest ( c * gin . Context , relayMode int , channel * model . Channel ) * dto . OpenAIErrorWithStatusCode {
addUsedChannel ( c , channel . Id )
requestBody , _ := common . GetRequestBody ( c )
c . Request . Body = io . NopCloser ( bytes . NewBuffer ( requestBody ) )
return relayHandler ( c , relayMode )
}
func addUsedChannel ( c * gin . Context , channelId int ) {
useChannel := c . GetStringSlice ( "use_channel" )
useChannel = append ( useChannel , fmt . Sprintf ( "%d" , channelId ) )
c . Set ( "use_channel" , useChannel )
}
func getChannel ( c * gin . Context , group , originalModel string , retryCount int ) ( * model . Channel , error ) {
if retryCount == 0 {
autoBan := c . GetBool ( "auto_ban" )
autoBanInt := 1
if ! autoBan {
autoBanInt = 0
}
return & model . Channel {
Id : c . GetInt ( "channel_id" ) ,
Type : c . GetInt ( "channel_type" ) ,
Name : c . GetString ( "channel_name" ) ,
AutoBan : & autoBanInt ,
} , nil
}
channel , err := model . CacheGetRandomSatisfiedChannel ( group , originalModel , retryCount )
if err != nil {
return nil , errors . New ( fmt . Sprintf ( "获取重试渠道失败: %s" , err . Error ( ) ) )
}
middleware . SetupContextForSelectedChannel ( c , channel , originalModel )
return channel , nil
}
func shouldRetry ( c * gin . Context , openaiErr * dto . OpenAIErrorWithStatusCode , retryTimes int ) bool {
if openaiErr == nil {
return false
}
@@ -114,6 +141,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
return true
}
if openaiErr . StatusCode == http . StatusBadRequest {
channelType := c . GetInt ( "channel_type" )
if channelType == common . ChannelTypeAnthropic {
return true
}
return false
}
if openaiErr . StatusCode == 408 {
@@ -129,9 +160,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
return true
}
func processChannelError ( c * gin . Context , channelId int , channelType int , channelName string , err * dto . OpenAIErrorWithStatusCode ) {
autoBan := c . GetBool ( "auto_ban" )
common . LogError ( c . Request . Context ( ) , fmt . Sprintf ( "relay error (channel #%d, status code: %d): %s" , channelId , err . StatusCode , err . Error . Message ) )
func processChannelError ( c * gin . Context , channelId int , channelType int , channelName string , autoBan bool , err * dto . OpenAIErrorWithStatusCode ) {
// 不要使用context获取渠道信息, 异步处理时可能会出现渠道信息不一致的情况
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
common . LogError ( c , fmt . Sprintf ( "relay error (channel #%d, status code: %d): %s" , channelId , err . StatusCode , err . Error . Message ) )
if service . ShouldDisableChannel ( channelType , err ) && autoBan {
service . DisableChannel ( channelId , channelName , err . Error . Message )
}
@@ -208,14 +240,14 @@ func RelayTask(c *gin.Context) {
for i := 0 ; shouldRetryTaskRelay ( c , channelId , taskErr , retryTimes ) && i < retryTimes ; i ++ {
channel , err := model . CacheGetRandomSatisfiedChannel ( group , originalModel , i )
if err != nil {
common . LogError ( c . Request . Context ( ) , fmt . Sprintf ( "CacheGetRandomSatisfiedChannel failed: %s" , err . Error ( ) ) )
common . LogError ( c , fmt . Sprintf ( "CacheGetRandomSatisfiedChannel failed: %s" , err . Error ( ) ) )
break
}
channelId = channel . Id
useChannel := c . GetStringSlice ( "use_channel" )
useChannel = append ( useChannel , fmt . Sprintf ( "%d" , channelId ) )
c . Set ( "use_channel" , useChannel )
common . LogInfo ( c . Request . Context ( ) , fmt . Sprintf ( "using channel #%d to retry (remain times %d)" , channel . Id , i ) )
common . LogInfo ( c , fmt . Sprintf ( "using channel #%d to retry (remain times %d)" , channel . Id , i ) )
middleware . SetupContextForSelectedChannel ( c , channel , originalModel )
requestBody , err := common . GetRequestBody ( c )
@@ -225,7 +257,7 @@ func RelayTask(c *gin.Context) {
useChannel := c . GetStringSlice ( "use_channel" )
if len ( useChannel ) > 1 {
retryLogStr := fmt . Sprintf ( "重试:%s" , strings . Trim ( strings . Join ( strings . Fields ( fmt . Sprint ( useChannel ) ) , "->" ) , "[]" ) )
common . LogInfo ( c . Request . Context ( ) , retryLogStr )
common . LogInfo ( c , retryLogStr )
}
if taskErr != nil {
if taskErr . StatusCode == http . StatusTooManyRequests {