refactor: Improve quota calculation precision using floating-point arithmetic

This commit is contained in:
1808837298@qq.com
2025-03-08 16:44:08 +08:00
parent 1f4ebddcfa
commit bb848b2fe0
2 changed files with 16 additions and 14 deletions

View File

@@ -320,17 +320,18 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
groupRatio := priceData.GroupRatio groupRatio := priceData.GroupRatio
modelPrice := priceData.ModelPrice modelPrice := priceData.ModelPrice
quota := 0 quotaCalculate := 0.0
if !priceData.UsePrice { if !priceData.UsePrice {
quota = (promptTokens - cacheTokens) + int(math.Round(float64(cacheTokens)*cacheRatio)) quotaCalculate = float64(promptTokens-cacheTokens) + float64(cacheTokens)*cacheRatio
quota += int(math.Round(float64(completionTokens) * completionRatio)) quotaCalculate += float64(completionTokens) * completionRatio
quota = int(math.Round(float64(quota) * ratio)) quotaCalculate = quotaCalculate * ratio
if ratio != 0 && quota <= 0 { if ratio != 0 && quotaCalculate <= 0 {
quota = 1 quotaCalculate = 1
} }
} else { } else {
quota = int(modelPrice * common.QuotaPerUnit * groupRatio) quotaCalculate = modelPrice * common.QuotaPerUnit * groupRatio
} }
quota := int(quotaCalculate)
totalTokens := promptTokens + completionTokens totalTokens := promptTokens + completionTokens
var logContent string var logContent string

View File

@@ -4,7 +4,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/bytedance/gopkg/util/gopool" "github.com/bytedance/gopkg/util/gopool"
"math"
"one-api/common" "one-api/common"
constant2 "one-api/constant" constant2 "one-api/constant"
"one-api/dto" "one-api/dto"
@@ -44,16 +43,18 @@ func calculateAudioQuota(info QuotaInfo) int {
audioCompletionRatio := operation_setting.GetAudioCompletionRatio(info.ModelName) audioCompletionRatio := operation_setting.GetAudioCompletionRatio(info.ModelName)
ratio := info.GroupRatio * info.ModelRatio ratio := info.GroupRatio * info.ModelRatio
quota := info.InputDetails.TextTokens + int(math.Round(float64(info.OutputDetails.TextTokens)*completionRatio)) quota := 0.0
quota += int(math.Round(float64(info.InputDetails.AudioTokens)*audioRatio)) + quota += float64(info.InputDetails.TextTokens)
int(math.Round(float64(info.OutputDetails.AudioTokens)*audioRatio*audioCompletionRatio)) quota += float64(info.OutputDetails.TextTokens) * completionRatio
quota += float64(info.InputDetails.AudioTokens) * audioRatio
quota += float64(info.OutputDetails.AudioTokens) * audioRatio * audioCompletionRatio
quota = int(math.Round(float64(quota) * ratio)) quota = quota * ratio
if ratio != 0 && quota <= 0 { if ratio != 0 && quota <= 0 {
quota = 1 quota = 1
} }
return quota return int(quota)
} }
func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error { func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error {