This commit refactors the logging mechanism across the application by replacing direct logger calls with a centralized logging approach using the `common` package. Key changes include: - Replaced instances of `logger.SysLog` and `logger.FatalLog` with `common.SysLog` and `common.FatalLog` for consistent logging practices. - Updated resource initialization error handling to utilize the new logging structure, enhancing maintainability and readability. - Minor adjustments to improve code clarity and organization throughout various modules. This change aims to streamline logging and improve the overall architecture of the codebase.
422 lines
16 KiB
Go
422 lines
16 KiB
Go
package relay
|
||
|
||
import (
|
||
"bytes"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"one-api/common"
|
||
"one-api/constant"
|
||
"one-api/dto"
|
||
"one-api/logger"
|
||
"one-api/model"
|
||
relaycommon "one-api/relay/common"
|
||
"one-api/relay/helper"
|
||
"one-api/service"
|
||
"one-api/setting/model_setting"
|
||
"one-api/setting/operation_setting"
|
||
"one-api/types"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/shopspring/decimal"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||
|
||
info.InitChannelMeta(c)
|
||
|
||
textRequest, ok := info.Request.(*dto.GeneralOpenAIRequest)
|
||
|
||
if !ok {
|
||
//return types.NewErrorWithStatusCode(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||
common.FatalLog("invalid request type, expected dto.GeneralOpenAIRequest, got %T", info.Request)
|
||
}
|
||
|
||
if textRequest.WebSearchOptions != nil {
|
||
c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
|
||
}
|
||
|
||
err := helper.ModelMappedHelper(c, info, textRequest)
|
||
if err != nil {
|
||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||
}
|
||
|
||
includeUsage := true
|
||
// 判断用户是否需要返回使用情况
|
||
if textRequest.StreamOptions != nil {
|
||
includeUsage = textRequest.StreamOptions.IncludeUsage
|
||
}
|
||
|
||
// 如果不支持StreamOptions,将StreamOptions设置为nil
|
||
if !info.SupportStreamOptions || !textRequest.Stream {
|
||
textRequest.StreamOptions = nil
|
||
} else {
|
||
// 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions
|
||
if constant.ForceStreamOption {
|
||
textRequest.StreamOptions = &dto.StreamOptions{
|
||
IncludeUsage: true,
|
||
}
|
||
}
|
||
}
|
||
|
||
info.ShouldIncludeUsage = includeUsage
|
||
|
||
adaptor := GetAdaptor(info.ApiType)
|
||
if adaptor == nil {
|
||
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||
}
|
||
adaptor.Init(info)
|
||
var requestBody io.Reader
|
||
|
||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
|
||
body, err := common.GetRequestBody(c)
|
||
if err != nil {
|
||
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||
}
|
||
if common.DebugEnabled {
|
||
println("requestBody: ", string(body))
|
||
}
|
||
requestBody = bytes.NewBuffer(body)
|
||
} else {
|
||
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, textRequest)
|
||
if err != nil {
|
||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||
}
|
||
|
||
if info.ChannelSetting.SystemPrompt != "" {
|
||
// 如果有系统提示,则将其添加到请求中
|
||
request := convertedRequest.(*dto.GeneralOpenAIRequest)
|
||
containSystemPrompt := false
|
||
for _, message := range request.Messages {
|
||
if message.Role == request.GetSystemRoleName() {
|
||
containSystemPrompt = true
|
||
break
|
||
}
|
||
}
|
||
if !containSystemPrompt {
|
||
// 如果没有系统提示,则添加系统提示
|
||
systemMessage := dto.Message{
|
||
Role: request.GetSystemRoleName(),
|
||
Content: info.ChannelSetting.SystemPrompt,
|
||
}
|
||
request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
|
||
} else if info.ChannelSetting.SystemPromptOverride {
|
||
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
|
||
// 如果有系统提示,且允许覆盖,则拼接到前面
|
||
for i, message := range request.Messages {
|
||
if message.Role == request.GetSystemRoleName() {
|
||
if message.IsStringContent() {
|
||
request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
|
||
} else {
|
||
contents := message.ParseContent()
|
||
contents = append([]dto.MediaContent{
|
||
{
|
||
Type: dto.ContentTypeText,
|
||
Text: info.ChannelSetting.SystemPrompt,
|
||
},
|
||
}, contents...)
|
||
request.Messages[i].Content = contents
|
||
}
|
||
break
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
jsonData, err := common.Marshal(convertedRequest)
|
||
if err != nil {
|
||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||
}
|
||
|
||
// apply param override
|
||
if len(info.ParamOverride) > 0 {
|
||
reqMap := make(map[string]interface{})
|
||
_ = common.Unmarshal(jsonData, &reqMap)
|
||
for key, value := range info.ParamOverride {
|
||
reqMap[key] = value
|
||
}
|
||
jsonData, err = common.Marshal(reqMap)
|
||
if err != nil {
|
||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||
}
|
||
}
|
||
|
||
logger.LogDebug(c, fmt.Sprintf("text request body: %s", string(jsonData)))
|
||
|
||
requestBody = bytes.NewBuffer(jsonData)
|
||
}
|
||
|
||
var httpResp *http.Response
|
||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||
if err != nil {
|
||
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
|
||
}
|
||
|
||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||
|
||
if resp != nil {
|
||
httpResp = resp.(*http.Response)
|
||
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||
if httpResp.StatusCode != http.StatusOK {
|
||
newApiErr := service.RelayErrorHandler(httpResp, false)
|
||
// reset status code 重置状态码
|
||
service.ResetStatusCode(newApiErr, statusCodeMappingStr)
|
||
return newApiErr
|
||
}
|
||
}
|
||
|
||
usage, newApiErr := adaptor.DoResponse(c, httpResp, info)
|
||
if newApiErr != nil {
|
||
// reset status code 重置状态码
|
||
service.ResetStatusCode(newApiErr, statusCodeMappingStr)
|
||
return newApiErr
|
||
}
|
||
|
||
if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
|
||
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||
} else {
|
||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) {
|
||
if usage == nil {
|
||
usage = &dto.Usage{
|
||
PromptTokens: relayInfo.PromptTokens,
|
||
CompletionTokens: 0,
|
||
TotalTokens: relayInfo.PromptTokens,
|
||
}
|
||
extraContent += "(可能是请求出错)"
|
||
}
|
||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||
promptTokens := usage.PromptTokens
|
||
cacheTokens := usage.PromptTokensDetails.CachedTokens
|
||
imageTokens := usage.PromptTokensDetails.ImageTokens
|
||
audioTokens := usage.PromptTokensDetails.AudioTokens
|
||
completionTokens := usage.CompletionTokens
|
||
modelName := relayInfo.OriginModelName
|
||
|
||
tokenName := ctx.GetString("token_name")
|
||
completionRatio := relayInfo.PriceData.CompletionRatio
|
||
cacheRatio := relayInfo.PriceData.CacheRatio
|
||
imageRatio := relayInfo.PriceData.ImageRatio
|
||
modelRatio := relayInfo.PriceData.ModelRatio
|
||
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
|
||
modelPrice := relayInfo.PriceData.ModelPrice
|
||
|
||
// Convert values to decimal for precise calculation
|
||
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
|
||
dCacheTokens := decimal.NewFromInt(int64(cacheTokens))
|
||
dImageTokens := decimal.NewFromInt(int64(imageTokens))
|
||
dAudioTokens := decimal.NewFromInt(int64(audioTokens))
|
||
dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
|
||
dCompletionRatio := decimal.NewFromFloat(completionRatio)
|
||
dCacheRatio := decimal.NewFromFloat(cacheRatio)
|
||
dImageRatio := decimal.NewFromFloat(imageRatio)
|
||
dModelRatio := decimal.NewFromFloat(modelRatio)
|
||
dGroupRatio := decimal.NewFromFloat(groupRatio)
|
||
dModelPrice := decimal.NewFromFloat(modelPrice)
|
||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||
|
||
ratio := dModelRatio.Mul(dGroupRatio)
|
||
|
||
// openai web search 工具计费
|
||
var dWebSearchQuota decimal.Decimal
|
||
var webSearchPrice float64
|
||
// response api 格式工具计费
|
||
if relayInfo.ResponsesUsageInfo != nil {
|
||
if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 {
|
||
// 计算 web search 调用的配额 (配额 = 价格 * 调用次数 / 1000 * 分组倍率)
|
||
webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, webSearchTool.SearchContextSize)
|
||
dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
|
||
Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))).
|
||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||
extraContent += fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 %s",
|
||
webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String())
|
||
}
|
||
} else if strings.HasSuffix(modelName, "search-preview") {
|
||
// search-preview 模型不支持 response api
|
||
searchContextSize := ctx.GetString("chat_completion_web_search_context_size")
|
||
if searchContextSize == "" {
|
||
searchContextSize = "medium"
|
||
}
|
||
webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, searchContextSize)
|
||
dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
|
||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||
extraContent += fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s,调用花费 %s",
|
||
searchContextSize, dWebSearchQuota.String())
|
||
}
|
||
// claude web search tool 计费
|
||
var dClaudeWebSearchQuota decimal.Decimal
|
||
var claudeWebSearchPrice float64
|
||
claudeWebSearchCallCount := ctx.GetInt("claude_web_search_requests")
|
||
if claudeWebSearchCallCount > 0 {
|
||
claudeWebSearchPrice = operation_setting.GetClaudeWebSearchPricePerThousand()
|
||
dClaudeWebSearchQuota = decimal.NewFromFloat(claudeWebSearchPrice).
|
||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit).Mul(decimal.NewFromInt(int64(claudeWebSearchCallCount)))
|
||
extraContent += fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s",
|
||
claudeWebSearchCallCount, dClaudeWebSearchQuota.String())
|
||
}
|
||
// file search tool 计费
|
||
var dFileSearchQuota decimal.Decimal
|
||
var fileSearchPrice float64
|
||
if relayInfo.ResponsesUsageInfo != nil {
|
||
if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists && fileSearchTool.CallCount > 0 {
|
||
fileSearchPrice = operation_setting.GetFileSearchPricePerThousand()
|
||
dFileSearchQuota = decimal.NewFromFloat(fileSearchPrice).
|
||
Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))).
|
||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||
extraContent += fmt.Sprintf("File Search 调用 %d 次,调用花费 %s",
|
||
fileSearchTool.CallCount, dFileSearchQuota.String())
|
||
}
|
||
}
|
||
|
||
var quotaCalculateDecimal decimal.Decimal
|
||
|
||
var audioInputQuota decimal.Decimal
|
||
var audioInputPrice float64
|
||
if !relayInfo.PriceData.UsePrice {
|
||
baseTokens := dPromptTokens
|
||
// 减去 cached tokens
|
||
var cachedTokensWithRatio decimal.Decimal
|
||
if !dCacheTokens.IsZero() {
|
||
baseTokens = baseTokens.Sub(dCacheTokens)
|
||
cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
|
||
}
|
||
|
||
// 减去 image tokens
|
||
var imageTokensWithRatio decimal.Decimal
|
||
if !dImageTokens.IsZero() {
|
||
baseTokens = baseTokens.Sub(dImageTokens)
|
||
imageTokensWithRatio = dImageTokens.Mul(dImageRatio)
|
||
}
|
||
|
||
// 减去 Gemini audio tokens
|
||
if !dAudioTokens.IsZero() {
|
||
audioInputPrice = operation_setting.GetGeminiInputAudioPricePerMillionTokens(modelName)
|
||
if audioInputPrice > 0 {
|
||
// 重新计算 base tokens
|
||
baseTokens = baseTokens.Sub(dAudioTokens)
|
||
audioInputQuota = decimal.NewFromFloat(audioInputPrice).Div(decimal.NewFromInt(1000000)).Mul(dAudioTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||
extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String())
|
||
}
|
||
}
|
||
promptQuota := baseTokens.Add(cachedTokensWithRatio).Add(imageTokensWithRatio)
|
||
|
||
completionQuota := dCompletionTokens.Mul(dCompletionRatio)
|
||
|
||
quotaCalculateDecimal = promptQuota.Add(completionQuota).Mul(ratio)
|
||
|
||
if !ratio.IsZero() && quotaCalculateDecimal.LessThanOrEqual(decimal.Zero) {
|
||
quotaCalculateDecimal = decimal.NewFromInt(1)
|
||
}
|
||
} else {
|
||
quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
|
||
}
|
||
// 添加 responses tools call 调用的配额
|
||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
|
||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
|
||
// 添加 audio input 独立计费
|
||
quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
|
||
|
||
quota := int(quotaCalculateDecimal.Round(0).IntPart())
|
||
totalTokens := promptTokens + completionTokens
|
||
|
||
var logContent string
|
||
if !relayInfo.PriceData.UsePrice {
|
||
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, groupRatio)
|
||
} else {
|
||
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
|
||
}
|
||
|
||
// record all the consume log even if quota is 0
|
||
if totalTokens == 0 {
|
||
// in this case, must be some error happened
|
||
// we cannot just return, because we may have to return the pre-consumed quota
|
||
quota = 0
|
||
logContent += fmt.Sprintf("(可能是上游超时)")
|
||
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
|
||
} else {
|
||
if !ratio.IsZero() && quota == 0 {
|
||
quota = 1
|
||
}
|
||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||
}
|
||
|
||
quotaDelta := quota - relayInfo.FinalPreConsumedQuota
|
||
if quotaDelta != 0 {
|
||
err := service.PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true)
|
||
if err != nil {
|
||
logger.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||
}
|
||
}
|
||
|
||
logModel := modelName
|
||
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
||
logModel = "gpt-4-gizmo-*"
|
||
logContent += fmt.Sprintf(",模型 %s", modelName)
|
||
}
|
||
if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
|
||
logModel = "gpt-4o-gizmo-*"
|
||
logContent += fmt.Sprintf(",模型 %s", modelName)
|
||
}
|
||
if extraContent != "" {
|
||
logContent += ", " + extraContent
|
||
}
|
||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
|
||
if imageTokens != 0 {
|
||
other["image"] = true
|
||
other["image_ratio"] = imageRatio
|
||
other["image_output"] = imageTokens
|
||
}
|
||
if !dWebSearchQuota.IsZero() {
|
||
if relayInfo.ResponsesUsageInfo != nil {
|
||
if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists {
|
||
other["web_search"] = true
|
||
other["web_search_call_count"] = webSearchTool.CallCount
|
||
other["web_search_price"] = webSearchPrice
|
||
}
|
||
} else if strings.HasSuffix(modelName, "search-preview") {
|
||
other["web_search"] = true
|
||
other["web_search_call_count"] = 1
|
||
other["web_search_price"] = webSearchPrice
|
||
}
|
||
} else if !dClaudeWebSearchQuota.IsZero() {
|
||
other["web_search"] = true
|
||
other["web_search_call_count"] = claudeWebSearchCallCount
|
||
other["web_search_price"] = claudeWebSearchPrice
|
||
}
|
||
if !dFileSearchQuota.IsZero() && relayInfo.ResponsesUsageInfo != nil {
|
||
if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists {
|
||
other["file_search"] = true
|
||
other["file_search_call_count"] = fileSearchTool.CallCount
|
||
other["file_search_price"] = fileSearchPrice
|
||
}
|
||
}
|
||
if !audioInputQuota.IsZero() {
|
||
other["audio_input_seperate_price"] = true
|
||
other["audio_input_token_count"] = audioTokens
|
||
other["audio_input_price"] = audioInputPrice
|
||
}
|
||
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
|
||
ChannelId: relayInfo.ChannelId,
|
||
PromptTokens: promptTokens,
|
||
CompletionTokens: completionTokens,
|
||
ModelName: logModel,
|
||
TokenName: tokenName,
|
||
Quota: quota,
|
||
Content: logContent,
|
||
TokenId: relayInfo.TokenId,
|
||
UseTimeSeconds: int(useTimeSeconds),
|
||
IsStream: relayInfo.IsStream,
|
||
Group: relayInfo.UsingGroup,
|
||
Other: other,
|
||
})
|
||
}
|