diff --git a/relay/relay-text.go b/relay/relay-text.go index 66b60db5..1e2dfbaa 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -219,7 +219,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { } if strings.HasPrefix(relayInfo.UpstreamModelName, "gpt-4o-audio") { - service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") + service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") } else { postConsumeQuota(c, relayInfo, textRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") } diff --git a/relay/websocket.go b/relay/websocket.go index c05e70a9..75a7d1f0 100644 --- a/relay/websocket.go +++ b/relay/websocket.go @@ -13,24 +13,6 @@ import ( "one-api/setting" ) -//func getAndValidateWssRequest(c *gin.Context, ws *websocket.Conn) (*dto.RealtimeEvent, error) { -// _, p, err := ws.ReadMessage() -// if err != nil { -// return nil, err -// } -// realtimeEvent := &dto.RealtimeEvent{} -// err = json.Unmarshal(p, realtimeEvent) -// if err != nil { -// return nil, err -// } -// // save the original request -// if realtimeEvent.Session == nil { -// return nil, errors.New("session object is nil") -// } -// c.Set("first_wss_request", p) -// return realtimeEvent, nil -//} - func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) { relayInfo := relaycommon.GenRelayInfoWs(c, ws) @@ -129,32 +111,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") + service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota, + userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") return nil } - -//func getWssPromptTokens(textRequest *dto.RealtimeEvent, info *relaycommon.RelayInfo) (int, error) { -// var promptTokens int -// var err error -// switch info.RelayMode { -// default: -// promptTokens, err = service.CountTokenRealtime(*textRequest, info.UpstreamModelName) -// } -// info.PromptTokens = promptTokens -// return promptTokens, err -//} - -//func checkWssRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error { -// var err error -// switch info.RelayMode { -// case relayconstant.RelayModeChatCompletions: -// err = service.CheckSensitiveMessages(textRequest.Messages) -// case relayconstant.RelayModeCompletions: -// err = service.CheckSensitiveInput(textRequest.Prompt) -// case relayconstant.RelayModeModerations: -// err = service.CheckSensitiveInput(textRequest.Input) -// case relayconstant.RelayModeEmbeddings: -// err = service.CheckSensitiveInput(textRequest.Input) -// } -// return err -//} diff --git a/service/quota.go b/service/quota.go index 19c7c057..234ddc5b 100644 --- a/service/quota.go +++ b/service/quota.go @@ -3,7 +3,6 @@ package service import ( "errors" "fmt" - "github.com/gin-gonic/gin" "math" "one-api/common" "one-api/dto" @@ -12,8 +11,47 @@ import ( "one-api/setting" "strings" "time" + + "github.com/gin-gonic/gin" ) +type TokenDetails struct { + TextTokens int + AudioTokens int +} + +type QuotaInfo struct { + InputDetails TokenDetails + OutputDetails TokenDetails + ModelName string + UsePrice bool + ModelPrice float64 + ModelRatio float64 + GroupRatio float64 +} + +func calculateAudioQuota(info QuotaInfo) int { + if info.UsePrice { + return int(info.ModelPrice * common.QuotaPerUnit * info.GroupRatio) + } + + completionRatio := common.GetCompletionRatio(info.ModelName) + audioRatio := common.GetAudioRatio(info.ModelName) + audioCompletionRatio := common.GetAudioCompletionRatio(info.ModelName) + ratio := info.GroupRatio * info.ModelRatio + + quota := info.InputDetails.TextTokens + int(math.Round(float64(info.OutputDetails.TextTokens)*completionRatio)) + quota += int(math.Round(float64(info.InputDetails.AudioTokens)*audioRatio)) + + int(math.Round(float64(info.OutputDetails.AudioTokens)*audioRatio*audioCompletionRatio)) + + quota = int(math.Round(float64(quota) * ratio)) + if ratio != 0 && quota <= 0 { + quota = 1 + } + + return quota +} + func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error { if relayInfo.UsePrice { return nil @@ -33,23 +71,26 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag textOutTokens := usage.OutputTokenDetails.TextTokens audioInputTokens := usage.InputTokenDetails.AudioTokens audioOutTokens := usage.OutputTokenDetails.AudioTokens - - completionRatio := common.GetCompletionRatio(modelName) - audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName) - audioCompletionRatio := common.GetAudioCompletionRatio(modelName) groupRatio := setting.GetGroupRatio(relayInfo.Group) modelRatio := common.GetModelRatio(modelName) - ratio := groupRatio * modelRatio - - quota := textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio)) - quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*audioRatio*audioCompletionRatio)) - - quota = int(math.Round(float64(quota) * ratio)) - if ratio != 0 && quota <= 0 { - quota = 1 + quotaInfo := QuotaInfo{ + InputDetails: TokenDetails{ + TextTokens: textInputTokens, + AudioTokens: audioInputTokens, + }, + OutputDetails: TokenDetails{ + TextTokens: textOutTokens, + AudioTokens: audioOutTokens, + }, + ModelName: modelName, + UsePrice: relayInfo.UsePrice, + ModelRatio: modelRatio, + GroupRatio: groupRatio, } + quota := calculateAudioQuota(quotaInfo) + if userQuota < quota { return errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota)) } @@ -67,8 +108,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag } func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, - usage *dto.RealtimeUsage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, - groupRatio float64, + usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, modelPrice float64, usePrice bool, extraContent string) { useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() @@ -83,17 +123,23 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName) audioCompletionRatio := common.GetAudioCompletionRatio(modelName) - quota := 0 - if !usePrice { - quota = int(math.Round(float64(textInputTokens) + float64(textOutTokens)*completionRatio)) - quota += int(math.Round(float64(audioInputTokens)*audioRatio + float64(audioOutTokens)*audioRatio*audioCompletionRatio)) - quota = int(math.Round(float64(quota) * ratio)) - if ratio != 0 && quota <= 0 { - quota = 1 - } - } else { - quota = int(modelPrice * common.QuotaPerUnit * groupRatio) + quotaInfo := QuotaInfo{ + InputDetails: TokenDetails{ + TextTokens: textInputTokens, + AudioTokens: audioInputTokens, + }, + OutputDetails: TokenDetails{ + TextTokens: textOutTokens, + AudioTokens: audioOutTokens, + }, + ModelName: modelName, + UsePrice: usePrice, + ModelRatio: modelRatio, + GroupRatio: groupRatio, } + + quota := calculateAudioQuota(quotaInfo) + totalTokens := usage.TotalTokens var logContent string if !usePrice { @@ -111,21 +157,6 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota)) } else { - //if sensitiveResp != nil { - // logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", ")) - //} - //quotaDelta := quota - preConsumedQuota - //if quotaDelta != 0 { - // err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true) - // if err != nil { - // common.LogError(ctx, "error consuming token remain quota: "+err.Error()) - // } - //} - - //err := model.CacheUpdateUserQuota(relayInfo.UserId) - //if err != nil { - // common.LogError(ctx, "error update user quota cache: "+err.Error()) - //} model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } @@ -140,8 +171,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod } func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, - usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, - groupRatio float64, + usage *dto.Usage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, modelPrice float64, usePrice bool, extraContent string) { useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() @@ -156,17 +186,23 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName) audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.UpstreamModelName) - quota := 0 - if !usePrice { - quota = int(math.Round(float64(textInputTokens) + float64(textOutTokens)*completionRatio)) - quota += int(math.Round(float64(audioInputTokens)*audioRatio + float64(audioOutTokens)*audioRatio*audioCompletionRatio)) - quota = int(math.Round(float64(quota) * ratio)) - if ratio != 0 && quota <= 0 { - quota = 1 - } - } else { - quota = int(modelPrice * common.QuotaPerUnit * groupRatio) + quotaInfo := QuotaInfo{ + InputDetails: TokenDetails{ + TextTokens: textInputTokens, + AudioTokens: audioInputTokens, + }, + OutputDetails: TokenDetails{ + TextTokens: textOutTokens, + AudioTokens: audioOutTokens, + }, + ModelName: relayInfo.UpstreamModelName, + UsePrice: usePrice, + ModelRatio: modelRatio, + GroupRatio: groupRatio, } + + quota := calculateAudioQuota(quotaInfo) + totalTokens := usage.TotalTokens var logContent string if !usePrice {