These changes ensures SSE ping packets are sent before receiving a response from the upstream. The previous implementation did not send ping packets until after the upstream response, rendering the feature ineffective.
445 lines
16 KiB
Go
445 lines
16 KiB
Go
package relay
|
||
|
||
import (
|
||
"bytes"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"math"
|
||
"net/http"
|
||
"one-api/common"
|
||
"one-api/constant"
|
||
"one-api/dto"
|
||
"one-api/model"
|
||
relaycommon "one-api/relay/common"
|
||
relayconstant "one-api/relay/constant"
|
||
"one-api/relay/helper"
|
||
"one-api/service"
|
||
"one-api/setting"
|
||
"one-api/setting/model_setting"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/bytedance/gopkg/util/gopool"
|
||
"github.com/shopspring/decimal"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
|
||
textRequest := &dto.GeneralOpenAIRequest{}
|
||
err := common.UnmarshalBodyReusable(c, textRequest)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if relayInfo.RelayMode == relayconstant.RelayModeModerations && textRequest.Model == "" {
|
||
textRequest.Model = "text-moderation-latest"
|
||
}
|
||
if relayInfo.RelayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" {
|
||
textRequest.Model = c.Param("model")
|
||
}
|
||
|
||
if textRequest.MaxTokens > math.MaxInt32/2 {
|
||
return nil, errors.New("max_tokens is invalid")
|
||
}
|
||
if textRequest.Model == "" {
|
||
return nil, errors.New("model is required")
|
||
}
|
||
switch relayInfo.RelayMode {
|
||
case relayconstant.RelayModeCompletions:
|
||
if textRequest.Prompt == "" {
|
||
return nil, errors.New("field prompt is required")
|
||
}
|
||
case relayconstant.RelayModeChatCompletions:
|
||
if len(textRequest.Messages) == 0 {
|
||
return nil, errors.New("field messages is required")
|
||
}
|
||
case relayconstant.RelayModeEmbeddings:
|
||
case relayconstant.RelayModeModerations:
|
||
if textRequest.Input == nil || textRequest.Input == "" {
|
||
return nil, errors.New("field input is required")
|
||
}
|
||
case relayconstant.RelayModeEdits:
|
||
if textRequest.Instruction == "" {
|
||
return nil, errors.New("field instruction is required")
|
||
}
|
||
}
|
||
relayInfo.IsStream = textRequest.Stream
|
||
return textRequest, nil
|
||
}
|
||
|
||
func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||
|
||
relayInfo := relaycommon.GenRelayInfo(c)
|
||
|
||
// get & validate textRequest 获取并验证文本请求
|
||
textRequest, err := getAndValidateTextRequest(c, relayInfo)
|
||
if err != nil {
|
||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
||
}
|
||
|
||
if setting.ShouldCheckPromptSensitive() {
|
||
words, err := checkRequestSensitive(textRequest, relayInfo)
|
||
if err != nil {
|
||
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
|
||
return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
|
||
}
|
||
}
|
||
|
||
err = helper.ModelMappedHelper(c, relayInfo)
|
||
if err != nil {
|
||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||
}
|
||
|
||
textRequest.Model = relayInfo.UpstreamModelName
|
||
|
||
// 获取 promptTokens,如果上下文中已经存在,则直接使用
|
||
var promptTokens int
|
||
if value, exists := c.Get("prompt_tokens"); exists {
|
||
promptTokens = value.(int)
|
||
relayInfo.PromptTokens = promptTokens
|
||
} else {
|
||
promptTokens, err = getPromptTokens(textRequest, relayInfo)
|
||
// count messages token error 计算promptTokens错误
|
||
if err != nil {
|
||
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
||
}
|
||
c.Set("prompt_tokens", promptTokens)
|
||
}
|
||
|
||
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens))))
|
||
if err != nil {
|
||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||
}
|
||
|
||
// pre-consume quota 预消耗配额
|
||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||
if openaiErr != nil {
|
||
return openaiErr
|
||
}
|
||
defer func() {
|
||
if openaiErr != nil {
|
||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||
}
|
||
}()
|
||
includeUsage := false
|
||
// 判断用户是否需要返回使用情况
|
||
if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage {
|
||
includeUsage = true
|
||
}
|
||
|
||
// 如果不支持StreamOptions,将StreamOptions设置为nil
|
||
if !relayInfo.SupportStreamOptions || !textRequest.Stream {
|
||
textRequest.StreamOptions = nil
|
||
} else {
|
||
// 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions
|
||
if constant.ForceStreamOption {
|
||
textRequest.StreamOptions = &dto.StreamOptions{
|
||
IncludeUsage: true,
|
||
}
|
||
}
|
||
}
|
||
|
||
if includeUsage {
|
||
relayInfo.ShouldIncludeUsage = true
|
||
}
|
||
|
||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||
if adaptor == nil {
|
||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||
}
|
||
adaptor.Init(relayInfo)
|
||
var requestBody io.Reader
|
||
|
||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
|
||
body, err := common.GetRequestBody(c)
|
||
if err != nil {
|
||
return service.OpenAIErrorWrapperLocal(err, "get_request_body_failed", http.StatusInternalServerError)
|
||
}
|
||
requestBody = bytes.NewBuffer(body)
|
||
} else {
|
||
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest)
|
||
if err != nil {
|
||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||
}
|
||
jsonData, err := json.Marshal(convertedRequest)
|
||
if err != nil {
|
||
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||
}
|
||
|
||
// apply param override
|
||
if len(relayInfo.ParamOverride) > 0 {
|
||
reqMap := make(map[string]interface{})
|
||
err = json.Unmarshal(jsonData, &reqMap)
|
||
if err != nil {
|
||
return service.OpenAIErrorWrapperLocal(err, "param_override_unmarshal_failed", http.StatusInternalServerError)
|
||
}
|
||
for key, value := range relayInfo.ParamOverride {
|
||
reqMap[key] = value
|
||
}
|
||
jsonData, err = json.Marshal(reqMap)
|
||
if err != nil {
|
||
return service.OpenAIErrorWrapperLocal(err, "param_override_marshal_failed", http.StatusInternalServerError)
|
||
}
|
||
}
|
||
|
||
if common.DebugEnabled {
|
||
println("requestBody: ", string(jsonData))
|
||
}
|
||
requestBody = bytes.NewBuffer(jsonData)
|
||
}
|
||
|
||
var httpResp *http.Response
|
||
var resp any
|
||
|
||
if relayInfo.IsStream {
|
||
// Streaming requests can use SSE ping to keep alive and avoid connection timeout
|
||
// The judgment of whether ping is enabled will be made within the function
|
||
resp, err = helper.DoStreamRequestWithPinger(adaptor.DoRequest, c, relayInfo, requestBody)
|
||
} else {
|
||
resp, err = adaptor.DoRequest(c, relayInfo, requestBody)
|
||
}
|
||
|
||
if err != nil {
|
||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||
}
|
||
|
||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||
|
||
if resp != nil {
|
||
httpResp = resp.(*http.Response)
|
||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||
if httpResp.StatusCode != http.StatusOK {
|
||
openaiErr = service.RelayErrorHandler(httpResp, false)
|
||
// reset status code 重置状态码
|
||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||
return openaiErr
|
||
}
|
||
}
|
||
|
||
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||
if openaiErr != nil {
|
||
// reset status code 重置状态码
|
||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||
return openaiErr
|
||
}
|
||
|
||
if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
|
||
service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||
} else {
|
||
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) {
|
||
var promptTokens int
|
||
var err error
|
||
switch info.RelayMode {
|
||
case relayconstant.RelayModeChatCompletions:
|
||
promptTokens, err = service.CountTokenChatRequest(info, *textRequest)
|
||
case relayconstant.RelayModeCompletions:
|
||
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
||
case relayconstant.RelayModeModerations:
|
||
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
||
case relayconstant.RelayModeEmbeddings:
|
||
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
||
default:
|
||
err = errors.New("unknown relay mode")
|
||
promptTokens = 0
|
||
}
|
||
info.PromptTokens = promptTokens
|
||
return promptTokens, err
|
||
}
|
||
|
||
func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) ([]string, error) {
|
||
var err error
|
||
var words []string
|
||
switch info.RelayMode {
|
||
case relayconstant.RelayModeChatCompletions:
|
||
words, err = service.CheckSensitiveMessages(textRequest.Messages)
|
||
case relayconstant.RelayModeCompletions:
|
||
words, err = service.CheckSensitiveInput(textRequest.Prompt)
|
||
case relayconstant.RelayModeModerations:
|
||
words, err = service.CheckSensitiveInput(textRequest.Input)
|
||
case relayconstant.RelayModeEmbeddings:
|
||
words, err = service.CheckSensitiveInput(textRequest.Input)
|
||
}
|
||
return words, err
|
||
}
|
||
|
||
// 预扣费并返回用户剩余配额
|
||
func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *dto.OpenAIErrorWithStatusCode) {
|
||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||
if err != nil {
|
||
return 0, 0, service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||
}
|
||
if userQuota <= 0 {
|
||
return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||
}
|
||
if userQuota-preConsumedQuota < 0 {
|
||
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusForbidden)
|
||
}
|
||
relayInfo.UserQuota = userQuota
|
||
if userQuota > 100*preConsumedQuota {
|
||
// 用户额度充足,判断令牌额度是否充足
|
||
if !relayInfo.TokenUnlimited {
|
||
// 非无限令牌,判断令牌额度是否充足
|
||
tokenQuota := c.GetInt("token_quota")
|
||
if tokenQuota > 100*preConsumedQuota {
|
||
// 令牌额度充足,信任令牌
|
||
preConsumedQuota = 0
|
||
common.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota))
|
||
}
|
||
} else {
|
||
// in this case, we do not pre-consume quota
|
||
// because the user has enough quota
|
||
preConsumedQuota = 0
|
||
common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota)))
|
||
}
|
||
}
|
||
|
||
if preConsumedQuota > 0 {
|
||
err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
||
if err != nil {
|
||
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||
}
|
||
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||
if err != nil {
|
||
return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||
}
|
||
}
|
||
return preConsumedQuota, userQuota, nil
|
||
}
|
||
|
||
func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) {
|
||
if preConsumedQuota != 0 {
|
||
gopool.Go(func() {
|
||
relayInfoCopy := *relayInfo
|
||
|
||
err := service.PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
|
||
if err != nil {
|
||
common.SysError("error return pre-consumed quota: " + err.Error())
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||
usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
|
||
if usage == nil {
|
||
usage = &dto.Usage{
|
||
PromptTokens: relayInfo.PromptTokens,
|
||
CompletionTokens: 0,
|
||
TotalTokens: relayInfo.PromptTokens,
|
||
}
|
||
extraContent += "(可能是请求出错)"
|
||
}
|
||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||
promptTokens := usage.PromptTokens
|
||
cacheTokens := usage.PromptTokensDetails.CachedTokens
|
||
imageTokens := usage.PromptTokensDetails.ImageTokens
|
||
completionTokens := usage.CompletionTokens
|
||
modelName := relayInfo.OriginModelName
|
||
|
||
tokenName := ctx.GetString("token_name")
|
||
completionRatio := priceData.CompletionRatio
|
||
cacheRatio := priceData.CacheRatio
|
||
imageRatio := priceData.ImageRatio
|
||
modelRatio := priceData.ModelRatio
|
||
groupRatio := priceData.GroupRatio
|
||
modelPrice := priceData.ModelPrice
|
||
|
||
// Convert values to decimal for precise calculation
|
||
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
|
||
dCacheTokens := decimal.NewFromInt(int64(cacheTokens))
|
||
dImageTokens := decimal.NewFromInt(int64(imageTokens))
|
||
dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
|
||
dCompletionRatio := decimal.NewFromFloat(completionRatio)
|
||
dCacheRatio := decimal.NewFromFloat(cacheRatio)
|
||
dImageRatio := decimal.NewFromFloat(imageRatio)
|
||
dModelRatio := decimal.NewFromFloat(modelRatio)
|
||
dGroupRatio := decimal.NewFromFloat(groupRatio)
|
||
dModelPrice := decimal.NewFromFloat(modelPrice)
|
||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||
|
||
ratio := dModelRatio.Mul(dGroupRatio)
|
||
|
||
var quotaCalculateDecimal decimal.Decimal
|
||
if !priceData.UsePrice {
|
||
nonCachedTokens := dPromptTokens.Sub(dCacheTokens)
|
||
cachedTokensWithRatio := dCacheTokens.Mul(dCacheRatio)
|
||
|
||
promptQuota := nonCachedTokens.Add(cachedTokensWithRatio)
|
||
if imageTokens > 0 {
|
||
nonImageTokens := dPromptTokens.Sub(dImageTokens)
|
||
imageTokensWithRatio := dImageTokens.Mul(dImageRatio)
|
||
promptQuota = nonImageTokens.Add(imageTokensWithRatio)
|
||
}
|
||
|
||
completionQuota := dCompletionTokens.Mul(dCompletionRatio)
|
||
|
||
quotaCalculateDecimal = promptQuota.Add(completionQuota).Mul(ratio)
|
||
|
||
if !ratio.IsZero() && quotaCalculateDecimal.LessThanOrEqual(decimal.Zero) {
|
||
quotaCalculateDecimal = decimal.NewFromInt(1)
|
||
}
|
||
} else {
|
||
quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
|
||
}
|
||
|
||
quota := int(quotaCalculateDecimal.Round(0).IntPart())
|
||
totalTokens := promptTokens + completionTokens
|
||
|
||
var logContent string
|
||
if !priceData.UsePrice {
|
||
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, groupRatio)
|
||
} else {
|
||
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
|
||
}
|
||
|
||
// record all the consume log even if quota is 0
|
||
if totalTokens == 0 {
|
||
// in this case, must be some error happened
|
||
// we cannot just return, because we may have to return the pre-consumed quota
|
||
quota = 0
|
||
logContent += fmt.Sprintf("(可能是上游超时)")
|
||
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
|
||
} else {
|
||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||
}
|
||
|
||
quotaDelta := quota - preConsumedQuota
|
||
if quotaDelta != 0 {
|
||
err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
|
||
if err != nil {
|
||
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||
}
|
||
}
|
||
|
||
logModel := modelName
|
||
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
||
logModel = "gpt-4-gizmo-*"
|
||
logContent += fmt.Sprintf(",模型 %s", modelName)
|
||
}
|
||
if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
|
||
logModel = "gpt-4o-gizmo-*"
|
||
logContent += fmt.Sprintf(",模型 %s", modelName)
|
||
}
|
||
if extraContent != "" {
|
||
logContent += ", " + extraContent
|
||
}
|
||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice)
|
||
if imageTokens != 0 {
|
||
other["image"] = true
|
||
other["image_ratio"] = imageRatio
|
||
other["image_output"] = imageTokens
|
||
}
|
||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
|
||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
||
}
|