feat: Improve decimal precision for quota and payment calculations

- Added github.com/shopspring/decimal for precise floating-point calculations
- Refactored quota and payment calculations in multiple files to use decimal arithmetic
- Updated go.mod and go.sum to include decimal library
- Improved precision in topup, relay, and quota service calculations
- Added support for more OpenAI model variants in cache ratio settings
This commit is contained in:
1808837298@qq.com
2025-03-08 21:55:50 +08:00
parent 3352bacd35
commit 68097c132d
6 changed files with 111 additions and 56 deletions

View File

@@ -3,7 +3,6 @@ package service
import (
"errors"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"one-api/common"
constant2 "one-api/constant"
"one-api/dto"
@@ -15,7 +14,10 @@ import (
"strings"
"time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/shopspring/decimal"
)
type TokenDetails struct {
@@ -35,26 +37,41 @@ type QuotaInfo struct {
func calculateAudioQuota(info QuotaInfo) int {
if info.UsePrice {
return int(info.ModelPrice * common.QuotaPerUnit * info.GroupRatio)
modelPrice := decimal.NewFromFloat(info.ModelPrice)
quotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
groupRatio := decimal.NewFromFloat(info.GroupRatio)
quota := modelPrice.Mul(quotaPerUnit).Mul(groupRatio)
return int(quota.IntPart())
}
completionRatio := operation_setting.GetCompletionRatio(info.ModelName)
audioRatio := operation_setting.GetAudioRatio(info.ModelName)
audioCompletionRatio := operation_setting.GetAudioCompletionRatio(info.ModelName)
ratio := info.GroupRatio * info.ModelRatio
completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(info.ModelName))
audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(info.ModelName))
audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(info.ModelName))
quota := 0.0
quota += float64(info.InputDetails.TextTokens)
quota += float64(info.OutputDetails.TextTokens) * completionRatio
quota += float64(info.InputDetails.AudioTokens) * audioRatio
quota += float64(info.OutputDetails.AudioTokens) * audioRatio * audioCompletionRatio
groupRatio := decimal.NewFromFloat(info.GroupRatio)
modelRatio := decimal.NewFromFloat(info.ModelRatio)
ratio := groupRatio.Mul(modelRatio)
quota = quota * ratio
if ratio != 0 && quota <= 0 {
quota = 1
inputTextTokens := decimal.NewFromInt(int64(info.InputDetails.TextTokens))
outputTextTokens := decimal.NewFromInt(int64(info.OutputDetails.TextTokens))
inputAudioTokens := decimal.NewFromInt(int64(info.InputDetails.AudioTokens))
outputAudioTokens := decimal.NewFromInt(int64(info.OutputDetails.AudioTokens))
quota := decimal.Zero
quota = quota.Add(inputTextTokens)
quota = quota.Add(outputTextTokens.Mul(completionRatio))
quota = quota.Add(inputAudioTokens.Mul(audioRatio))
quota = quota.Add(outputAudioTokens.Mul(audioRatio).Mul(audioCompletionRatio))
quota = quota.Mul(ratio)
// If ratio is not zero and quota is less than or equal to zero, set quota to 1
if !ratio.IsZero() && quota.LessThanOrEqual(decimal.Zero) {
quota = decimal.NewFromInt(1)
}
return int(quota)
return int(quota.Round(0).IntPart())
}
func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error {
@@ -124,9 +141,9 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
audioOutTokens := usage.OutputTokenDetails.AudioTokens
tokenName := ctx.GetString("token_name")
completionRatio := operation_setting.GetCompletionRatio(modelName)
audioRatio := operation_setting.GetAudioRatio(relayInfo.OriginModelName)
audioCompletionRatio := operation_setting.GetAudioCompletionRatio(modelName)
completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(modelName))
audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName))
audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(modelName))
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
@@ -148,7 +165,8 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
totalTokens := usage.TotalTokens
var logContent string
if !usePrice {
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f",
modelRatio, completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), groupRatio)
} else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
}
@@ -170,7 +188,8 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
if extraContent != "" {
logContent += ", " + extraContent
}
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}
@@ -186,9 +205,9 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
audioOutTokens := usage.CompletionTokenDetails.AudioTokens
tokenName := ctx.GetString("token_name")
completionRatio := operation_setting.GetCompletionRatio(relayInfo.OriginModelName)
audioRatio := operation_setting.GetAudioRatio(relayInfo.OriginModelName)
audioCompletionRatio := operation_setting.GetAudioCompletionRatio(relayInfo.OriginModelName)
completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(relayInfo.OriginModelName))
audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName))
audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatio
@@ -215,7 +234,8 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
totalTokens := usage.TotalTokens
var logContent string
if !usePrice {
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f",
modelRatio, completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), groupRatio)
} else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
}
@@ -244,7 +264,8 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
if extraContent != "" {
logContent += ", " + extraContent
}
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}