diff --git a/constant/context_key.go b/constant/context_key.go index 4b4d5cae..895b0fcb 100644 --- a/constant/context_key.go +++ b/constant/context_key.go @@ -7,4 +7,5 @@ const ( ContextKeyUserStatus = "user_status" ContextKeyUserEmail = "user_email" ContextKeyUserGroup = "user_group" + ContextKeyUsingGroup = "group" ) diff --git a/controller/channel-test.go b/controller/channel-test.go index d54ccf0d..db8c9db0 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -171,7 +171,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio, usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试", - quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other) + quota, "模型测试", 0, quota, int(consumedTime), false, info.UsingGroup, other) common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) return nil, nil } diff --git a/middleware/distributor.go b/middleware/distributor.go index 9d074ce8..e2f63602 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -57,7 +57,7 @@ func Distribute() func(c *gin.Context) { } userGroup = tokenGroup } - c.Set("group", userGroup) + c.Set(constant.ContextKeyUsingGroup, userGroup) if ok { id, err := strconv.Atoi(channelId.(string)) if err != nil { diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 3759c363..f3fc9ce9 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -65,8 +65,8 @@ type RelayInfo struct { TokenId int TokenKey string UserId int - Group string - UserGroup string + UsingGroup string // 使用的分组 + UserGroup string // 用户所在分组 TokenUnlimited bool StartTime time.Time FirstResponseTime time.Time @@ -219,7 +219,6 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { tokenId := c.GetInt("token_id") tokenKey := c.GetString("token_key") userId := c.GetInt("id") - group := c.GetString("group") tokenUnlimited := c.GetBool("token_unlimited_quota") startTime := c.GetTime(constant.ContextKeyRequestStartTime) // firstResponseTime = time.Now() - 1 second @@ -239,7 +238,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { TokenId: tokenId, TokenKey: tokenKey, UserId: userId, - Group: group, + UsingGroup: c.GetString(constant.ContextKeyUsingGroup), UserGroup: c.GetString(constant.ContextKeyUserGroup), TokenUnlimited: tokenUnlimited, StartTime: startTime, diff --git a/relay/helper/price.go b/relay/helper/price.go index 1ee2767e..ab614cbd 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -13,6 +13,7 @@ import ( type GroupRatioInfo struct { GroupRatio float64 GroupSpecialRatio float64 + HasSpecialRatio bool } type PriceData struct { @@ -31,7 +32,7 @@ func (p PriceData) ToSetting() string { return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) } -// HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.Group if present +// HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.UsingGroup if present func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupRatioInfo { groupRatioInfo := GroupRatioInfo{ GroupRatio: 1.0, // default ratio @@ -44,18 +45,19 @@ func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupR if common.DebugEnabled { println(fmt.Sprintf("final group: %s", autoGroup)) } - relayInfo.Group = autoGroup.(string) + relayInfo.UsingGroup = autoGroup.(string) } // check user group special ratio - userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group) + userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup) if ok { // user group special ratio groupRatioInfo.GroupSpecialRatio = userGroupRatio groupRatioInfo.GroupRatio = userGroupRatio + groupRatioInfo.HasSpecialRatio = true } else { // normal group ratio - groupRatioInfo.GroupRatio = ratio_setting.GetGroupRatio(relayInfo.Group) + groupRatioInfo.GroupRatio = ratio_setting.GetGroupRatio(relayInfo.UsingGroup) } return groupRatioInfo @@ -120,6 +122,35 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens return priceData, nil } +type PerCallPriceData struct { + ModelPrice float64 + Quota int + GroupRatioInfo GroupRatioInfo +} + +// ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task) +func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) PerCallPriceData { + groupRatioInfo := HandleGroupRatio(c, info) + + modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true) + // 如果没有配置价格,则使用默认价格 + if !success { + defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName] + if !ok { + modelPrice = 0.1 + } else { + modelPrice = defaultPrice + } + } + quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio) + priceData := PerCallPriceData{ + ModelPrice: modelPrice, + Quota: quota, + GroupRatioInfo: groupRatioInfo, + } + return priceData +} + func ContainPriceOrRatio(modelName string) bool { _, ok := ratio_setting.GetModelPrice(modelName, false) if ok { diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 6465dc88..b44890c1 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -13,9 +13,9 @@ import ( "one-api/model" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" + "one-api/relay/helper" "one-api/service" "one-api/setting" - "one-api/setting/ratio_setting" "strconv" "strings" "time" @@ -174,24 +174,9 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required") } modelName := service.CoverActionToModelName(constant.MjActionSwapFace) - modelPrice, success := ratio_setting.GetModelPrice(modelName, true) - // 如果没有配置价格,则使用默认价格 - if !success { - defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName] - if !ok { - modelPrice = 0.1 - } else { - modelPrice = defaultPrice - } - } - groupRatio := ratio_setting.GetGroupRatio(group) - var ratio float64 - userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, group) - if hasUserGroupRatio { - ratio = modelPrice * userGroupRatio - } else { - ratio = modelPrice * groupRatio - } + + priceData := helper.ModelPriceHelperPerCall(c, relayInfo) + userQuota, err := model.GetUserQuota(userId, false) if err != nil { return &dto.MidjourneyResponse{ @@ -199,9 +184,8 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { Description: err.Error(), } } - quota := int(ratio * common.QuotaPerUnit) - if userQuota-quota < 0 { + if userQuota-priceData.Quota < 0 { return &dto.MidjourneyResponse{ Code: 4, Description: "quota_not_enough", @@ -216,27 +200,18 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { } defer func() { if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 { - err := service.PostConsumeQuota(relayInfo, quota, 0, true) + err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) } - //err = model.CacheUpdateUserQuota(userId) - // if err != nil { - // common.SysError("error update user quota cache: " + err.Error()) - // } - if quota != 0 { - tokenName := c.GetString("token_name") - gRatio := groupRatio - if hasUserGroupRatio { - gRatio = userGroupRatio - } - logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, constant.MjActionSwapFace) - other := genMjOtherInfo(modelPrice, groupRatio, userGroupRatio, hasUserGroupRatio) - model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName, - quota, logContent, tokenId, userQuota, 0, false, group, other) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - model.UpdateChannelUsedQuota(channelId, quota) - } + + tokenName := c.GetString("token_name") + logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace) + other := service.GenerateMjOtherInfo(priceData) + model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName, + priceData.Quota, logContent, tokenId, userQuota, 0, false, group, other) + model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota) + model.UpdateChannelUsedQuota(channelId, priceData.Quota) } }() midjResponse := &mjResp.Response @@ -257,7 +232,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { Progress: "0%", FailReason: "", ChannelId: c.GetInt("channel_id"), - Quota: quota, + Quota: priceData.Quota, } err = midjourneyTask.Insert() if err != nil { @@ -487,24 +462,9 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) modelName := service.CoverActionToModelName(midjRequest.Action) - modelPrice, success := ratio_setting.GetModelPrice(modelName, true) - // 如果没有配置价格,则使用默认价格 - if !success { - defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName] - if !ok { - modelPrice = 0.1 - } else { - modelPrice = defaultPrice - } - } - groupRatio := ratio_setting.GetGroupRatio(group) - var ratio float64 - userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, group) - if hasUserGroupRatio { - ratio = modelPrice * userGroupRatio - } else { - ratio = modelPrice * groupRatio - } + + priceData := helper.ModelPriceHelperPerCall(c, relayInfo) + userQuota, err := model.GetUserQuota(userId, false) if err != nil { return &dto.MidjourneyResponse{ @@ -512,9 +472,8 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons Description: err.Error(), } } - quota := int(ratio * common.QuotaPerUnit) - if consumeQuota && userQuota-quota < 0 { + if consumeQuota && userQuota-priceData.Quota < 0 { return &dto.MidjourneyResponse{ Code: 4, Description: "quota_not_enough", @@ -529,23 +488,17 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons defer func() { if consumeQuota && midjResponseWithStatus.StatusCode == 200 { - err := service.PostConsumeQuota(relayInfo, quota, 0, true) + err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) } - if quota != 0 { - tokenName := c.GetString("token_name") - gRatio := groupRatio - if hasUserGroupRatio { - gRatio = userGroupRatio - } - logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", modelPrice, gRatio, midjRequest.Action, midjResponse.Result) - other := genMjOtherInfo(modelPrice, groupRatio, userGroupRatio, hasUserGroupRatio) - model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName, - quota, logContent, tokenId, userQuota, 0, false, group, other) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - model.UpdateChannelUsedQuota(channelId, quota) - } + tokenName := c.GetString("token_name") + logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result) + other := service.GenerateMjOtherInfo(priceData) + model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName, + priceData.Quota, logContent, tokenId, userQuota, 0, false, group, other) + model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota) + model.UpdateChannelUsedQuota(channelId, priceData.Quota) } }() @@ -573,7 +526,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons Progress: "0%", FailReason: "", ChannelId: c.GetInt("channel_id"), - Quota: quota, + Quota: priceData.Quota, } if midjResponse.Code == 3 { //无实例账号自动禁用渠道(No available account instance) @@ -673,13 +626,3 @@ func getMjRequestPath(path string) string { } return requestURL } - -func genMjOtherInfo(modelPrice, groupRatio, userGroupRatio float64, hasUserGroupRatio bool) map[string]interface{} { - other := make(map[string]interface{}) - other["model_price"] = modelPrice - other["group_ratio"] = groupRatio - if hasUserGroupRatio && userGroupRatio > 0 { - other["user_group_ratio"] = userGroupRatio - } - return other -} diff --git a/relay/relay-text.go b/relay/relay-text.go index db8d0d3b..e0c8f047 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -541,5 +541,5 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, other["audio_input_price"] = audioInputPrice } model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, - tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) + tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other) } diff --git a/relay/relay_task.go b/relay/relay_task.go index 3c0cea42..b8004105 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -52,9 +52,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { } // 预扣 - groupRatio := ratio_setting.GetGroupRatio(relayInfo.Group) + groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup) var ratio float64 - userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group) + userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup) if hasUserGroupRatio { ratio = modelPrice * userGroupRatio } else { @@ -140,7 +140,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { other["user_group_ratio"] = userGroupRatio } model.RecordConsumeLog(c, relayInfo.UserId, relayInfo.ChannelId, 0, 0, - modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.Group, other) + modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.UsingGroup, other) model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } diff --git a/service/log_info_generate.go b/service/log_info_generate.go index 1edc9073..affae5fb 100644 --- a/service/log_info_generate.go +++ b/service/log_info_generate.go @@ -3,6 +3,7 @@ package service import ( "one-api/dto" relaycommon "one-api/relay/common" + "one-api/relay/helper" "github.com/gin-gonic/gin" ) @@ -63,3 +64,13 @@ func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, info["cache_creation_ratio"] = cacheCreationRatio return info } + +func GenerateMjOtherInfo(priceData helper.PerCallPriceData) map[string]interface{} { + other := make(map[string]interface{}) + other["model_price"] = priceData.ModelPrice + other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio + if priceData.GroupRatioInfo.HasSpecialRatio { + other["user_group_ratio"] = priceData.GroupRatioInfo.GroupSpecialRatio + } + return other +} diff --git a/service/quota.go b/service/quota.go index 8005a1fb..c17616a7 100644 --- a/service/quota.go +++ b/service/quota.go @@ -95,18 +95,18 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag textOutTokens := usage.OutputTokenDetails.TextTokens audioInputTokens := usage.InputTokenDetails.AudioTokens audioOutTokens := usage.OutputTokenDetails.AudioTokens - groupRatio := ratio_setting.GetGroupRatio(relayInfo.Group) + groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup) modelRatio, _ := ratio_setting.GetModelRatio(modelName) autoGroup, exists := ctx.Get("auto_group") if exists { groupRatio = ratio_setting.GetGroupRatio(autoGroup.(string)) log.Printf("final group ratio: %f", groupRatio) - relayInfo.Group = autoGroup.(string) + relayInfo.UsingGroup = autoGroup.(string) } actualGroupRatio := groupRatio - userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group) + userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup) if ok { actualGroupRatio = userGroupRatio } @@ -210,7 +210,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel, - tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) + tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other) } func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, @@ -287,7 +287,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName, - tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) + tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other) } func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData) int { @@ -385,7 +385,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel, - tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) + tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other) } func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error { diff --git a/setting/ratio_setting/group_ratio.go b/setting/ratio_setting/group_ratio.go index f600a7b5..86f4a8d1 100644 --- a/setting/ratio_setting/group_ratio.go +++ b/setting/ratio_setting/group_ratio.go @@ -73,15 +73,15 @@ func GetGroupRatio(name string) float64 { return ratio } -func GetGroupGroupRatio(group, name string) (float64, bool) { +func GetGroupGroupRatio(userGroup, usingGroup string) (float64, bool) { groupGroupRatioMutex.RLock() defer groupGroupRatioMutex.RUnlock() - gp, ok := GroupGroupRatio[group] + gp, ok := GroupGroupRatio[userGroup] if !ok { return -1, false } - ratio, ok := gp[name] + ratio, ok := gp[usingGroup] if !ok { return -1, false } diff --git a/web/vite.config.js b/web/vite.config.js index 5681d30e..78825b4a 100644 --- a/web/vite.config.js +++ b/web/vite.config.js @@ -66,6 +66,10 @@ export default defineConfig({ target: 'http://localhost:3000', changeOrigin: true, }, + '/mj': { + target: 'http://localhost:3000', + changeOrigin: true, + }, '/pg': { target: 'http://localhost:3000', changeOrigin: true,