✨ feat: enhance group ratio handling in pricing calculations
This commit is contained in:
@@ -2,7 +2,6 @@ package helper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"one-api/common"
|
||||
constant2 "one-api/constant"
|
||||
relaycommon "one-api/relay/common"
|
||||
@@ -12,6 +11,11 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type GroupRatioInfo struct {
|
||||
GroupRatio float64
|
||||
GroupSpecialRatio float64
|
||||
}
|
||||
|
||||
type PriceData struct {
|
||||
ModelPrice float64
|
||||
ModelRatio float64
|
||||
@@ -19,32 +23,50 @@ type PriceData struct {
|
||||
CacheRatio float64
|
||||
CacheCreationRatio float64
|
||||
ImageRatio float64
|
||||
GroupRatio float64
|
||||
UserGroupRatio float64
|
||||
UsePrice bool
|
||||
ShouldPreConsumedQuota int
|
||||
GroupRatioInfo GroupRatioInfo
|
||||
}
|
||||
|
||||
func (p PriceData) ToSetting() string {
|
||||
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
|
||||
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
|
||||
}
|
||||
|
||||
// HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.Group if present
|
||||
func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupRatioInfo {
|
||||
groupRatioInfo := GroupRatioInfo{
|
||||
GroupRatio: 1.0, // default ratio
|
||||
GroupSpecialRatio: 1.0, // default user group ratio
|
||||
}
|
||||
|
||||
// check auto group
|
||||
autoGroup, exists := ctx.Get("auto_group")
|
||||
if exists {
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("final group: %s", autoGroup))
|
||||
}
|
||||
relayInfo.Group = autoGroup.(string)
|
||||
}
|
||||
|
||||
// check user group special ratio
|
||||
userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
|
||||
if ok {
|
||||
// user group special ratio
|
||||
groupRatioInfo.GroupSpecialRatio = userGroupRatio
|
||||
groupRatioInfo.GroupRatio = userGroupRatio
|
||||
} else {
|
||||
// normal group ratio
|
||||
groupRatioInfo.GroupRatio = setting.GetGroupRatio(relayInfo.Group)
|
||||
}
|
||||
|
||||
return groupRatioInfo
|
||||
}
|
||||
|
||||
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
|
||||
modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false)
|
||||
groupRatio := setting.GetGroupRatio(info.Group)
|
||||
var userGroupRatio float64
|
||||
autoGroup, exists := c.Get("auto_group")
|
||||
if exists {
|
||||
groupRatio = setting.GetGroupRatio(autoGroup.(string))
|
||||
log.Printf("final group ratio: %f", groupRatio)
|
||||
info.Group = autoGroup.(string)
|
||||
}
|
||||
actualGroupRatio := groupRatio
|
||||
userGroupRatio, ok := setting.GetGroupGroupRatio(info.UserGroup, info.Group)
|
||||
if ok {
|
||||
actualGroupRatio = userGroupRatio
|
||||
}
|
||||
groupRatio = actualGroupRatio
|
||||
|
||||
groupRatioInfo := HandleGroupRatio(c, info)
|
||||
|
||||
var preConsumedQuota int
|
||||
var modelRatio float64
|
||||
var completionRatio float64
|
||||
@@ -74,18 +96,17 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName)
|
||||
cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName)
|
||||
imageRatio, _ = operation_setting.GetImageRatio(info.OriginModelName)
|
||||
ratio := modelRatio * groupRatio
|
||||
ratio := modelRatio * groupRatioInfo.GroupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
|
||||
}
|
||||
|
||||
priceData := PriceData{
|
||||
ModelPrice: modelPrice,
|
||||
ModelRatio: modelRatio,
|
||||
CompletionRatio: completionRatio,
|
||||
GroupRatio: groupRatio,
|
||||
UserGroupRatio: userGroupRatio,
|
||||
GroupRatioInfo: groupRatioInfo,
|
||||
UsePrice: usePrice,
|
||||
CacheRatio: cacheRatio,
|
||||
ImageRatio: imageRatio,
|
||||
|
||||
@@ -162,7 +162,7 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
|
||||
// reset model price
|
||||
priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
|
||||
quota = int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit)
|
||||
quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit)
|
||||
userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
|
||||
@@ -361,9 +361,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
cacheRatio := priceData.CacheRatio
|
||||
imageRatio := priceData.ImageRatio
|
||||
modelRatio := priceData.ModelRatio
|
||||
groupRatio := priceData.GroupRatio
|
||||
groupRatio := priceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := priceData.ModelPrice
|
||||
userGroupRatio := priceData.UserGroupRatio
|
||||
|
||||
// Convert values to decimal for precise calculation
|
||||
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
|
||||
@@ -511,7 +510,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
if extraContent != "" {
|
||||
logContent += ", " + extraContent
|
||||
}
|
||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, userGroupRatio)
|
||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
if imageTokens != 0 {
|
||||
other["image"] = true
|
||||
other["image_ratio"] = imageRatio
|
||||
|
||||
@@ -6,12 +6,10 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
)
|
||||
|
||||
func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
@@ -39,43 +37,14 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
|
||||
//isModelMapped = true
|
||||
}
|
||||
}
|
||||
//relayInfo.UpstreamModelName = textRequest.Model
|
||||
modelPrice, getModelPriceSuccess := operation_setting.GetModelPrice(relayInfo.UpstreamModelName, false)
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
|
||||
var preConsumedQuota int
|
||||
var ratio float64
|
||||
var modelRatio float64
|
||||
//err := service.SensitiveWordsCheck(textRequest)
|
||||
|
||||
//if constant.ShouldCheckPromptSensitive() {
|
||||
// err = checkRequestSensitive(textRequest, relayInfo)
|
||||
// if err != nil {
|
||||
// return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
|
||||
// }
|
||||
//}
|
||||
|
||||
//promptTokens, err := getWssPromptTokens(realtimeEvent, relayInfo)
|
||||
//// count messages token error 计算promptTokens错误
|
||||
//if err != nil {
|
||||
// return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
||||
//}
|
||||
//
|
||||
if !getModelPriceSuccess {
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
//if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
|
||||
// preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
|
||||
//}
|
||||
modelRatio, _ = operation_setting.GetModelRatio(relayInfo.UpstreamModelName)
|
||||
ratio = modelRatio * groupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
relayInfo.UsePrice = true
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
}
|
||||
@@ -113,6 +82,6 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
|
||||
return openaiErr
|
||||
}
|
||||
service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota,
|
||||
userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
|
||||
userQuota, priceData, "")
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user