@@ -38,7 +38,7 @@ type testResult struct {
newAPIError * types . NewAPIError
}
func testChannel ( channel * model . Channel , testModel string ) testResult {
func testChannel ( channel * model . Channel , testModel string , endpointType string ) testResult {
tik := time . Now ( )
if channel . Type == constant . ChannelTypeMidjourney {
return testResult {
@@ -81,18 +81,26 @@ func testChannel(channel *model.Channel, testModel string) testResult {
requestPath := "/v1/chat/completions"
// 先判断是否为 Embedding 模 型
if strings . Contains ( strings . ToLower ( testModel ) , "embedding" ) ||
strings . HasPrefix ( testModel , "m3e" ) || // m3e 系列模型
strings . Contains ( testModel , "bge-" ) || // bge 系列模型
strings . Contains ( testModel , "embed" ) ||
channel . Type == constant . ChannelTypeMokaAI { // 其他 embedding 模型
requestPath = "/v1/embeddings" // 修改请求路径
}
// 如果指定了端点类型,使用指定的端点类 型
if endpointType != "" {
if endpointInfo , ok := common . GetDefaultEndpointInfo ( constant . EndpointType ( endpointType ) ) ; ok {
requestPath = endpointInfo . Path
}
} else {
// 如果没有指定端点类型,使用原有的自动检测逻辑
// 先判断是否为 Embedding 模型
if strings . Contains ( strings . ToLower ( testModel ) , "embedding" ) ||
strings . HasPrefix ( testModel , "m3e" ) || // m3e 系列模型
strings . Contains ( testModel , "bge-" ) || // bge 系列模型
strings . Contains ( testModel , "embed" ) ||
channel . Type == constant . ChannelTypeMokaAI { // 其他 embedding 模型
requestPath = "/v1/embeddings" // 修改请求路径
}
// VolcEngine 图像生成模型
if channel . Type == constant . ChannelTypeVolcEngine && strings . Contains ( testModel , "seedream" ) {
requestPath = "/v1/images/generations"
// VolcEngine 图像生成模型
if channel . Type == constant . ChannelTypeVolcEngine && strings . Contains ( testModel , "seedream" ) {
requestPath = "/v1/images/generations"
}
}
c . Request = & http . Request {
@@ -114,21 +122,6 @@ func testChannel(channel *model.Channel, testModel string) testResult {
}
}
// 重新检查模型类型并更新请求路径
if strings . Contains ( strings . ToLower ( testModel ) , "embedding" ) ||
strings . HasPrefix ( testModel , "m3e" ) ||
strings . Contains ( testModel , "bge-" ) ||
strings . Contains ( testModel , "embed" ) ||
channel . Type == constant . ChannelTypeMokaAI {
requestPath = "/v1/embeddings"
c . Request . URL . Path = requestPath
}
if channel . Type == constant . ChannelTypeVolcEngine && strings . Contains ( testModel , "seedream" ) {
requestPath = "/v1/images/generations"
c . Request . URL . Path = requestPath
}
cache , err := model . GetUserCache ( 1 )
if err != nil {
return testResult {
@@ -153,17 +146,54 @@ func testChannel(channel *model.Channel, testModel string) testResult {
newAPIError : newAPIError ,
}
}
request := buildTestRequest ( testModel )
// Determine relay format based on request path
relayFormat := types . RelayFormatOpenAI
if c . Request . URL . Path = = "/v1/embeddings " {
relayFormat = types . R elayFormatEmbedding
}
if c . Request . URL . Path == "/v1/images/generations" {
relayFormat = types . RelayFormatOpenAIImage
// Determine relay format based on endpoint type or request path
var relayFormat types . RelayFormat
if endpointType ! = "" {
// 根据指定的端点类型设置 r elayFormat
switch constant . EndpointType ( endpointType ) {
case constant . EndpointTypeOpenAI :
relayFormat = types . RelayFormatOpenAI
case constant . EndpointTypeOpenAIResponse :
relayFormat = types . RelayFormatOpenAIResponses
case constant . EndpointTypeAnthropic :
relayFormat = types . RelayFormatClaude
case constant . EndpointTypeGemini :
relayFormat = types . RelayFormatGemini
case constant . EndpointTypeJinaRerank :
relayFormat = types . RelayFormatRerank
case constant . EndpointTypeImageGeneration :
relayFormat = types . RelayFormatOpenAIImage
case constant . EndpointTypeEmbeddings :
relayFormat = types . RelayFormatEmbedding
default :
relayFormat = types . RelayFormatOpenAI
}
} else {
// 根据请求路径自动检测
relayFormat = types . RelayFormatOpenAI
if c . Request . URL . Path == "/v1/embeddings" {
relayFormat = types . RelayFormatEmbedding
}
if c . Request . URL . Path == "/v1/images/generations" {
relayFormat = types . RelayFormatOpenAIImage
}
if c . Request . URL . Path == "/v1/messages" {
relayFormat = types . RelayFormatClaude
}
if strings . Contains ( c . Request . URL . Path , "/v1beta/models" ) {
relayFormat = types . RelayFormatGemini
}
if c . Request . URL . Path == "/v1/rerank" || c . Request . URL . Path == "/rerank" {
relayFormat = types . RelayFormatRerank
}
if c . Request . URL . Path == "/v1/responses" {
relayFormat = types . RelayFormatOpenAIResponses
}
}
request := buildTestRequest ( testModel , endpointType )
info , err := relaycommon . GenRelayInfo ( c , relayFormat , request , nil )
if err != nil {
@@ -186,7 +216,8 @@ func testChannel(channel *model.Channel, testModel string) testResult {
}
testModel = info . UpstreamModelName
request . Model = testModel
// 更新请求中的模型名称
request . SetModelName ( testModel )
apiType , _ := common . ChannelType2APIType ( channel . Type )
adaptor := relay . GetAdaptor ( apiType )
@@ -216,33 +247,62 @@ func testChannel(channel *model.Channel, testModel string) testResult {
var convertedRequest any
// 根据 RelayMode 选择正确的转换函数
if info . RelayMode == relayconstant . RelayModeEmbeddings {
// 创建一个 EmbeddingRequest
e mbeddingRequest := dto . EmbeddingRequest {
Input : request . Input ,
Model : request . Model ,
}
// 调用专门用于 Embedding 的转换函数
convertedReques t , err = adaptor . ConvertEmbeddingRequest ( c , info , embeddingRequest )
} else if info . RelayMode == relayconstant . RelayModeImagesGenerations {
// 创建一个 ImageRequest
prompt := "cat"
if request . Prompt != nil {
if promptStr , ok := request . Prompt . ( string ) ; ok && promptStr != "" {
prompt = promptStr
switch info . RelayMode {
case relayconstant . RelayMode Embeddings :
// E mbedding 请求 - request 已经是正确的类型
if embeddingReq , ok := request . ( * dto . EmbeddingRequest ) ; ok {
convertedRequest , err = adaptor . ConvertEmbeddingRequest ( c , info , * embeddingReq )
} else {
return testResult {
contex t : c ,
localErr : errors . New ( "invalid embedding request type" ) ,
newAPIError : types . NewError ( errors . New ( "invalid embedding request type" ) , types . ErrorCodeConvertRequestFailed ) ,
}
}
imageRequest := dto . ImageRequest {
Prompt : prompt ,
Model : request . Model ,
N : uint ( r equest . N ) ,
Size : request . Size ,
case relayconstant . RelayModeImagesGenerations :
// 图像生成请求 - request 已经是正确的类型
if imageReq , ok := request . ( * dto . ImageRequest ) ; ok {
convertedRequest , err = adaptor . ConvertImageR equest( c , info , * imageReq )
} else {
return testResult {
context : c ,
localErr : errors . New ( "invalid image request type" ) ,
newAPIError : types . NewError ( errors . New ( "invalid image request type" ) , types . ErrorCodeConvertRequestFailed ) ,
}
}
case relayconstant . RelayModeRerank :
// Rerank 请求 - request 已经是正确的类型
if rerankReq , ok := request . ( * dto . RerankRequest ) ; ok {
convertedRequest , err = adaptor . ConvertRerankRequest ( c , info . RelayMode , * rerankReq )
} else {
return testResult {
context : c ,
localErr : errors . New ( "invalid rerank request type" ) ,
newAPIError : types . NewError ( errors . New ( "invalid rerank request type" ) , types . ErrorCodeConvertRequestFailed ) ,
}
}
case relayconstant . RelayModeResponses :
// Response 请求 - request 已经是正确的类型
if responseReq , ok := request . ( * dto . OpenAIResponsesRequest ) ; ok {
convertedRequest , err = adaptor . ConvertOpenAIResponsesRequest ( c , info , * responseReq )
} else {
return testResult {
context : c ,
localErr : errors . New ( "invalid response request type" ) ,
newAPIError : types . NewError ( errors . New ( "invalid response request type" ) , types . ErrorCodeConvertRequestFailed ) ,
}
}
default :
// Chat/Completion 等其他请求类型
if generalReq , ok := request . ( * dto . GeneralOpenAIRequest ) ; ok {
convertedRequest , err = adaptor . ConvertOpenAIRequest ( c , info , generalReq )
} else {
return testResult {
context : c ,
localErr : errors . New ( "invalid general request type" ) ,
newAPIError : types . NewError ( errors . New ( "invalid general request type" ) , types . ErrorCodeConvertRequestFailed ) ,
}
}
// 调用专门用于图像生成的转换函数
convertedRequest , err = adaptor . ConvertImageRequest ( c , info , imageRequest )
} else {
// 对其他所有请求类型(如 Chat) , 保持原有逻辑
convertedRequest , err = adaptor . ConvertOpenAIRequest ( c , info , request )
}
if err != nil {
@@ -345,22 +405,82 @@ func testChannel(channel *model.Channel, testModel string) testResult {
}
}
func buildTestRequest ( model string ) * dto . GeneralOpenAI Request {
testRequest := & dto . GeneralOpenAIRequest {
Model : "" , // this will be set later
Stream : false ,
func buildTestRequest ( model string , endpointType string ) dto . Request {
// 根据端点类型构建不同的测试请求
if endpointType != "" {
switch constant . EndpointType ( endpointType ) {
case constant . EndpointTypeEmbeddings :
// 返回 EmbeddingRequest
return & dto . EmbeddingRequest {
Model : model ,
Input : [ ] any { "hello world" } ,
}
case constant . EndpointTypeImageGeneration :
// 返回 ImageRequest
return & dto . ImageRequest {
Model : model ,
Prompt : "a cute cat" ,
N : 1 ,
Size : "1024x1024" ,
}
case constant . EndpointTypeJinaRerank :
// 返回 RerankRequest
return & dto . RerankRequest {
Model : model ,
Query : "What is Deep Learning?" ,
Documents : [ ] any { "Deep Learning is a subset of machine learning." , "Machine learning is a field of artificial intelligence." } ,
TopN : 2 ,
}
case constant . EndpointTypeOpenAIResponse :
// 返回 OpenAIResponsesRequest
return & dto . OpenAIResponsesRequest {
Model : model ,
Input : json . RawMessage ( "\"hi\"" ) ,
}
case constant . EndpointTypeAnthropic , constant . EndpointTypeGemini , constant . EndpointTypeOpenAI :
// 返回 GeneralOpenAIRequest
maxTokens := uint ( 10 )
if constant . EndpointType ( endpointType ) == constant . EndpointTypeGemini {
maxTokens = 3000
}
return & dto . GeneralOpenAIRequest {
Model : model ,
Stream : false ,
Messages : [ ] dto . Message {
{
Role : "user" ,
Content : "hi" ,
} ,
} ,
MaxTokens : maxTokens ,
}
}
}
// 自动检测逻辑(保持原有行为)
// 先判断是否为 Embedding 模型
if strings . Contains ( strings . ToLower ( model ) , "embedding" ) || // 其他 embedding 模型
strings . HasPrefix ( model , "m3e" ) || // m3e 系列模型
if strings . Contains ( strings . ToLower ( model ) , "embedding" ) ||
strings . HasPrefix ( model , "m3e" ) ||
strings . Contains ( model , "bge-" ) {
testRequest . Model = model
// Embedding 请求
testRequest . Input = [ ] any { "hello world" } // 修改为any, 因为dto/openai_request.go 的ParseInput方法无法处理[]string类型
return testRequest
// 返回 EmbeddingRequest
return & dto . EmbeddingRequest {
Model : model ,
Input : [ ] any { "hello world" } ,
}
}
// 并非Embedding 模型
// Chat/Completion 请求 - 返回 GeneralOpenAIRequest
testRequest := & dto . GeneralOpenAIRequest {
Model : model ,
Stream : false ,
Messages : [ ] dto . Message {
{
Role : "user" ,
Content : "hi" ,
} ,
} ,
}
if strings . HasPrefix ( model , "o" ) {
testRequest . MaxCompletionTokens = 10
} else if strings . Contains ( model , "thinking" ) {
@@ -373,12 +493,6 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
testRequest . MaxTokens = 10
}
testMessage := dto . Message {
Role : "user" ,
Content : "hi" ,
}
testRequest . Model = model
testRequest . Messages = append ( testRequest . Messages , testMessage )
return testRequest
}
@@ -402,8 +516,9 @@ func TestChannel(c *gin.Context) {
// }
//}()
testModel := c . Query ( "model" )
endpointType := c . Query ( "endpoint_type" )
tik := time . Now ( )
result := testChannel ( channel , testModel )
result := testChannel ( channel , testModel , endpointType )
if result . localErr != nil {
c . JSON ( http . StatusOK , gin . H {
"success" : false ,
@@ -429,7 +544,6 @@ func TestChannel(c *gin.Context) {
"message" : "" ,
"time" : consumedTime ,
} )
return
}
var testAllChannelsLock sync . Mutex
@@ -463,7 +577,7 @@ func testAllChannels(notify bool) error {
for _ , channel := range channels {
isChannelEnabled := channel . Status == common . ChannelStatusEnabled
tik := time . Now ( )
result := testChannel ( channel , "" )
result := testChannel ( channel , "" , "" )
tok := time . Now ( )
milliseconds := tok . Sub ( tik ) . Milliseconds ( )
@@ -477,7 +591,7 @@ func testAllChannels(notify bool) error {
// 当错误检查通过,才检查响应时间
if common . AutomaticDisableChannelEnabled && ! shouldBanChannel {
if milliseconds > disableThreshold {
err := errors . New ( fmt. Sprint f( "响应时间 %.2fs 超过阈值 %.2fs" , float64 ( milliseconds ) / 1000.0 , float64 ( disableThreshold ) / 1000.0 ) )
err := fmt . Error f( "响应时间 %.2fs 超过阈值 %.2fs" , float64 ( milliseconds ) / 1000.0 , float64 ( disableThreshold ) / 1000.0 )
newAPIError = types . NewOpenAIError ( err , types . ErrorCodeChannelResponseTimeExceeded , http . StatusRequestTimeout )
shouldBanChannel = true
}
@@ -514,7 +628,6 @@ func TestAllChannels(c *gin.Context) {
"success" : true ,
"message" : "" ,
} )
return
}
var autoTestChannelsOnce sync . Once