refactor: realtime quota

This commit is contained in:
1808837298@qq.com
2025-01-04 15:46:35 +08:00
parent b5de003ec2
commit 99245e4c1f
3 changed files with 91 additions and 98 deletions

View File

@@ -219,7 +219,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
} }
if strings.HasPrefix(relayInfo.UpstreamModelName, "gpt-4o-audio") { if strings.HasPrefix(relayInfo.UpstreamModelName, "gpt-4o-audio") {
service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
} else { } else {
postConsumeQuota(c, relayInfo, textRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") postConsumeQuota(c, relayInfo, textRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
} }

View File

@@ -13,24 +13,6 @@ import (
"one-api/setting" "one-api/setting"
) )
//func getAndValidateWssRequest(c *gin.Context, ws *websocket.Conn) (*dto.RealtimeEvent, error) {
// _, p, err := ws.ReadMessage()
// if err != nil {
// return nil, err
// }
// realtimeEvent := &dto.RealtimeEvent{}
// err = json.Unmarshal(p, realtimeEvent)
// if err != nil {
// return nil, err
// }
// // save the original request
// if realtimeEvent.Session == nil {
// return nil, errors.New("session object is nil")
// }
// c.Set("first_wss_request", p)
// return realtimeEvent, nil
//}
func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) { func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfoWs(c, ws) relayInfo := relaycommon.GenRelayInfoWs(c, ws)
@@ -129,32 +111,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr
} }
service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota,
userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
return nil return nil
} }
//func getWssPromptTokens(textRequest *dto.RealtimeEvent, info *relaycommon.RelayInfo) (int, error) {
// var promptTokens int
// var err error
// switch info.RelayMode {
// default:
// promptTokens, err = service.CountTokenRealtime(*textRequest, info.UpstreamModelName)
// }
// info.PromptTokens = promptTokens
// return promptTokens, err
//}
//func checkWssRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error {
// var err error
// switch info.RelayMode {
// case relayconstant.RelayModeChatCompletions:
// err = service.CheckSensitiveMessages(textRequest.Messages)
// case relayconstant.RelayModeCompletions:
// err = service.CheckSensitiveInput(textRequest.Prompt)
// case relayconstant.RelayModeModerations:
// err = service.CheckSensitiveInput(textRequest.Input)
// case relayconstant.RelayModeEmbeddings:
// err = service.CheckSensitiveInput(textRequest.Input)
// }
// return err
//}

View File

@@ -3,7 +3,6 @@ package service
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"math" "math"
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
@@ -12,8 +11,47 @@ import (
"one-api/setting" "one-api/setting"
"strings" "strings"
"time" "time"
"github.com/gin-gonic/gin"
) )
type TokenDetails struct {
TextTokens int
AudioTokens int
}
type QuotaInfo struct {
InputDetails TokenDetails
OutputDetails TokenDetails
ModelName string
UsePrice bool
ModelPrice float64
ModelRatio float64
GroupRatio float64
}
func calculateAudioQuota(info QuotaInfo) int {
if info.UsePrice {
return int(info.ModelPrice * common.QuotaPerUnit * info.GroupRatio)
}
completionRatio := common.GetCompletionRatio(info.ModelName)
audioRatio := common.GetAudioRatio(info.ModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(info.ModelName)
ratio := info.GroupRatio * info.ModelRatio
quota := info.InputDetails.TextTokens + int(math.Round(float64(info.OutputDetails.TextTokens)*completionRatio))
quota += int(math.Round(float64(info.InputDetails.AudioTokens)*audioRatio)) +
int(math.Round(float64(info.OutputDetails.AudioTokens)*audioRatio*audioCompletionRatio))
quota = int(math.Round(float64(quota) * ratio))
if ratio != 0 && quota <= 0 {
quota = 1
}
return quota
}
func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error { func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error {
if relayInfo.UsePrice { if relayInfo.UsePrice {
return nil return nil
@@ -33,23 +71,26 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
textOutTokens := usage.OutputTokenDetails.TextTokens textOutTokens := usage.OutputTokenDetails.TextTokens
audioInputTokens := usage.InputTokenDetails.AudioTokens audioInputTokens := usage.InputTokenDetails.AudioTokens
audioOutTokens := usage.OutputTokenDetails.AudioTokens audioOutTokens := usage.OutputTokenDetails.AudioTokens
completionRatio := common.GetCompletionRatio(modelName)
audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
groupRatio := setting.GetGroupRatio(relayInfo.Group) groupRatio := setting.GetGroupRatio(relayInfo.Group)
modelRatio := common.GetModelRatio(modelName) modelRatio := common.GetModelRatio(modelName)
ratio := groupRatio * modelRatio quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
quota := textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio)) TextTokens: textInputTokens,
quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*audioRatio*audioCompletionRatio)) AudioTokens: audioInputTokens,
},
quota = int(math.Round(float64(quota) * ratio)) OutputDetails: TokenDetails{
if ratio != 0 && quota <= 0 { TextTokens: textOutTokens,
quota = 1 AudioTokens: audioOutTokens,
},
ModelName: modelName,
UsePrice: relayInfo.UsePrice,
ModelRatio: modelRatio,
GroupRatio: groupRatio,
} }
quota := calculateAudioQuota(quotaInfo)
if userQuota < quota { if userQuota < quota {
return errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota)) return errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
} }
@@ -67,8 +108,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
} }
func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
usage *dto.RealtimeUsage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
groupRatio float64,
modelPrice float64, usePrice bool, extraContent string) { modelPrice float64, usePrice bool, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
@@ -83,17 +123,23 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName) audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(modelName) audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
quota := 0 quotaInfo := QuotaInfo{
if !usePrice { InputDetails: TokenDetails{
quota = int(math.Round(float64(textInputTokens) + float64(textOutTokens)*completionRatio)) TextTokens: textInputTokens,
quota += int(math.Round(float64(audioInputTokens)*audioRatio + float64(audioOutTokens)*audioRatio*audioCompletionRatio)) AudioTokens: audioInputTokens,
quota = int(math.Round(float64(quota) * ratio)) },
if ratio != 0 && quota <= 0 { OutputDetails: TokenDetails{
quota = 1 TextTokens: textOutTokens,
} AudioTokens: audioOutTokens,
} else { },
quota = int(modelPrice * common.QuotaPerUnit * groupRatio) ModelName: modelName,
UsePrice: usePrice,
ModelRatio: modelRatio,
GroupRatio: groupRatio,
} }
quota := calculateAudioQuota(quotaInfo)
totalTokens := usage.TotalTokens totalTokens := usage.TotalTokens
var logContent string var logContent string
if !usePrice { if !usePrice {
@@ -111,21 +157,6 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota)) "tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
} else { } else {
//if sensitiveResp != nil {
// logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
//}
//quotaDelta := quota - preConsumedQuota
//if quotaDelta != 0 {
// err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
// if err != nil {
// common.LogError(ctx, "error consuming token remain quota: "+err.Error())
// }
//}
//err := model.CacheUpdateUserQuota(relayInfo.UserId)
//if err != nil {
// common.LogError(ctx, "error update user quota cache: "+err.Error())
//}
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
} }
@@ -140,8 +171,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
} }
func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, usage *dto.Usage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
groupRatio float64,
modelPrice float64, usePrice bool, extraContent string) { modelPrice float64, usePrice bool, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
@@ -156,17 +186,23 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName) audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.UpstreamModelName) audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.UpstreamModelName)
quota := 0 quotaInfo := QuotaInfo{
if !usePrice { InputDetails: TokenDetails{
quota = int(math.Round(float64(textInputTokens) + float64(textOutTokens)*completionRatio)) TextTokens: textInputTokens,
quota += int(math.Round(float64(audioInputTokens)*audioRatio + float64(audioOutTokens)*audioRatio*audioCompletionRatio)) AudioTokens: audioInputTokens,
quota = int(math.Round(float64(quota) * ratio)) },
if ratio != 0 && quota <= 0 { OutputDetails: TokenDetails{
quota = 1 TextTokens: textOutTokens,
} AudioTokens: audioOutTokens,
} else { },
quota = int(modelPrice * common.QuotaPerUnit * groupRatio) ModelName: relayInfo.UpstreamModelName,
UsePrice: usePrice,
ModelRatio: modelRatio,
GroupRatio: groupRatio,
} }
quota := calculateAudioQuota(quotaInfo)
totalTokens := usage.TotalTokens totalTokens := usage.TotalTokens
var logContent string var logContent string
if !usePrice { if !usePrice {