diff --git a/relay/relay-text.go b/relay/relay-text.go index ddf6767d..af1eeca5 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -320,19 +320,20 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, groupRatio := priceData.GroupRatio modelPrice := priceData.ModelPrice - quota := 0 + quotaCalculate := 0.0 if !priceData.UsePrice { - quota = (promptTokens - cacheTokens) + int(math.Round(float64(cacheTokens)*cacheRatio)) - quota += int(math.Round(float64(completionTokens) * completionRatio)) - quota = int(math.Round(float64(quota) * ratio)) - if ratio != 0 && quota <= 0 { - quota = 1 + quotaCalculate = float64(promptTokens-cacheTokens) + float64(cacheTokens)*cacheRatio + quotaCalculate += float64(completionTokens) * completionRatio + quotaCalculate = quotaCalculate * ratio + if ratio != 0 && quotaCalculate <= 0 { + quotaCalculate = 1 } } else { - quota = int(modelPrice * common.QuotaPerUnit * groupRatio) + quotaCalculate = modelPrice * common.QuotaPerUnit * groupRatio } + quota := int(quotaCalculate) totalTokens := promptTokens + completionTokens - + var logContent string if !priceData.UsePrice { logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, groupRatio) diff --git a/service/quota.go b/service/quota.go index e4499ff9..6fec7252 100644 --- a/service/quota.go +++ b/service/quota.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "github.com/bytedance/gopkg/util/gopool" - "math" "one-api/common" constant2 "one-api/constant" "one-api/dto" @@ -44,16 +43,18 @@ func calculateAudioQuota(info QuotaInfo) int { audioCompletionRatio := operation_setting.GetAudioCompletionRatio(info.ModelName) ratio := info.GroupRatio * info.ModelRatio - quota := info.InputDetails.TextTokens + int(math.Round(float64(info.OutputDetails.TextTokens)*completionRatio)) - quota += int(math.Round(float64(info.InputDetails.AudioTokens)*audioRatio)) + - int(math.Round(float64(info.OutputDetails.AudioTokens)*audioRatio*audioCompletionRatio)) + quota := 0.0 + quota += float64(info.InputDetails.TextTokens) + quota += float64(info.OutputDetails.TextTokens) * completionRatio + quota += float64(info.InputDetails.AudioTokens) * audioRatio + quota += float64(info.OutputDetails.AudioTokens) * audioRatio * audioCompletionRatio - quota = int(math.Round(float64(quota) * ratio)) + quota = quota * ratio if ratio != 0 && quota <= 0 { quota = 1 } - return quota + return int(quota) } func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error {