Merge pull request #1281 from QuantumNous/mj_usergroupratio

feat: support user-group-specific for mj and task
This commit is contained in:
Calcium-Ion
2025-06-22 18:08:11 +08:00
committed by GitHub
12 changed files with 113 additions and 86 deletions

View File

@@ -7,4 +7,5 @@ const (
ContextKeyUserStatus = "user_status"
ContextKeyUserEmail = "user_email"
ContextKeyUserGroup = "user_group"
ContextKeyUsingGroup = "group"
)

View File

@@ -171,7 +171,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other)
quota, "模型测试", 0, quota, int(consumedTime), false, info.UsingGroup, other)
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
return nil, nil
}

View File

@@ -57,7 +57,7 @@ func Distribute() func(c *gin.Context) {
}
userGroup = tokenGroup
}
c.Set("group", userGroup)
c.Set(constant.ContextKeyUsingGroup, userGroup)
if ok {
id, err := strconv.Atoi(channelId.(string))
if err != nil {

View File

@@ -65,8 +65,8 @@ type RelayInfo struct {
TokenId int
TokenKey string
UserId int
Group string
UserGroup string
UsingGroup string // 使用的分组
UserGroup string // 用户所在分组
TokenUnlimited bool
StartTime time.Time
FirstResponseTime time.Time
@@ -219,7 +219,6 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
tokenId := c.GetInt("token_id")
tokenKey := c.GetString("token_key")
userId := c.GetInt("id")
group := c.GetString("group")
tokenUnlimited := c.GetBool("token_unlimited_quota")
startTime := c.GetTime(constant.ContextKeyRequestStartTime)
// firstResponseTime = time.Now() - 1 second
@@ -239,7 +238,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
TokenId: tokenId,
TokenKey: tokenKey,
UserId: userId,
Group: group,
UsingGroup: c.GetString(constant.ContextKeyUsingGroup),
UserGroup: c.GetString(constant.ContextKeyUserGroup),
TokenUnlimited: tokenUnlimited,
StartTime: startTime,

View File

@@ -13,6 +13,7 @@ import (
type GroupRatioInfo struct {
GroupRatio float64
GroupSpecialRatio float64
HasSpecialRatio bool
}
type PriceData struct {
@@ -31,7 +32,7 @@ func (p PriceData) ToSetting() string {
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
}
// HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.Group if present
// HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.UsingGroup if present
func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupRatioInfo {
groupRatioInfo := GroupRatioInfo{
GroupRatio: 1.0, // default ratio
@@ -44,18 +45,19 @@ func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupR
if common.DebugEnabled {
println(fmt.Sprintf("final group: %s", autoGroup))
}
relayInfo.Group = autoGroup.(string)
relayInfo.UsingGroup = autoGroup.(string)
}
// check user group special ratio
userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup)
if ok {
// user group special ratio
groupRatioInfo.GroupSpecialRatio = userGroupRatio
groupRatioInfo.GroupRatio = userGroupRatio
groupRatioInfo.HasSpecialRatio = true
} else {
// normal group ratio
groupRatioInfo.GroupRatio = ratio_setting.GetGroupRatio(relayInfo.Group)
groupRatioInfo.GroupRatio = ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
}
return groupRatioInfo
@@ -120,6 +122,35 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
return priceData, nil
}
type PerCallPriceData struct {
ModelPrice float64
Quota int
GroupRatioInfo GroupRatioInfo
}
// ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task)
func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) PerCallPriceData {
groupRatioInfo := HandleGroupRatio(c, info)
modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true)
// 如果没有配置价格,则使用默认价格
if !success {
defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName]
if !ok {
modelPrice = 0.1
} else {
modelPrice = defaultPrice
}
}
quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
priceData := PerCallPriceData{
ModelPrice: modelPrice,
Quota: quota,
GroupRatioInfo: groupRatioInfo,
}
return priceData
}
func ContainPriceOrRatio(modelName string) bool {
_, ok := ratio_setting.GetModelPrice(modelName, false)
if ok {

View File

@@ -13,9 +13,9 @@ import (
"one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
"one-api/setting/ratio_setting"
"strconv"
"strings"
"time"
@@ -174,18 +174,9 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
}
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
// 如果没有配置价格,则使用默认价格
if !success {
defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
if !ok {
modelPrice = 0.1
} else {
modelPrice = defaultPrice
}
}
groupRatio := ratio_setting.GetGroupRatio(group)
ratio := modelPrice * groupRatio
priceData := helper.ModelPriceHelperPerCall(c, relayInfo)
userQuota, err := model.GetUserQuota(userId, false)
if err != nil {
return &dto.MidjourneyResponse{
@@ -193,9 +184,8 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
Description: err.Error(),
}
}
quota := int(ratio * common.QuotaPerUnit)
if userQuota-quota < 0 {
if userQuota-priceData.Quota < 0 {
return &dto.MidjourneyResponse{
Code: 4,
Description: "quota_not_enough",
@@ -210,26 +200,18 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
}
defer func() {
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
//err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, constant.MjActionSwapFace)
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
quota, logContent, tokenId, userQuota, 0, false, group, other)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace)
other := service.GenerateMjOtherInfo(priceData)
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
priceData.Quota, logContent, tokenId, userQuota, 0, false, group, other)
model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota)
model.UpdateChannelUsedQuota(channelId, priceData.Quota)
}
}()
midjResponse := &mjResp.Response
@@ -250,7 +232,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
Progress: "0%",
FailReason: "",
ChannelId: c.GetInt("channel_id"),
Quota: quota,
Quota: priceData.Quota,
}
err = midjourneyTask.Insert()
if err != nil {
@@ -480,18 +462,9 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
modelName := service.CoverActionToModelName(midjRequest.Action)
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
// 如果没有配置价格,则使用默认价格
if !success {
defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
if !ok {
modelPrice = 0.1
} else {
modelPrice = defaultPrice
}
}
groupRatio := ratio_setting.GetGroupRatio(group)
ratio := modelPrice * groupRatio
priceData := helper.ModelPriceHelperPerCall(c, relayInfo)
userQuota, err := model.GetUserQuota(userId, false)
if err != nil {
return &dto.MidjourneyResponse{
@@ -499,9 +472,8 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
Description: err.Error(),
}
}
quota := int(ratio * common.QuotaPerUnit)
if consumeQuota && userQuota-quota < 0 {
if consumeQuota && userQuota-priceData.Quota < 0 {
return &dto.MidjourneyResponse{
Code: 4,
Description: "quota_not_enough",
@@ -516,22 +488,17 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
defer func() {
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %sID %s", modelPrice, groupRatio, midjRequest.Action, midjResponse.Result)
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
quota, logContent, tokenId, userQuota, 0, false, group, other)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %sID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result)
other := service.GenerateMjOtherInfo(priceData)
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
priceData.Quota, logContent, tokenId, userQuota, 0, false, group, other)
model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota)
model.UpdateChannelUsedQuota(channelId, priceData.Quota)
}
}()
@@ -559,7 +526,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
Progress: "0%",
FailReason: "",
ChannelId: c.GetInt("channel_id"),
Quota: quota,
Quota: priceData.Quota,
}
if midjResponse.Code == 3 {
//无实例账号自动禁用渠道No available account instance

View File

@@ -541,5 +541,5 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
other["audio_input_price"] = audioInputPrice
}
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
}

View File

@@ -5,7 +5,6 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
@@ -16,6 +15,8 @@ import (
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
/*
@@ -51,8 +52,14 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
}
// 预扣
groupRatio := ratio_setting.GetGroupRatio(relayInfo.Group)
ratio := modelPrice * groupRatio
groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
var ratio float64
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup)
if hasUserGroupRatio {
ratio = modelPrice * userGroupRatio
} else {
ratio = modelPrice * groupRatio
}
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
@@ -121,12 +128,19 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, relayInfo.Action)
gRatio := groupRatio
if hasUserGroupRatio {
gRatio = userGroupRatio
}
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, relayInfo.Action)
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
if hasUserGroupRatio {
other["user_group_ratio"] = userGroupRatio
}
model.RecordConsumeLog(c, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.Group, other)
modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.UsingGroup, other)
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}

View File

@@ -3,6 +3,7 @@ package service
import (
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"github.com/gin-gonic/gin"
)
@@ -63,3 +64,13 @@ func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
info["cache_creation_ratio"] = cacheCreationRatio
return info
}
func GenerateMjOtherInfo(priceData helper.PerCallPriceData) map[string]interface{} {
other := make(map[string]interface{})
other["model_price"] = priceData.ModelPrice
other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio
if priceData.GroupRatioInfo.HasSpecialRatio {
other["user_group_ratio"] = priceData.GroupRatioInfo.GroupSpecialRatio
}
return other
}

View File

@@ -95,18 +95,18 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
textOutTokens := usage.OutputTokenDetails.TextTokens
audioInputTokens := usage.InputTokenDetails.AudioTokens
audioOutTokens := usage.OutputTokenDetails.AudioTokens
groupRatio := ratio_setting.GetGroupRatio(relayInfo.Group)
groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
modelRatio, _ := ratio_setting.GetModelRatio(modelName)
autoGroup, exists := ctx.Get("auto_group")
if exists {
groupRatio = ratio_setting.GetGroupRatio(autoGroup.(string))
log.Printf("final group ratio: %f", groupRatio)
relayInfo.Group = autoGroup.(string)
relayInfo.UsingGroup = autoGroup.(string)
}
actualGroupRatio := groupRatio
userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup)
if ok {
actualGroupRatio = userGroupRatio
}
@@ -210,7 +210,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
}
func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
@@ -287,7 +287,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
}
func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData) int {
@@ -385,7 +385,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
}
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {

View File

@@ -73,15 +73,15 @@ func GetGroupRatio(name string) float64 {
return ratio
}
func GetGroupGroupRatio(group, name string) (float64, bool) {
func GetGroupGroupRatio(userGroup, usingGroup string) (float64, bool) {
groupGroupRatioMutex.RLock()
defer groupGroupRatioMutex.RUnlock()
gp, ok := GroupGroupRatio[group]
gp, ok := GroupGroupRatio[userGroup]
if !ok {
return -1, false
}
ratio, ok := gp[name]
ratio, ok := gp[usingGroup]
if !ok {
return -1, false
}

View File

@@ -66,6 +66,10 @@ export default defineConfig({
target: 'http://localhost:3000',
changeOrigin: true,
},
'/mj': {
target: 'http://localhost:3000',
changeOrigin: true,
},
'/pg': {
target: 'http://localhost:3000',
changeOrigin: true,