From dd82618c0554e7fca3a658b5e9d1932dd84c7e96 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Mon, 17 Mar 2025 17:52:54 +0800 Subject: [PATCH] refactor: Improve token quota consumption logic --- relay/helper/price.go | 15 +++++++++++++-- relay/relay-text.go | 17 +++++++++-------- service/quota.go | 33 ++++++++++++++++----------------- 3 files changed, 38 insertions(+), 27 deletions(-) diff --git a/relay/helper/price.go b/relay/helper/price.go index 1ae3d2fc..5bf8f9ba 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -20,6 +20,10 @@ type PriceData struct { ShouldPreConsumedQuota int } +func (p PriceData) ToSetting() string { + return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota) +} + func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) { modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false) groupRatio := setting.GetGroupRatio(info.Group) @@ -50,7 +54,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens } else { preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) } - return PriceData{ + + priceData := PriceData{ ModelPrice: modelPrice, ModelRatio: modelRatio, CompletionRatio: completionRatio, @@ -59,5 +64,11 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens CacheRatio: cacheRatio, CacheCreationRatio: cacheCreationRatio, ShouldPreConsumedQuota: preConsumedQuota, - }, nil + } + + if common.DebugEnabled { + println(fmt.Sprintf("model_price_helper result: %s", priceData.ToSetting())) + } + + return priceData, nil } diff --git a/relay/relay-text.go b/relay/relay-text.go index a61718fc..c871b80b 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -109,7 +109,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { c.Set("prompt_tokens", promptTokens) } - priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens)) + priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens)))) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) } @@ -372,17 +372,18 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota)) } else { - quotaDelta := quota - preConsumedQuota - if quotaDelta != 0 { - err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) - if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) - } - } model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } + quotaDelta := quota - preConsumedQuota + if quotaDelta != 0 { + err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) + if err != nil { + common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + } + } + logModel := modelName if strings.HasPrefix(logModel, "gpt-4-gizmo") { logModel = "gpt-4-gizmo-*" diff --git a/service/quota.go b/service/quota.go index ec5af57a..0d11b4a0 100644 --- a/service/quota.go +++ b/service/quota.go @@ -243,20 +243,18 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota)) } else { - //if sensitiveResp != nil { - // logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", ")) - //} - quotaDelta := quota - preConsumedQuota - if quotaDelta != 0 { - err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) - if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) - } - } model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } + quotaDelta := quota - preConsumedQuota + if quotaDelta != 0 { + err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) + if err != nil { + common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + } + } + other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice) model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName, @@ -318,17 +316,18 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota)) } else { - quotaDelta := quota - preConsumedQuota - if quotaDelta != 0 { - err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) - if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) - } - } model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } + quotaDelta := quota - preConsumedQuota + if quotaDelta != 0 { + err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) + if err != nil { + common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + } + } + logModel := relayInfo.OriginModelName if extraContent != "" { logContent += ", " + extraContent