This commit refactors the logging mechanism across the application by replacing direct logger calls with a centralized logging approach using the `common` package. Key changes include: - Replaced instances of `logger.SysLog` and `logger.FatalLog` with `common.SysLog` and `common.FatalLog` for consistent logging practices. - Updated resource initialization error handling to utilize the new logging structure, enhancing maintainability and readability. - Minor adjustments to improve code clarity and organization throughout various modules. This change aims to streamline logging and improve the overall architecture of the codebase.
464 lines
14 KiB
Go
464 lines
14 KiB
Go
package controller
|
||
|
||
import (
|
||
"bytes"
|
||
"fmt"
|
||
"io"
|
||
"log"
|
||
"net/http"
|
||
"one-api/common"
|
||
"one-api/constant"
|
||
"one-api/dto"
|
||
"one-api/logger"
|
||
"one-api/middleware"
|
||
"one-api/model"
|
||
"one-api/relay"
|
||
relaycommon "one-api/relay/common"
|
||
relayconstant "one-api/relay/constant"
|
||
"one-api/relay/helper"
|
||
"one-api/service"
|
||
"one-api/setting"
|
||
"one-api/types"
|
||
"strings"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/gorilla/websocket"
|
||
)
|
||
|
||
func relayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError {
|
||
var err *types.NewAPIError
|
||
switch info.RelayMode {
|
||
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
||
err = relay.ImageHelper(c, info)
|
||
case relayconstant.RelayModeAudioSpeech:
|
||
fallthrough
|
||
case relayconstant.RelayModeAudioTranslation:
|
||
fallthrough
|
||
case relayconstant.RelayModeAudioTranscription:
|
||
err = relay.AudioHelper(c, info)
|
||
case relayconstant.RelayModeRerank:
|
||
err = relay.RerankHelper(c, info)
|
||
case relayconstant.RelayModeEmbeddings:
|
||
err = relay.EmbeddingHelper(c, info)
|
||
case relayconstant.RelayModeResponses:
|
||
err = relay.ResponsesHelper(c, info)
|
||
default:
|
||
err = relay.TextHelper(c, info)
|
||
}
|
||
return err
|
||
}
|
||
|
||
func geminiRelayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError {
|
||
var err *types.NewAPIError
|
||
if strings.Contains(c.Request.URL.Path, "embed") {
|
||
err = relay.GeminiEmbeddingHandler(c, info)
|
||
} else {
|
||
err = relay.GeminiHelper(c, info)
|
||
}
|
||
return err
|
||
}
|
||
|
||
func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
||
|
||
requestId := c.GetString(common.RequestIdKey)
|
||
group := c.GetString("group")
|
||
originalModel := c.GetString("original_model")
|
||
|
||
var (
|
||
newAPIError *types.NewAPIError
|
||
ws *websocket.Conn
|
||
)
|
||
|
||
if relayFormat == types.RelayFormatOpenAIRealtime {
|
||
var err error
|
||
ws, err = upgrader.Upgrade(c.Writer, c.Request, nil)
|
||
if err != nil {
|
||
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError())
|
||
return
|
||
}
|
||
defer ws.Close()
|
||
}
|
||
|
||
defer func() {
|
||
if newAPIError != nil {
|
||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||
switch relayFormat {
|
||
case types.RelayFormatOpenAIRealtime:
|
||
helper.WssError(c, ws, newAPIError.ToOpenAIError())
|
||
case types.RelayFormatClaude:
|
||
c.JSON(newAPIError.StatusCode, gin.H{
|
||
"type": "error",
|
||
"error": newAPIError.ToClaudeError(),
|
||
})
|
||
default:
|
||
c.JSON(newAPIError.StatusCode, gin.H{
|
||
"error": newAPIError.ToOpenAIError(),
|
||
})
|
||
}
|
||
}
|
||
}()
|
||
|
||
request, err := helper.GetAndValidateRequest(c, relayFormat)
|
||
if err != nil {
|
||
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
|
||
return
|
||
}
|
||
|
||
relayInfo, err := relaycommon.GenRelayInfo(c, relayFormat, request, ws)
|
||
if err != nil {
|
||
newAPIError = types.NewError(err, types.ErrorCodeGenRelayInfoFailed)
|
||
return
|
||
}
|
||
|
||
meta := request.GetTokenCountMeta()
|
||
|
||
if setting.ShouldCheckPromptSensitive() {
|
||
words, err := service.CheckSensitiveText(meta.CombineText)
|
||
if err != nil {
|
||
logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
|
||
newAPIError = types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
|
||
return
|
||
}
|
||
}
|
||
|
||
tokens, err := service.CountRequestToken(c, meta, relayInfo)
|
||
if err != nil {
|
||
newAPIError = types.NewError(err, types.ErrorCodeCountTokenFailed)
|
||
return
|
||
}
|
||
|
||
priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta)
|
||
if err != nil {
|
||
newAPIError = types.NewError(err, types.ErrorCodeModelPriceError)
|
||
return
|
||
}
|
||
|
||
preConsumedQuota, newApiErr := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||
if newApiErr != nil {
|
||
return
|
||
}
|
||
|
||
defer func() {
|
||
if newApiErr != nil {
|
||
service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota)
|
||
}
|
||
}()
|
||
|
||
for i := 0; i <= common.RetryTimes; i++ {
|
||
channel, err := getChannel(c, group, originalModel, i)
|
||
if err != nil {
|
||
logger.LogError(c, err.Error())
|
||
newAPIError = err
|
||
break
|
||
}
|
||
|
||
addUsedChannel(c, channel.Id)
|
||
requestBody, _ := common.GetRequestBody(c)
|
||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||
|
||
switch relayFormat {
|
||
case types.RelayFormatOpenAIRealtime:
|
||
newAPIError = relay.WssHelper(c, relayInfo)
|
||
case types.RelayFormatClaude:
|
||
newAPIError = relay.ClaudeHelper(c, relayInfo)
|
||
case types.RelayFormatGemini:
|
||
newAPIError = geminiRelayHandler(c, relayInfo)
|
||
default:
|
||
newAPIError = relayHandler(c, relayInfo)
|
||
}
|
||
|
||
if newAPIError == nil {
|
||
return
|
||
} else {
|
||
if constant.ErrorLogEnabled && types.IsRecordErrorLog(newAPIError) {
|
||
// 保存错误日志到mysql中
|
||
userId := c.GetInt("id")
|
||
tokenName := c.GetString("token_name")
|
||
modelName := c.GetString("original_model")
|
||
tokenId := c.GetInt("token_id")
|
||
userGroup := c.GetString("group")
|
||
channelId := c.GetInt("channel_id")
|
||
other := make(map[string]interface{})
|
||
other["error_type"] = newAPIError.GetErrorType()
|
||
other["error_code"] = newAPIError.GetErrorCode()
|
||
other["status_code"] = newAPIError.StatusCode
|
||
other["channel_id"] = channelId
|
||
other["channel_name"] = c.GetString("channel_name")
|
||
other["channel_type"] = c.GetInt("channel_type")
|
||
adminInfo := make(map[string]interface{})
|
||
adminInfo["use_channel"] = c.GetStringSlice("use_channel")
|
||
isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
|
||
if isMultiKey {
|
||
adminInfo["is_multi_key"] = true
|
||
adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
|
||
}
|
||
other["admin_info"] = adminInfo
|
||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, newAPIError.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
|
||
}
|
||
}
|
||
|
||
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||
|
||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||
break
|
||
}
|
||
}
|
||
|
||
useChannel := c.GetStringSlice("use_channel")
|
||
if len(useChannel) > 1 {
|
||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||
logger.LogInfo(c, retryLogStr)
|
||
}
|
||
}
|
||
|
||
var upgrader = websocket.Upgrader{
|
||
Subprotocols: []string{"realtime"}, // WS 握手支持的协议,如果有使用 Sec-WebSocket-Protocol,则必须在此声明对应的 Protocol TODO add other protocol
|
||
CheckOrigin: func(r *http.Request) bool {
|
||
return true // 允许跨域
|
||
},
|
||
}
|
||
|
||
func addUsedChannel(c *gin.Context, channelId int) {
|
||
useChannel := c.GetStringSlice("use_channel")
|
||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||
c.Set("use_channel", useChannel)
|
||
}
|
||
|
||
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, *types.NewAPIError) {
|
||
if retryCount == 0 {
|
||
autoBan := c.GetBool("auto_ban")
|
||
autoBanInt := 1
|
||
if !autoBan {
|
||
autoBanInt = 0
|
||
}
|
||
return &model.Channel{
|
||
Id: c.GetInt("channel_id"),
|
||
Type: c.GetInt("channel_type"),
|
||
Name: c.GetString("channel_name"),
|
||
AutoBan: &autoBanInt,
|
||
}, nil
|
||
}
|
||
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
||
if err != nil {
|
||
return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||
}
|
||
if channel == nil {
|
||
return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||
}
|
||
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||
if newAPIError != nil {
|
||
return nil, newAPIError
|
||
}
|
||
return channel, nil
|
||
}
|
||
|
||
func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) bool {
|
||
if openaiErr == nil {
|
||
return false
|
||
}
|
||
if types.IsChannelError(openaiErr) {
|
||
return true
|
||
}
|
||
if types.IsSkipRetryError(openaiErr) {
|
||
return false
|
||
}
|
||
if retryTimes <= 0 {
|
||
return false
|
||
}
|
||
if _, ok := c.Get("specific_channel_id"); ok {
|
||
return false
|
||
}
|
||
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
||
return true
|
||
}
|
||
if openaiErr.StatusCode == 307 {
|
||
return true
|
||
}
|
||
if openaiErr.StatusCode/100 == 5 {
|
||
// 超时不重试
|
||
if openaiErr.StatusCode == 504 || openaiErr.StatusCode == 524 {
|
||
return false
|
||
}
|
||
return true
|
||
}
|
||
if openaiErr.StatusCode == http.StatusBadRequest {
|
||
return false
|
||
}
|
||
if openaiErr.StatusCode == 408 {
|
||
// azure处理超时不重试
|
||
return false
|
||
}
|
||
if openaiErr.StatusCode/100 == 2 {
|
||
return false
|
||
}
|
||
return true
|
||
}
|
||
|
||
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
|
||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||
logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
||
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
||
service.DisableChannel(channelError, err.Error())
|
||
}
|
||
}
|
||
|
||
func RelayMidjourney(c *gin.Context) {
|
||
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatMjProxy, nil, nil)
|
||
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"description": fmt.Sprintf("failed to generate relay info: %s", err.Error()),
|
||
"type": "upstream_error",
|
||
"code": 4,
|
||
})
|
||
return
|
||
}
|
||
|
||
var mjErr *dto.MidjourneyResponse
|
||
switch relayInfo.RelayMode {
|
||
case relayconstant.RelayModeMidjourneyNotify:
|
||
mjErr = relay.RelayMidjourneyNotify(c)
|
||
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
|
||
mjErr = relay.RelayMidjourneyTask(c, relayInfo.RelayMode)
|
||
case relayconstant.RelayModeMidjourneyTaskImageSeed:
|
||
mjErr = relay.RelayMidjourneyTaskImageSeed(c)
|
||
case relayconstant.RelayModeSwapFace:
|
||
mjErr = relay.RelaySwapFace(c, relayInfo)
|
||
default:
|
||
mjErr = relay.RelayMidjourneySubmit(c, relayInfo)
|
||
}
|
||
//err = relayMidjourneySubmit(c, relayMode)
|
||
log.Println(mjErr)
|
||
if mjErr != nil {
|
||
statusCode := http.StatusBadRequest
|
||
if mjErr.Code == 30 {
|
||
mjErr.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
|
||
statusCode = http.StatusTooManyRequests
|
||
}
|
||
c.JSON(statusCode, gin.H{
|
||
"description": fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result),
|
||
"type": "upstream_error",
|
||
"code": mjErr.Code,
|
||
})
|
||
channelId := c.GetInt("channel_id")
|
||
logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result)))
|
||
}
|
||
}
|
||
|
||
func RelayNotImplemented(c *gin.Context) {
|
||
err := dto.OpenAIError{
|
||
Message: "API not implemented",
|
||
Type: "new_api_error",
|
||
Param: "",
|
||
Code: "api_not_implemented",
|
||
}
|
||
c.JSON(http.StatusNotImplemented, gin.H{
|
||
"error": err,
|
||
})
|
||
}
|
||
|
||
func RelayNotFound(c *gin.Context) {
|
||
err := dto.OpenAIError{
|
||
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
|
||
Type: "invalid_request_error",
|
||
Param: "",
|
||
Code: "",
|
||
}
|
||
c.JSON(http.StatusNotFound, gin.H{
|
||
"error": err,
|
||
})
|
||
}
|
||
|
||
func RelayTask(c *gin.Context) {
|
||
retryTimes := common.RetryTimes
|
||
channelId := c.GetInt("channel_id")
|
||
relayMode := c.GetInt("relay_mode")
|
||
group := c.GetString("group")
|
||
originalModel := c.GetString("original_model")
|
||
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
|
||
taskErr := taskRelayHandler(c, relayMode)
|
||
if taskErr == nil {
|
||
retryTimes = 0
|
||
}
|
||
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
||
channel, newAPIError := getChannel(c, group, originalModel, i)
|
||
if newAPIError != nil {
|
||
logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
|
||
taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
|
||
break
|
||
}
|
||
channelId = channel.Id
|
||
useChannel := c.GetStringSlice("use_channel")
|
||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||
c.Set("use_channel", useChannel)
|
||
logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
||
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||
|
||
requestBody, _ := common.GetRequestBody(c)
|
||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||
taskErr = taskRelayHandler(c, relayMode)
|
||
}
|
||
useChannel := c.GetStringSlice("use_channel")
|
||
if len(useChannel) > 1 {
|
||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||
logger.LogInfo(c, retryLogStr)
|
||
}
|
||
if taskErr != nil {
|
||
if taskErr.StatusCode == http.StatusTooManyRequests {
|
||
taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
|
||
}
|
||
c.JSON(taskErr.StatusCode, taskErr)
|
||
}
|
||
}
|
||
|
||
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
|
||
var err *dto.TaskError
|
||
switch relayMode {
|
||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID:
|
||
err = relay.RelayTaskFetch(c, relayMode)
|
||
default:
|
||
err = relay.RelayTaskSubmit(c, relayMode)
|
||
}
|
||
return err
|
||
}
|
||
|
||
func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {
|
||
if taskErr == nil {
|
||
return false
|
||
}
|
||
if retryTimes <= 0 {
|
||
return false
|
||
}
|
||
if _, ok := c.Get("specific_channel_id"); ok {
|
||
return false
|
||
}
|
||
if taskErr.StatusCode == http.StatusTooManyRequests {
|
||
return true
|
||
}
|
||
if taskErr.StatusCode == 307 {
|
||
return true
|
||
}
|
||
if taskErr.StatusCode/100 == 5 {
|
||
// 超时不重试
|
||
if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 {
|
||
return false
|
||
}
|
||
return true
|
||
}
|
||
if taskErr.StatusCode == http.StatusBadRequest {
|
||
return false
|
||
}
|
||
if taskErr.StatusCode == 408 {
|
||
// azure处理超时不重试
|
||
return false
|
||
}
|
||
if taskErr.LocalError {
|
||
return false
|
||
}
|
||
if taskErr.StatusCode/100 == 2 {
|
||
return false
|
||
}
|
||
return true
|
||
}
|