Files
new-api/controller/channel-test.go
Calcium-Ion 2911b9cd04 Merge pull request #1368 from RedwindA/fix/embedding-test
fix: 修复Gemini渠道的向量模型测试
2025-07-17 19:16:10 +08:00

465 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package controller
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"math"
"net/http"
"net/http/httptest"
"net/url"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/middleware"
"one-api/model"
"one-api/relay"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"strconv"
"strings"
"sync"
"time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
)
type testResult struct {
context *gin.Context
localErr error
newAPIError *types.NewAPIError
}
func testChannel(channel *model.Channel, testModel string) testResult {
tik := time.Now()
if channel.Type == constant.ChannelTypeMidjourney {
return testResult{
localErr: errors.New("midjourney channel test is not supported"),
newAPIError: nil,
}
}
if channel.Type == constant.ChannelTypeMidjourneyPlus {
return testResult{
localErr: errors.New("midjourney plus channel test is not supported"),
newAPIError: nil,
}
}
if channel.Type == constant.ChannelTypeSunoAPI {
return testResult{
localErr: errors.New("suno channel test is not supported"),
newAPIError: nil,
}
}
if channel.Type == constant.ChannelTypeKling {
return testResult{
localErr: errors.New("kling channel test is not supported"),
newAPIError: nil,
}
}
if channel.Type == constant.ChannelTypeJimeng {
return testResult{
localErr: errors.New("jimeng channel test is not supported"),
newAPIError: nil,
}
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
requestPath := "/v1/chat/completions"
// 先判断是否为 Embedding 模型
if strings.Contains(strings.ToLower(testModel), "embedding") ||
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
strings.Contains(testModel, "bge-") || // bge 系列模型
strings.Contains(testModel, "embed") ||
channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
requestPath = "/v1/embeddings" // 修改请求路径
}
c.Request = &http.Request{
Method: "POST",
URL: &url.URL{Path: requestPath}, // 使用动态路径
Body: nil,
Header: make(http.Header),
}
if testModel == "" {
if channel.TestModel != nil && *channel.TestModel != "" {
testModel = *channel.TestModel
} else {
if len(channel.GetModels()) > 0 {
testModel = channel.GetModels()[0]
} else {
testModel = "gpt-4o-mini"
}
}
}
cache, err := model.GetUserCache(1)
if err != nil {
return testResult{
localErr: err,
newAPIError: nil,
}
}
cache.WriteContext(c)
//c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL())
group, _ := model.GetUserGroup(1, false)
c.Set("group", group)
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
if newAPIError != nil {
return testResult{
context: c,
localErr: newAPIError,
newAPIError: newAPIError,
}
}
info := relaycommon.GenRelayInfo(c)
err = helper.ModelMappedHelper(c, info, nil)
if err != nil {
return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
}
}
testModel = info.UpstreamModelName
apiType, _ := common.ChannelType2APIType(channel.Type)
adaptor := relay.GetAdaptor(apiType)
if adaptor == nil {
return testResult{
context: c,
localErr: fmt.Errorf("invalid api type: %d, adaptor is nil", apiType),
newAPIError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType),
}
}
request := buildTestRequest(testModel)
// 创建一个用于日志的 info 副本,移除 ApiKey
logInfo := *info
logInfo.ApiKey = ""
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
if err != nil {
return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeModelPriceError),
}
}
adaptor.Init(info)
var convertedRequest any
// 根据 RelayMode 选择正确的转换函数
if info.RelayMode == relayconstant.RelayModeEmbeddings {
// 创建一个 EmbeddingRequest
embeddingRequest := dto.EmbeddingRequest{
Input: request.Input,
Model: request.Model,
}
// 调用专门用于 Embedding 的转换函数
convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest)
} else {
// 对其他所有请求类型(如 Chat保持原有逻辑
convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request)
}
if err != nil {
return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
}
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
}
}
requestBody := bytes.NewBuffer(jsonData)
c.Request.Body = io.NopCloser(requestBody)
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeDoRequestFailed),
}
}
var httpResp *http.Response
if resp != nil {
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
err := service.RelayErrorHandler(httpResp, true)
return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeBadResponse),
}
}
}
usageA, respErr := adaptor.DoResponse(c, httpResp, info)
if respErr != nil {
return testResult{
context: c,
localErr: respErr,
newAPIError: respErr,
}
}
if usageA == nil {
return testResult{
context: c,
localErr: errors.New("usage is nil"),
newAPIError: types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody),
}
}
usage := usageA.(*dto.Usage)
result := w.Result()
respBody, err := io.ReadAll(result.Body)
if err != nil {
return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeReadResponseBodyFailed),
}
}
info.PromptTokens = usage.PromptTokens
quota := 0
if !priceData.UsePrice {
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
quota = int(math.Round(float64(quota) * priceData.ModelRatio))
if priceData.ModelRatio != 0 && quota <= 0 {
quota = 1
}
} else {
quota = int(priceData.ModelPrice * common.QuotaPerUnit)
}
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
consumedTime := float64(milliseconds) / 1000.0
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
ChannelId: channel.Id,
PromptTokens: usage.PromptTokens,
CompletionTokens: usage.CompletionTokens,
ModelName: info.OriginModelName,
TokenName: "模型测试",
Quota: quota,
Content: "模型测试",
UseTimeSeconds: int(consumedTime),
IsStream: false,
Group: info.UsingGroup,
Other: other,
})
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
return testResult{
context: c,
localErr: nil,
newAPIError: nil,
}
}
func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
testRequest := &dto.GeneralOpenAIRequest{
Model: "", // this will be set later
Stream: false,
}
// 先判断是否为 Embedding 模型
if strings.Contains(strings.ToLower(model), "embedding") || // 其他 embedding 模型
strings.HasPrefix(model, "m3e") || // m3e 系列模型
strings.Contains(model, "bge-") {
testRequest.Model = model
// Embedding 请求
testRequest.Input = []any{"hello world"} // 修改为any因为dto/openai_request.go 的ParseInput方法无法处理[]string类型
return testRequest
}
// 并非Embedding 模型
if strings.HasPrefix(model, "o") {
testRequest.MaxCompletionTokens = 10
} else if strings.Contains(model, "thinking") {
if !strings.Contains(model, "claude") {
testRequest.MaxTokens = 50
}
} else if strings.Contains(model, "gemini") {
testRequest.MaxTokens = 3000
} else {
testRequest.MaxTokens = 10
}
testMessage := dto.Message{
Role: "user",
Content: "hi",
}
testRequest.Model = model
testRequest.Messages = append(testRequest.Messages, testMessage)
return testRequest
}
func TestChannel(c *gin.Context) {
channelId, err := strconv.Atoi(c.Param("id"))
if err != nil {
common.ApiError(c, err)
return
}
channel, err := model.CacheGetChannel(channelId)
if err != nil {
common.ApiError(c, err)
return
}
//defer func() {
// if channel.ChannelInfo.IsMultiKey {
// go func() { _ = channel.SaveChannelInfo() }()
// }
//}()
testModel := c.Query("model")
tik := time.Now()
result := testChannel(channel, testModel)
if result.localErr != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": result.localErr.Error(),
"time": 0.0,
})
return
}
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
go channel.UpdateResponseTime(milliseconds)
consumedTime := float64(milliseconds) / 1000.0
if result.newAPIError != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": result.newAPIError.Error(),
"time": consumedTime,
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"time": consumedTime,
})
return
}
var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
func testAllChannels(notify bool) error {
testAllChannelsLock.Lock()
if testAllChannelsRunning {
testAllChannelsLock.Unlock()
return errors.New("测试已在运行中")
}
testAllChannelsRunning = true
testAllChannelsLock.Unlock()
channels, getChannelErr := model.GetAllChannels(0, 0, true, false)
if getChannelErr != nil {
return getChannelErr
}
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
}
gopool.Go(func() {
// 使用 defer 确保无论如何都会重置运行状态,防止死锁
defer func() {
testAllChannelsLock.Lock()
testAllChannelsRunning = false
testAllChannelsLock.Unlock()
}()
for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
result := testChannel(channel, "")
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
shouldBanChannel := false
newAPIError := result.newAPIError
// request error disables the channel
if newAPIError != nil {
shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError)
}
// 当错误检查通过,才检查响应时间
if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
if milliseconds > disableThreshold {
err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
newAPIError = types.NewError(err, types.ErrorCodeChannelResponseTimeExceeded)
shouldBanChannel = true
}
}
// disable channel
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
go processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
}
// enable channel
if !isChannelEnabled && service.ShouldEnableChannel(newAPIError, channel.Status) {
service.EnableChannel(channel.Id, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.Name)
}
channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval)
}
if notify {
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
}
})
return nil
}
func TestAllChannels(c *gin.Context) {
err := testAllChannels(true)
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func AutomaticallyTestChannels(frequency int) {
if frequency <= 0 {
common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test")
return
}
for {
time.Sleep(time.Duration(frequency) * time.Minute)
common.SysLog("testing all channels")
_ = testAllChannels(false)
common.SysLog("channel test finished")
}
}