From 68097c132d2b465c7058c14bdce4ba0f52eef52f Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Sat, 8 Mar 2025 21:55:50 +0800 Subject: [PATCH] feat: Improve decimal precision for quota and payment calculations - Added github.com/shopspring/decimal for precise floating-point calculations - Refactored quota and payment calculations in multiple files to use decimal arithmetic - Updated go.mod and go.sum to include decimal library - Improved precision in topup, relay, and quota service calculations - Added support for more OpenAI model variants in cache ratio settings --- controller/topup.go | 39 +++++++++---- go.mod | 2 +- go.sum | 4 +- relay/relay-text.go | 46 +++++++++------ service/quota.go | 71 +++++++++++++++--------- setting/operation_setting/cache_ratio.go | 5 ++ 6 files changed, 111 insertions(+), 56 deletions(-) diff --git a/controller/topup.go b/controller/topup.go index a342ec3a..ecb48298 100644 --- a/controller/topup.go +++ b/controller/topup.go @@ -2,9 +2,6 @@ package controller import ( "fmt" - "github.com/Calcium-Ion/go-epay/epay" - "github.com/gin-gonic/gin" - "github.com/samber/lo" "log" "net/url" "one-api/common" @@ -14,6 +11,11 @@ import ( "strconv" "sync" "time" + + "github.com/Calcium-Ion/go-epay/epay" + "github.com/gin-gonic/gin" + "github.com/samber/lo" + "github.com/shopspring/decimal" ) type EpayRequest struct { @@ -42,22 +44,32 @@ func GetEpayClient() *epay.Client { } func getPayMoney(amount float64, group string) float64 { + dAmount := decimal.NewFromFloat(amount) + if !common.DisplayInCurrencyEnabled { - amount = amount / common.QuotaPerUnit + dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) + dAmount = dAmount.Div(dQuotaPerUnit) } - // 别问为什么用float64,问就是这么点钱没必要 + topupGroupRatio := common.GetTopupGroupRatio(group) if topupGroupRatio == 0 { topupGroupRatio = 1 } - payMoney := amount * setting.Price * topupGroupRatio - return payMoney + + dTopupGroupRatio := decimal.NewFromFloat(topupGroupRatio) + dPrice := decimal.NewFromFloat(setting.Price) + + payMoney := dAmount.Mul(dPrice).Mul(dTopupGroupRatio) + + return payMoney.InexactFloat64() } func getMinTopup() int { minTopup := setting.MinTopUp if !common.DisplayInCurrencyEnabled { - minTopup = minTopup * int(common.QuotaPerUnit) + dMinTopup := decimal.NewFromInt(int64(minTopup)) + dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) + minTopup = int(dMinTopup.Mul(dQuotaPerUnit).IntPart()) } return minTopup } @@ -118,7 +130,9 @@ func RequestEpay(c *gin.Context) { } amount := req.Amount if !common.DisplayInCurrencyEnabled { - amount = amount / int(common.QuotaPerUnit) + dAmount := decimal.NewFromInt(int64(amount)) + dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) + amount = int(dAmount.Div(dQuotaPerUnit).IntPart()) } topUp := &model.TopUp{ UserId: id, @@ -210,13 +224,16 @@ func EpayNotify(c *gin.Context) { } //user, _ := model.GetUserById(topUp.UserId, false) //user.Quota += topUp.Amount * 500000 - err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit), true) + dAmount := decimal.NewFromInt(int64(topUp.Amount)) + dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) + quotaToAdd := int(dAmount.Mul(dQuotaPerUnit).IntPart()) + err = model.IncreaseUserQuota(topUp.UserId, quotaToAdd, true) if err != nil { log.Printf("易支付回调更新用户失败: %v", topUp) return } log.Printf("易支付回调更新用户成功 %v", topUp) - model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(topUp.Amount*int(common.QuotaPerUnit)), topUp.Money)) + model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(quotaToAdd), topUp.Money)) } } else { log.Printf("易支付异常回调: %v", verifyInfo) diff --git a/go.mod b/go.mod index c9da57c6..d5686d03 100644 --- a/go.mod +++ b/go.mod @@ -22,12 +22,12 @@ require ( github.com/golang-jwt/jwt v3.2.2+incompatible github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.0 - github.com/jinzhu/copier v0.4.0 github.com/joho/godotenv v1.5.1 github.com/pkg/errors v0.9.1 github.com/pkoukk/tiktoken-go v0.1.7 github.com/samber/lo v1.39.0 github.com/shirou/gopsutil v3.21.11+incompatible + github.com/shopspring/decimal v1.4.0 golang.org/x/crypto v0.27.0 golang.org/x/image v0.23.0 golang.org/x/net v0.28.0 diff --git a/go.sum b/go.sum index 0194ca30..44de7e52 100644 --- a/go.sum +++ b/go.sum @@ -117,8 +117,6 @@ github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs= github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= -github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8= -github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= @@ -183,6 +181,8 @@ github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA= github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= diff --git a/relay/relay-text.go b/relay/relay-text.go index af1eeca5..b1c9d515 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/bytedance/gopkg/util/gopool" "io" "math" "net/http" @@ -21,6 +20,9 @@ import ( "strings" "time" + "github.com/bytedance/gopkg/util/gopool" + "github.com/shopspring/decimal" + "github.com/gin-gonic/gin" ) @@ -315,23 +317,40 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, tokenName := ctx.GetString("token_name") completionRatio := priceData.CompletionRatio cacheRatio := priceData.CacheRatio - ratio := priceData.ModelRatio * priceData.GroupRatio modelRatio := priceData.ModelRatio groupRatio := priceData.GroupRatio modelPrice := priceData.ModelPrice - quotaCalculate := 0.0 + // Convert values to decimal for precise calculation + dPromptTokens := decimal.NewFromInt(int64(promptTokens)) + dCacheTokens := decimal.NewFromInt(int64(cacheTokens)) + dCompletionTokens := decimal.NewFromInt(int64(completionTokens)) + dCompletionRatio := decimal.NewFromFloat(completionRatio) + dCacheRatio := decimal.NewFromFloat(cacheRatio) + dModelRatio := decimal.NewFromFloat(modelRatio) + dGroupRatio := decimal.NewFromFloat(groupRatio) + dModelPrice := decimal.NewFromFloat(modelPrice) + dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) + + ratio := dModelRatio.Mul(dGroupRatio) + + var quotaCalculateDecimal decimal.Decimal if !priceData.UsePrice { - quotaCalculate = float64(promptTokens-cacheTokens) + float64(cacheTokens)*cacheRatio - quotaCalculate += float64(completionTokens) * completionRatio - quotaCalculate = quotaCalculate * ratio - if ratio != 0 && quotaCalculate <= 0 { - quotaCalculate = 1 + nonCachedTokens := dPromptTokens.Sub(dCacheTokens) + cachedTokensWithRatio := dCacheTokens.Mul(dCacheRatio) + promptQuota := nonCachedTokens.Add(cachedTokensWithRatio) + completionQuota := dCompletionTokens.Mul(dCompletionRatio) + + quotaCalculateDecimal = promptQuota.Add(completionQuota).Mul(ratio) + + if !ratio.IsZero() && quotaCalculateDecimal.LessThanOrEqual(decimal.Zero) { + quotaCalculateDecimal = decimal.NewFromInt(1) } } else { - quotaCalculate = modelPrice * common.QuotaPerUnit * groupRatio + quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio) } - quota := int(quotaCalculate) + + quota := int(quotaCalculateDecimal.Round(0).IntPart()) totalTokens := promptTokens + completionTokens var logContent string @@ -350,9 +369,6 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota)) } else { - //if sensitiveResp != nil { - // logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", ")) - //} quotaDelta := quota - preConsumedQuota if quotaDelta != 0 { err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) @@ -379,8 +395,4 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice) model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) - - //if quota != 0 { - // - //} } diff --git a/service/quota.go b/service/quota.go index 6fec7252..e19f1b82 100644 --- a/service/quota.go +++ b/service/quota.go @@ -3,7 +3,6 @@ package service import ( "errors" "fmt" - "github.com/bytedance/gopkg/util/gopool" "one-api/common" constant2 "one-api/constant" "one-api/dto" @@ -15,7 +14,10 @@ import ( "strings" "time" + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" + "github.com/shopspring/decimal" ) type TokenDetails struct { @@ -35,26 +37,41 @@ type QuotaInfo struct { func calculateAudioQuota(info QuotaInfo) int { if info.UsePrice { - return int(info.ModelPrice * common.QuotaPerUnit * info.GroupRatio) + modelPrice := decimal.NewFromFloat(info.ModelPrice) + quotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) + groupRatio := decimal.NewFromFloat(info.GroupRatio) + + quota := modelPrice.Mul(quotaPerUnit).Mul(groupRatio) + return int(quota.IntPart()) } - completionRatio := operation_setting.GetCompletionRatio(info.ModelName) - audioRatio := operation_setting.GetAudioRatio(info.ModelName) - audioCompletionRatio := operation_setting.GetAudioCompletionRatio(info.ModelName) - ratio := info.GroupRatio * info.ModelRatio + completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(info.ModelName)) + audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(info.ModelName)) + audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(info.ModelName)) - quota := 0.0 - quota += float64(info.InputDetails.TextTokens) - quota += float64(info.OutputDetails.TextTokens) * completionRatio - quota += float64(info.InputDetails.AudioTokens) * audioRatio - quota += float64(info.OutputDetails.AudioTokens) * audioRatio * audioCompletionRatio + groupRatio := decimal.NewFromFloat(info.GroupRatio) + modelRatio := decimal.NewFromFloat(info.ModelRatio) + ratio := groupRatio.Mul(modelRatio) - quota = quota * ratio - if ratio != 0 && quota <= 0 { - quota = 1 + inputTextTokens := decimal.NewFromInt(int64(info.InputDetails.TextTokens)) + outputTextTokens := decimal.NewFromInt(int64(info.OutputDetails.TextTokens)) + inputAudioTokens := decimal.NewFromInt(int64(info.InputDetails.AudioTokens)) + outputAudioTokens := decimal.NewFromInt(int64(info.OutputDetails.AudioTokens)) + + quota := decimal.Zero + quota = quota.Add(inputTextTokens) + quota = quota.Add(outputTextTokens.Mul(completionRatio)) + quota = quota.Add(inputAudioTokens.Mul(audioRatio)) + quota = quota.Add(outputAudioTokens.Mul(audioRatio).Mul(audioCompletionRatio)) + + quota = quota.Mul(ratio) + + // If ratio is not zero and quota is less than or equal to zero, set quota to 1 + if !ratio.IsZero() && quota.LessThanOrEqual(decimal.Zero) { + quota = decimal.NewFromInt(1) } - return int(quota) + return int(quota.Round(0).IntPart()) } func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error { @@ -124,9 +141,9 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod audioOutTokens := usage.OutputTokenDetails.AudioTokens tokenName := ctx.GetString("token_name") - completionRatio := operation_setting.GetCompletionRatio(modelName) - audioRatio := operation_setting.GetAudioRatio(relayInfo.OriginModelName) - audioCompletionRatio := operation_setting.GetAudioCompletionRatio(modelName) + completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(modelName)) + audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName)) + audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(modelName)) quotaInfo := QuotaInfo{ InputDetails: TokenDetails{ @@ -148,7 +165,8 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod totalTokens := usage.TotalTokens var logContent string if !usePrice { - logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio) + logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", + modelRatio, completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), groupRatio) } else { logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio) } @@ -170,7 +188,8 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod if extraContent != "" { logContent += ", " + extraContent } - other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice) + other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, + completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice) model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) } @@ -186,9 +205,9 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, audioOutTokens := usage.CompletionTokenDetails.AudioTokens tokenName := ctx.GetString("token_name") - completionRatio := operation_setting.GetCompletionRatio(relayInfo.OriginModelName) - audioRatio := operation_setting.GetAudioRatio(relayInfo.OriginModelName) - audioCompletionRatio := operation_setting.GetAudioCompletionRatio(relayInfo.OriginModelName) + completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(relayInfo.OriginModelName)) + audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName)) + audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(relayInfo.OriginModelName)) modelRatio := priceData.ModelRatio groupRatio := priceData.GroupRatio @@ -215,7 +234,8 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, totalTokens := usage.TotalTokens var logContent string if !usePrice { - logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio) + logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", + modelRatio, completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), groupRatio) } else { logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio) } @@ -244,7 +264,8 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, if extraContent != "" { logContent += ", " + extraContent } - other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice) + other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, + completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice) model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) } diff --git a/setting/operation_setting/cache_ratio.go b/setting/operation_setting/cache_ratio.go index 545a5892..98f022ed 100644 --- a/setting/operation_setting/cache_ratio.go +++ b/setting/operation_setting/cache_ratio.go @@ -8,12 +8,17 @@ import ( var defaultCacheRatio = map[string]float64{ "gpt-4": 0.5, + "o1": 0.5, "o1-2024-12-17": 0.5, "o1-preview-2024-09-12": 0.5, + "o1-preview": 0.5, "o1-mini-2024-09-12": 0.5, + "o1-mini": 0.5, "gpt-4o-2024-11-20": 0.5, "gpt-4o-2024-08-06": 0.5, + "gpt-4o": 0.5, "gpt-4o-mini-2024-07-18": 0.5, + "gpt-4o-mini": 0.5, "gpt-4o-realtime-preview": 0.5, "gpt-4o-mini-realtime-preview": 0.5, "deepseek-chat": 0.1,