feat: enhance group ratio handling in pricing calculations

This commit is contained in:
CaIon
2025-06-17 21:05:35 +08:00
parent 3c276d13c4
commit 0f35d2368f
8 changed files with 76 additions and 106 deletions

View File

@@ -2,7 +2,6 @@ package helper
import (
"fmt"
"log"
"one-api/common"
constant2 "one-api/constant"
relaycommon "one-api/relay/common"
@@ -12,6 +11,11 @@ import (
"github.com/gin-gonic/gin"
)
type GroupRatioInfo struct {
GroupRatio float64
GroupSpecialRatio float64
}
type PriceData struct {
ModelPrice float64
ModelRatio float64
@@ -19,32 +23,50 @@ type PriceData struct {
CacheRatio float64
CacheCreationRatio float64
ImageRatio float64
GroupRatio float64
UserGroupRatio float64
UsePrice bool
ShouldPreConsumedQuota int
GroupRatioInfo GroupRatioInfo
}
func (p PriceData) ToSetting() string {
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
}
// HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.Group if present
func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupRatioInfo {
groupRatioInfo := GroupRatioInfo{
GroupRatio: 1.0, // default ratio
GroupSpecialRatio: 1.0, // default user group ratio
}
// check auto group
autoGroup, exists := ctx.Get("auto_group")
if exists {
if common.DebugEnabled {
println(fmt.Sprintf("final group: %s", autoGroup))
}
relayInfo.Group = autoGroup.(string)
}
// check user group special ratio
userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
if ok {
// user group special ratio
groupRatioInfo.GroupSpecialRatio = userGroupRatio
groupRatioInfo.GroupRatio = userGroupRatio
} else {
// normal group ratio
groupRatioInfo.GroupRatio = setting.GetGroupRatio(relayInfo.Group)
}
return groupRatioInfo
}
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false)
groupRatio := setting.GetGroupRatio(info.Group)
var userGroupRatio float64
autoGroup, exists := c.Get("auto_group")
if exists {
groupRatio = setting.GetGroupRatio(autoGroup.(string))
log.Printf("final group ratio: %f", groupRatio)
info.Group = autoGroup.(string)
}
actualGroupRatio := groupRatio
userGroupRatio, ok := setting.GetGroupGroupRatio(info.UserGroup, info.Group)
if ok {
actualGroupRatio = userGroupRatio
}
groupRatio = actualGroupRatio
groupRatioInfo := HandleGroupRatio(c, info)
var preConsumedQuota int
var modelRatio float64
var completionRatio float64
@@ -74,18 +96,17 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName)
cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName)
imageRatio, _ = operation_setting.GetImageRatio(info.OriginModelName)
ratio := modelRatio * groupRatio
ratio := modelRatio * groupRatioInfo.GroupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
}
priceData := PriceData{
ModelPrice: modelPrice,
ModelRatio: modelRatio,
CompletionRatio: completionRatio,
GroupRatio: groupRatio,
UserGroupRatio: userGroupRatio,
GroupRatioInfo: groupRatioInfo,
UsePrice: usePrice,
CacheRatio: cacheRatio,
ImageRatio: imageRatio,

View File

@@ -162,7 +162,7 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
// reset model price
priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
quota = int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit)
quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit)
userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)

View File

@@ -361,9 +361,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
cacheRatio := priceData.CacheRatio
imageRatio := priceData.ImageRatio
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatio
groupRatio := priceData.GroupRatioInfo.GroupRatio
modelPrice := priceData.ModelPrice
userGroupRatio := priceData.UserGroupRatio
// Convert values to decimal for precise calculation
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
@@ -511,7 +510,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
if extraContent != "" {
logContent += ", " + extraContent
}
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, userGroupRatio)
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
if imageTokens != 0 {
other["image"] = true
other["image_ratio"] = imageRatio

View File

@@ -6,12 +6,10 @@ import (
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
"one-api/setting/operation_setting"
)
func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
@@ -39,43 +37,14 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
//isModelMapped = true
}
}
//relayInfo.UpstreamModelName = textRequest.Model
modelPrice, getModelPriceSuccess := operation_setting.GetModelPrice(relayInfo.UpstreamModelName, false)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64
var modelRatio float64
//err := service.SensitiveWordsCheck(textRequest)
//if constant.ShouldCheckPromptSensitive() {
// err = checkRequestSensitive(textRequest, relayInfo)
// if err != nil {
// return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
// }
//}
//promptTokens, err := getWssPromptTokens(realtimeEvent, relayInfo)
//// count messages token error 计算promptTokens错误
//if err != nil {
// return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
//}
//
if !getModelPriceSuccess {
preConsumedTokens := common.PreConsumedQuota
//if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
// preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
//}
modelRatio, _ = operation_setting.GetModelRatio(relayInfo.UpstreamModelName)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
relayInfo.UsePrice = true
priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
}
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
@@ -113,6 +82,6 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
return openaiErr
}
service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota,
userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
userQuota, priceData, "")
return nil
}