From 0f35d2368f918bf51569fd11fbe1111171353a41 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Tue, 17 Jun 2025 21:05:35 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20enhance=20group=20ratio=20h?= =?UTF-8?q?andling=20in=20pricing=20calculations?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/channel-test.go | 6 ++-- model/cache.go | 9 ++++-- model/option.go | 6 ++-- relay/helper/price.go | 65 +++++++++++++++++++++++++------------- relay/relay-image.go | 2 +- relay/relay-text.go | 5 ++- relay/websocket.go | 43 ++++--------------------- service/quota.go | 46 +++++++-------------------- 8 files changed, 76 insertions(+), 106 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index 52f8a7ef..d162d8cf 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -165,8 +165,8 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() consumedTime := float64(milliseconds) / 1000.0 - other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio, - usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.UserGroupRatio) + other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio, + usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试", quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other) common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) @@ -312,7 +312,7 @@ func testAllChannels(notify bool) error { channel.UpdateResponseTime(milliseconds) time.Sleep(common.RequestInterval) } - + if notify { service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成") } diff --git a/model/cache.go b/model/cache.go index 1d7d2f25..3e5eb4c4 100644 --- a/model/cache.go +++ b/model/cache.go @@ -3,7 +3,6 @@ package model import ( "errors" "fmt" - "log" "math/rand" "one-api/common" "one-api/setting" @@ -88,14 +87,18 @@ func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, return nil, selectGroup, errors.New("auto groups is not enabled") } for _, autoGroup := range setting.AutoGroups { - log.Printf("autoGroup: %s", autoGroup) + if common.DebugEnabled { + println("autoGroup:", autoGroup) + } channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry) if channel == nil { continue } else { c.Set("auto_group", autoGroup) selectGroup = autoGroup - log.Printf("selectGroup: %s", selectGroup) + if common.DebugEnabled { + println("selectGroup:", selectGroup) + } break } } diff --git a/model/option.go b/model/option.go index 89ab8506..1391b203 100644 --- a/model/option.go +++ b/model/option.go @@ -194,7 +194,7 @@ func updateOptionMap(key string, value string) (err error) { common.ImageDownloadPermission = intValue } } - if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" { + if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" || key == "DefaultUseAutoGroup" { boolValue := value == "true" switch key { case "PasswordRegisterEnabled": @@ -263,6 +263,8 @@ func updateOptionMap(key string, value string) (err error) { common.SMTPSSLEnabled = boolValue case "WorkerAllowHttpImageRequestEnabled": setting.WorkerAllowHttpImageRequestEnabled = boolValue + case "DefaultUseAutoGroup": + setting.DefaultUseAutoGroup = boolValue } } switch key { @@ -291,8 +293,6 @@ func updateOptionMap(key string, value string) (err error) { err = setting.UpdateChatsByJsonString(value) case "AutoGroups": err = setting.UpdateAutoGroupsByJsonString(value) - case "DefaultUseAutoGroup": - setting.DefaultUseAutoGroup = value == "true" case "CustomCallbackAddress": setting.CustomCallbackAddress = value case "EpayId": diff --git a/relay/helper/price.go b/relay/helper/price.go index 6ecebac5..326790b4 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -2,7 +2,6 @@ package helper import ( "fmt" - "log" "one-api/common" constant2 "one-api/constant" relaycommon "one-api/relay/common" @@ -12,6 +11,11 @@ import ( "github.com/gin-gonic/gin" ) +type GroupRatioInfo struct { + GroupRatio float64 + GroupSpecialRatio float64 +} + type PriceData struct { ModelPrice float64 ModelRatio float64 @@ -19,32 +23,50 @@ type PriceData struct { CacheRatio float64 CacheCreationRatio float64 ImageRatio float64 - GroupRatio float64 - UserGroupRatio float64 UsePrice bool ShouldPreConsumedQuota int + GroupRatioInfo GroupRatioInfo } func (p PriceData) ToSetting() string { - return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) + return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) +} + +// HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.Group if present +func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupRatioInfo { + groupRatioInfo := GroupRatioInfo{ + GroupRatio: 1.0, // default ratio + GroupSpecialRatio: 1.0, // default user group ratio + } + + // check auto group + autoGroup, exists := ctx.Get("auto_group") + if exists { + if common.DebugEnabled { + println(fmt.Sprintf("final group: %s", autoGroup)) + } + relayInfo.Group = autoGroup.(string) + } + + // check user group special ratio + userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group) + if ok { + // user group special ratio + groupRatioInfo.GroupSpecialRatio = userGroupRatio + groupRatioInfo.GroupRatio = userGroupRatio + } else { + // normal group ratio + groupRatioInfo.GroupRatio = setting.GetGroupRatio(relayInfo.Group) + } + + return groupRatioInfo } func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) { modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false) - groupRatio := setting.GetGroupRatio(info.Group) - var userGroupRatio float64 - autoGroup, exists := c.Get("auto_group") - if exists { - groupRatio = setting.GetGroupRatio(autoGroup.(string)) - log.Printf("final group ratio: %f", groupRatio) - info.Group = autoGroup.(string) - } - actualGroupRatio := groupRatio - userGroupRatio, ok := setting.GetGroupGroupRatio(info.UserGroup, info.Group) - if ok { - actualGroupRatio = userGroupRatio - } - groupRatio = actualGroupRatio + + groupRatioInfo := HandleGroupRatio(c, info) + var preConsumedQuota int var modelRatio float64 var completionRatio float64 @@ -74,18 +96,17 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName) cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName) imageRatio, _ = operation_setting.GetImageRatio(info.OriginModelName) - ratio := modelRatio * groupRatio + ratio := modelRatio * groupRatioInfo.GroupRatio preConsumedQuota = int(float64(preConsumedTokens) * ratio) } else { - preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) + preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio) } priceData := PriceData{ ModelPrice: modelPrice, ModelRatio: modelRatio, CompletionRatio: completionRatio, - GroupRatio: groupRatio, - UserGroupRatio: userGroupRatio, + GroupRatioInfo: groupRatioInfo, UsePrice: usePrice, CacheRatio: cacheRatio, ImageRatio: imageRatio, diff --git a/relay/relay-image.go b/relay/relay-image.go index dc63cce8..197a8af6 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -162,7 +162,7 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { // reset model price priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N) - quota = int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit) + quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit) userQuota, err = model.GetUserQuota(relayInfo.UserId, false) if err != nil { return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) diff --git a/relay/relay-text.go b/relay/relay-text.go index 3aa382e8..c94e0f50 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -361,9 +361,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, cacheRatio := priceData.CacheRatio imageRatio := priceData.ImageRatio modelRatio := priceData.ModelRatio - groupRatio := priceData.GroupRatio + groupRatio := priceData.GroupRatioInfo.GroupRatio modelPrice := priceData.ModelPrice - userGroupRatio := priceData.UserGroupRatio // Convert values to decimal for precise calculation dPromptTokens := decimal.NewFromInt(int64(promptTokens)) @@ -511,7 +510,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, if extraContent != "" { logContent += ", " + extraContent } - other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, userGroupRatio) + other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) if imageTokens != 0 { other["image"] = true other["image_ratio"] = imageRatio diff --git a/relay/websocket.go b/relay/websocket.go index c815eb71..571f3a82 100644 --- a/relay/websocket.go +++ b/relay/websocket.go @@ -6,12 +6,10 @@ import ( "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "net/http" - "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/service" - "one-api/setting" - "one-api/setting/operation_setting" ) func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) { @@ -39,43 +37,14 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi //isModelMapped = true } } - //relayInfo.UpstreamModelName = textRequest.Model - modelPrice, getModelPriceSuccess := operation_setting.GetModelPrice(relayInfo.UpstreamModelName, false) - groupRatio := setting.GetGroupRatio(relayInfo.Group) - var preConsumedQuota int - var ratio float64 - var modelRatio float64 - //err := service.SensitiveWordsCheck(textRequest) - - //if constant.ShouldCheckPromptSensitive() { - // err = checkRequestSensitive(textRequest, relayInfo) - // if err != nil { - // return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest) - // } - //} - - //promptTokens, err := getWssPromptTokens(realtimeEvent, relayInfo) - //// count messages token error 计算promptTokens错误 - //if err != nil { - // return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError) - //} - // - if !getModelPriceSuccess { - preConsumedTokens := common.PreConsumedQuota - //if realtimeEvent.Session.MaxResponseOutputTokens != 0 { - // preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens) - //} - modelRatio, _ = operation_setting.GetModelRatio(relayInfo.UpstreamModelName) - ratio = modelRatio * groupRatio - preConsumedQuota = int(float64(preConsumedTokens) * ratio) - } else { - preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) - relayInfo.UsePrice = true + priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) } // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo) + preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) if openaiErr != nil { return openaiErr } @@ -113,6 +82,6 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi return openaiErr } service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota, - userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") + userQuota, priceData, "") return nil } diff --git a/service/quota.go b/service/quota.go index 75b186ae..0fb9e67c 100644 --- a/service/quota.go +++ b/service/quota.go @@ -144,8 +144,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag } func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, - usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, - modelPrice float64, usePrice bool, extraContent string) { + usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() textInputTokens := usage.InputTokenDetails.TextTokens @@ -159,18 +158,11 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName)) audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(modelName)) - autoGroup, exists := ctx.Get("auto_group") - if exists { - groupRatio = setting.GetGroupRatio(autoGroup.(string)) - log.Printf("final group ratio: %f", groupRatio) - relayInfo.Group = autoGroup.(string) - } + modelRatio := priceData.ModelRatio + groupRatio := priceData.GroupRatioInfo.GroupRatio + modelPrice := priceData.ModelPrice + usePrice := priceData.UsePrice - actualGroupRatio := groupRatio - userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group) - if ok { - actualGroupRatio = userGroupRatio - } quotaInfo := QuotaInfo{ InputDetails: TokenDetails{ TextTokens: textInputTokens, @@ -183,7 +175,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod ModelName: modelName, UsePrice: usePrice, ModelRatio: modelRatio, - GroupRatio: actualGroupRatio, + GroupRatio: groupRatio, } quota := calculateAudioQuota(quotaInfo) @@ -215,7 +207,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod logContent += ", " + extraContent } other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, - completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, userGroupRatio) + completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) } @@ -231,9 +223,8 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, tokenName := ctx.GetString("token_name") completionRatio := priceData.CompletionRatio modelRatio := priceData.ModelRatio - groupRatio := priceData.GroupRatio + groupRatio := priceData.GroupRatioInfo.GroupRatio modelPrice := priceData.ModelPrice - userGroupRatio := priceData.UserGroupRatio cacheRatio := priceData.CacheRatio cacheTokens := usage.PromptTokensDetails.CachedTokens @@ -282,7 +273,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, } other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, - cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, userGroupRatio) + cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) } @@ -303,23 +294,10 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(relayInfo.OriginModelName)) modelRatio := priceData.ModelRatio - groupRatio := priceData.GroupRatio + groupRatio := priceData.GroupRatioInfo.GroupRatio modelPrice := priceData.ModelPrice usePrice := priceData.UsePrice - autoGroup, exists := ctx.Get("auto_group") - if exists { - groupRatio = setting.GetGroupRatio(autoGroup.(string)) - log.Printf("final group ratio: %f", groupRatio) - relayInfo.Group = autoGroup.(string) - } - - actualGroupRatio := groupRatio - userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group) - if ok { - actualGroupRatio = userGroupRatio - } - quotaInfo := QuotaInfo{ InputDetails: TokenDetails{ TextTokens: textInputTokens, @@ -332,7 +310,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, ModelName: relayInfo.OriginModelName, UsePrice: usePrice, ModelRatio: modelRatio, - GroupRatio: actualGroupRatio, + GroupRatio: groupRatio, } quota := calculateAudioQuota(quotaInfo) @@ -372,7 +350,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, logContent += ", " + extraContent } other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, - completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, userGroupRatio) + completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) }