This commit refactors the logging mechanism across the application by replacing direct logger calls with a centralized logging approach using the `common` package. Key changes include: - Replaced instances of `logger.SysLog` and `logger.FatalLog` with `common.SysLog` and `common.FatalLog` for consistent logging practices. - Updated resource initialization error handling to utilize the new logging structure, enhancing maintainability and readability. - Minor adjustments to improve code clarity and organization throughout various modules. This change aims to streamline logging and improve the overall architecture of the codebase.
484 lines
13 KiB
Go
484 lines
13 KiB
Go
package controller
|
||
|
||
import (
|
||
"bytes"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"math"
|
||
"net/http"
|
||
"net/http/httptest"
|
||
"net/url"
|
||
"one-api/common"
|
||
"one-api/constant"
|
||
"one-api/dto"
|
||
"one-api/middleware"
|
||
"one-api/model"
|
||
"one-api/relay"
|
||
relaycommon "one-api/relay/common"
|
||
relayconstant "one-api/relay/constant"
|
||
"one-api/relay/helper"
|
||
"one-api/service"
|
||
"one-api/types"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/bytedance/gopkg/util/gopool"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
type testResult struct {
|
||
context *gin.Context
|
||
localErr error
|
||
newAPIError *types.NewAPIError
|
||
}
|
||
|
||
func testChannel(channel *model.Channel, testModel string) testResult {
|
||
tik := time.Now()
|
||
if channel.Type == constant.ChannelTypeMidjourney {
|
||
return testResult{
|
||
localErr: errors.New("midjourney channel test is not supported"),
|
||
newAPIError: nil,
|
||
}
|
||
}
|
||
if channel.Type == constant.ChannelTypeMidjourneyPlus {
|
||
return testResult{
|
||
localErr: errors.New("midjourney plus channel test is not supported"),
|
||
newAPIError: nil,
|
||
}
|
||
}
|
||
if channel.Type == constant.ChannelTypeSunoAPI {
|
||
return testResult{
|
||
localErr: errors.New("suno channel test is not supported"),
|
||
newAPIError: nil,
|
||
}
|
||
}
|
||
if channel.Type == constant.ChannelTypeKling {
|
||
return testResult{
|
||
localErr: errors.New("kling channel test is not supported"),
|
||
newAPIError: nil,
|
||
}
|
||
}
|
||
if channel.Type == constant.ChannelTypeJimeng {
|
||
return testResult{
|
||
localErr: errors.New("jimeng channel test is not supported"),
|
||
newAPIError: nil,
|
||
}
|
||
}
|
||
if channel.Type == constant.ChannelTypeVidu {
|
||
return testResult{
|
||
localErr: errors.New("vidu channel test is not supported"),
|
||
newAPIError: nil,
|
||
}
|
||
}
|
||
w := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(w)
|
||
|
||
requestPath := "/v1/chat/completions"
|
||
|
||
// 先判断是否为 Embedding 模型
|
||
if strings.Contains(strings.ToLower(testModel), "embedding") ||
|
||
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
|
||
strings.Contains(testModel, "bge-") || // bge 系列模型
|
||
strings.Contains(testModel, "embed") ||
|
||
channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
|
||
requestPath = "/v1/embeddings" // 修改请求路径
|
||
}
|
||
|
||
c.Request = &http.Request{
|
||
Method: "POST",
|
||
URL: &url.URL{Path: requestPath}, // 使用动态路径
|
||
Body: nil,
|
||
Header: make(http.Header),
|
||
}
|
||
|
||
if testModel == "" {
|
||
if channel.TestModel != nil && *channel.TestModel != "" {
|
||
testModel = *channel.TestModel
|
||
} else {
|
||
if len(channel.GetModels()) > 0 {
|
||
testModel = channel.GetModels()[0]
|
||
} else {
|
||
testModel = "gpt-4o-mini"
|
||
}
|
||
}
|
||
}
|
||
|
||
cache, err := model.GetUserCache(1)
|
||
if err != nil {
|
||
return testResult{
|
||
localErr: err,
|
||
newAPIError: nil,
|
||
}
|
||
}
|
||
cache.WriteContext(c)
|
||
|
||
//c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||
c.Request.Header.Set("Content-Type", "application/json")
|
||
c.Set("channel", channel.Type)
|
||
c.Set("base_url", channel.GetBaseURL())
|
||
group, _ := model.GetUserGroup(1, false)
|
||
c.Set("group", group)
|
||
|
||
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
||
if newAPIError != nil {
|
||
return testResult{
|
||
context: c,
|
||
localErr: newAPIError,
|
||
newAPIError: newAPIError,
|
||
}
|
||
}
|
||
request := buildTestRequest(testModel)
|
||
|
||
info, err := relaycommon.GenRelayInfo(c, types.RelayFormatOpenAI, request, nil)
|
||
|
||
if err != nil {
|
||
return testResult{
|
||
context: c,
|
||
localErr: err,
|
||
newAPIError: types.NewError(err, types.ErrorCodeGenRelayInfoFailed),
|
||
}
|
||
}
|
||
|
||
err = helper.ModelMappedHelper(c, info, nil)
|
||
if err != nil {
|
||
return testResult{
|
||
context: c,
|
||
localErr: err,
|
||
newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
|
||
}
|
||
}
|
||
|
||
testModel = info.UpstreamModelName
|
||
request.Model = testModel
|
||
|
||
apiType, _ := common.ChannelType2APIType(channel.Type)
|
||
adaptor := relay.GetAdaptor(apiType)
|
||
if adaptor == nil {
|
||
return testResult{
|
||
context: c,
|
||
localErr: fmt.Errorf("invalid api type: %d, adaptor is nil", apiType),
|
||
newAPIError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType),
|
||
}
|
||
}
|
||
|
||
// 创建一个用于日志的 info 副本,移除 ApiKey
|
||
logInfo := *info
|
||
logInfo.ApiKey = ""
|
||
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
|
||
|
||
priceData, err := helper.ModelPriceHelper(c, info, 0, request.GetTokenCountMeta())
|
||
if err != nil {
|
||
return testResult{
|
||
context: c,
|
||
localErr: err,
|
||
newAPIError: types.NewError(err, types.ErrorCodeModelPriceError),
|
||
}
|
||
}
|
||
|
||
adaptor.Init(info)
|
||
|
||
var convertedRequest any
|
||
// 根据 RelayMode 选择正确的转换函数
|
||
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||
// 创建一个 EmbeddingRequest
|
||
embeddingRequest := dto.EmbeddingRequest{
|
||
Input: request.Input,
|
||
Model: request.Model,
|
||
}
|
||
// 调用专门用于 Embedding 的转换函数
|
||
convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest)
|
||
} else {
|
||
// 对其他所有请求类型(如 Chat),保持原有逻辑
|
||
convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request)
|
||
}
|
||
|
||
if err != nil {
|
||
return testResult{
|
||
context: c,
|
||
localErr: err,
|
||
newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
|
||
}
|
||
}
|
||
jsonData, err := json.Marshal(convertedRequest)
|
||
if err != nil {
|
||
return testResult{
|
||
context: c,
|
||
localErr: err,
|
||
newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
|
||
}
|
||
}
|
||
requestBody := bytes.NewBuffer(jsonData)
|
||
c.Request.Body = io.NopCloser(requestBody)
|
||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||
if err != nil {
|
||
return testResult{
|
||
context: c,
|
||
localErr: err,
|
||
newAPIError: types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError),
|
||
}
|
||
}
|
||
var httpResp *http.Response
|
||
if resp != nil {
|
||
httpResp = resp.(*http.Response)
|
||
if httpResp.StatusCode != http.StatusOK {
|
||
err := service.RelayErrorHandler(httpResp, true)
|
||
return testResult{
|
||
context: c,
|
||
localErr: err,
|
||
newAPIError: types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError),
|
||
}
|
||
}
|
||
}
|
||
usageA, respErr := adaptor.DoResponse(c, httpResp, info)
|
||
if respErr != nil {
|
||
return testResult{
|
||
context: c,
|
||
localErr: respErr,
|
||
newAPIError: respErr,
|
||
}
|
||
}
|
||
if usageA == nil {
|
||
return testResult{
|
||
context: c,
|
||
localErr: errors.New("usage is nil"),
|
||
newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
|
||
}
|
||
}
|
||
usage := usageA.(*dto.Usage)
|
||
result := w.Result()
|
||
respBody, err := io.ReadAll(result.Body)
|
||
if err != nil {
|
||
return testResult{
|
||
context: c,
|
||
localErr: err,
|
||
newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
|
||
}
|
||
}
|
||
info.PromptTokens = usage.PromptTokens
|
||
|
||
quota := 0
|
||
if !priceData.UsePrice {
|
||
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
|
||
quota = int(math.Round(float64(quota) * priceData.ModelRatio))
|
||
if priceData.ModelRatio != 0 && quota <= 0 {
|
||
quota = 1
|
||
}
|
||
} else {
|
||
quota = int(priceData.ModelPrice * common.QuotaPerUnit)
|
||
}
|
||
tok := time.Now()
|
||
milliseconds := tok.Sub(tik).Milliseconds()
|
||
consumedTime := float64(milliseconds) / 1000.0
|
||
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
|
||
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||
model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
|
||
ChannelId: channel.Id,
|
||
PromptTokens: usage.PromptTokens,
|
||
CompletionTokens: usage.CompletionTokens,
|
||
ModelName: info.OriginModelName,
|
||
TokenName: "模型测试",
|
||
Quota: quota,
|
||
Content: "模型测试",
|
||
UseTimeSeconds: int(consumedTime),
|
||
IsStream: info.IsStream,
|
||
Group: info.UsingGroup,
|
||
Other: other,
|
||
})
|
||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||
return testResult{
|
||
context: c,
|
||
localErr: nil,
|
||
newAPIError: nil,
|
||
}
|
||
}
|
||
|
||
func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
||
testRequest := &dto.GeneralOpenAIRequest{
|
||
Model: "", // this will be set later
|
||
Stream: false,
|
||
}
|
||
|
||
// 先判断是否为 Embedding 模型
|
||
if strings.Contains(strings.ToLower(model), "embedding") || // 其他 embedding 模型
|
||
strings.HasPrefix(model, "m3e") || // m3e 系列模型
|
||
strings.Contains(model, "bge-") {
|
||
testRequest.Model = model
|
||
// Embedding 请求
|
||
testRequest.Input = []any{"hello world"} // 修改为any,因为dto/openai_request.go 的ParseInput方法无法处理[]string类型
|
||
return testRequest
|
||
}
|
||
// 并非Embedding 模型
|
||
if strings.HasPrefix(model, "o") {
|
||
testRequest.MaxCompletionTokens = 10
|
||
} else if strings.Contains(model, "thinking") {
|
||
if !strings.Contains(model, "claude") {
|
||
testRequest.MaxTokens = 50
|
||
}
|
||
} else if strings.Contains(model, "gemini") {
|
||
testRequest.MaxTokens = 3000
|
||
} else {
|
||
testRequest.MaxTokens = 10
|
||
}
|
||
|
||
testMessage := dto.Message{
|
||
Role: "user",
|
||
Content: "hi",
|
||
}
|
||
testRequest.Model = model
|
||
testRequest.Messages = append(testRequest.Messages, testMessage)
|
||
return testRequest
|
||
}
|
||
|
||
func TestChannel(c *gin.Context) {
|
||
channelId, err := strconv.Atoi(c.Param("id"))
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
channel, err := model.CacheGetChannel(channelId)
|
||
if err != nil {
|
||
channel, err = model.GetChannelById(channelId, true)
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
}
|
||
//defer func() {
|
||
// if channel.ChannelInfo.IsMultiKey {
|
||
// go func() { _ = channel.SaveChannelInfo() }()
|
||
// }
|
||
//}()
|
||
testModel := c.Query("model")
|
||
tik := time.Now()
|
||
result := testChannel(channel, testModel)
|
||
if result.localErr != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": result.localErr.Error(),
|
||
"time": 0.0,
|
||
})
|
||
return
|
||
}
|
||
tok := time.Now()
|
||
milliseconds := tok.Sub(tik).Milliseconds()
|
||
go channel.UpdateResponseTime(milliseconds)
|
||
consumedTime := float64(milliseconds) / 1000.0
|
||
if result.newAPIError != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": result.newAPIError.Error(),
|
||
"time": consumedTime,
|
||
})
|
||
return
|
||
}
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
"time": consumedTime,
|
||
})
|
||
return
|
||
}
|
||
|
||
var testAllChannelsLock sync.Mutex
|
||
var testAllChannelsRunning bool = false
|
||
|
||
func testAllChannels(notify bool) error {
|
||
|
||
testAllChannelsLock.Lock()
|
||
if testAllChannelsRunning {
|
||
testAllChannelsLock.Unlock()
|
||
return errors.New("测试已在运行中")
|
||
}
|
||
testAllChannelsRunning = true
|
||
testAllChannelsLock.Unlock()
|
||
channels, getChannelErr := model.GetAllChannels(0, 0, true, false)
|
||
if getChannelErr != nil {
|
||
return getChannelErr
|
||
}
|
||
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
|
||
if disableThreshold == 0 {
|
||
disableThreshold = 10000000 // a impossible value
|
||
}
|
||
gopool.Go(func() {
|
||
// 使用 defer 确保无论如何都会重置运行状态,防止死锁
|
||
defer func() {
|
||
testAllChannelsLock.Lock()
|
||
testAllChannelsRunning = false
|
||
testAllChannelsLock.Unlock()
|
||
}()
|
||
|
||
for _, channel := range channels {
|
||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||
tik := time.Now()
|
||
result := testChannel(channel, "")
|
||
tok := time.Now()
|
||
milliseconds := tok.Sub(tik).Milliseconds()
|
||
|
||
shouldBanChannel := false
|
||
newAPIError := result.newAPIError
|
||
// request error disables the channel
|
||
if newAPIError != nil {
|
||
shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError)
|
||
}
|
||
|
||
// 当错误检查通过,才检查响应时间
|
||
if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
|
||
if milliseconds > disableThreshold {
|
||
err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
||
newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout)
|
||
shouldBanChannel = true
|
||
}
|
||
}
|
||
|
||
// disable channel
|
||
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
|
||
go processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||
}
|
||
|
||
// enable channel
|
||
if !isChannelEnabled && service.ShouldEnableChannel(newAPIError, channel.Status) {
|
||
service.EnableChannel(channel.Id, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.Name)
|
||
}
|
||
|
||
channel.UpdateResponseTime(milliseconds)
|
||
time.Sleep(common.RequestInterval)
|
||
}
|
||
|
||
if notify {
|
||
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
|
||
}
|
||
})
|
||
return nil
|
||
}
|
||
|
||
func TestAllChannels(c *gin.Context) {
|
||
err := testAllChannels(true)
|
||
if err != nil {
|
||
common.ApiError(c, err)
|
||
return
|
||
}
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
})
|
||
return
|
||
}
|
||
|
||
func AutomaticallyTestChannels(frequency int) {
|
||
if frequency <= 0 {
|
||
common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test")
|
||
return
|
||
}
|
||
for {
|
||
time.Sleep(time.Duration(frequency) * time.Minute)
|
||
common.SysLog("testing all channels")
|
||
_ = testAllChannels(false)
|
||
common.SysLog("channel test finished")
|
||
}
|
||
}
|