refactor: Simplify model mapping and pricing logic across relay modules

This commit is contained in:
1808837298@qq.com
2025-02-20 16:41:46 +08:00
parent 60aac77c08
commit 06da65a9d0
13 changed files with 279 additions and 199 deletions

View File

@@ -24,7 +24,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
var err *dto.OpenAIErrorWithStatusCode
switch relayMode {
case relayconstant.RelayModeImagesGenerations:
err = relay.ImageHelper(c, relayMode)
err = relay.ImageHelper(c)
case relayconstant.RelayModeAudioSpeech:
fallthrough
case relayconstant.RelayModeAudioTranslation:

View File

@@ -84,7 +84,7 @@ func InitOptionMap() {
common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ShouldPreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
@@ -306,7 +306,7 @@ func updateOptionMap(key string, value string) (err error) {
common.QuotaForInvitee, _ = strconv.Atoi(value)
case "QuotaRemindThreshold":
common.QuotaRemindThreshold, _ = strconv.Atoi(value)
case "PreConsumedQuota":
case "ShouldPreConsumedQuota":
common.PreConsumedQuota, _ = strconv.Atoi(value)
case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value)

View File

@@ -13,24 +13,24 @@ import (
)
type RelayInfo struct {
ChannelType int
ChannelId int
TokenId int
TokenKey string
UserId int
Group string
TokenUnlimited bool
StartTime time.Time
FirstResponseTime time.Time
setFirstResponse bool
ApiType int
IsStream bool
IsPlayground bool
UsePrice bool
RelayMode int
UpstreamModelName string
OriginModelName string
RecodeModelName string
ChannelType int
ChannelId int
TokenId int
TokenKey string
UserId int
Group string
TokenUnlimited bool
StartTime time.Time
FirstResponseTime time.Time
setFirstResponse bool
ApiType int
IsStream bool
IsPlayground bool
UsePrice bool
RelayMode int
UpstreamModelName string
OriginModelName string
//RecodeModelName string
RequestURLPath string
ApiVersion string
PromptTokens int
@@ -39,6 +39,7 @@ type RelayInfo struct {
BaseUrl string
SupportStreamOptions bool
ShouldIncludeUsage bool
IsModelMapped bool
ClientWs *websocket.Conn
TargetWs *websocket.Conn
InputAudioFormat string
@@ -50,6 +51,18 @@ type RelayInfo struct {
ChannelSetting map[string]interface{}
}
// 定义支持流式选项的通道类型
var streamSupportedChannels = map[int]bool{
common.ChannelTypeOpenAI: true,
common.ChannelTypeAnthropic: true,
common.ChannelTypeAws: true,
common.ChannelTypeGemini: true,
common.ChannelCloudflare: true,
common.ChannelTypeAzure: true,
common.ChannelTypeVolcEngine: true,
common.ChannelTypeOllama: true,
}
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
info := GenRelayInfo(c)
info.ClientWs = ws
@@ -89,12 +102,13 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
FirstResponseTime: startTime.Add(-time.Second),
OriginModelName: c.GetString("original_model"),
UpstreamModelName: c.GetString("original_model"),
RecodeModelName: c.GetString("recode_model"),
ApiType: apiType,
ApiVersion: c.GetString("api_version"),
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Organization: c.GetString("channel_organization"),
ChannelSetting: channelSetting,
//RecodeModelName: c.GetString("original_model"),
IsModelMapped: false,
ApiType: apiType,
ApiVersion: c.GetString("api_version"),
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Organization: c.GetString("channel_organization"),
ChannelSetting: channelSetting,
}
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
info.IsPlayground = true
@@ -110,10 +124,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
if info.ChannelType == common.ChannelTypeVertexAi {
info.ApiVersion = c.GetString("region")
}
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
info.ChannelType == common.ChannelCloudflare || info.ChannelType == common.ChannelTypeAzure ||
info.ChannelType == common.ChannelTypeVolcEngine || info.ChannelType == common.ChannelTypeOllama {
if streamSupportedChannels[info.ChannelType] {
info.SupportStreamOptions = true
}
return info

View File

@@ -0,0 +1,25 @@
package helper
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"one-api/relay/common"
)
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
// map model name
modelMapping := c.GetString("model_mapping")
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return fmt.Errorf("unmarshal_model_mapping_failed")
}
if modelMap[info.OriginModelName] != "" {
info.UpstreamModelName = modelMap[info.OriginModelName]
info.IsModelMapped = true
}
}
return nil
}

41
relay/helper/price.go Normal file
View File

@@ -0,0 +1,41 @@
package helper
import (
"github.com/gin-gonic/gin"
"one-api/common"
relaycommon "one-api/relay/common"
"one-api/setting"
)
type PriceData struct {
ModelPrice float64
ModelRatio float64
GroupRatio float64
UsePrice bool
ShouldPreConsumedQuota int
}
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) PriceData {
modelPrice, usePrice := common.GetModelPrice(info.OriginModelName, false)
groupRatio := setting.GetGroupRatio(info.Group)
var preConsumedQuota int
var modelRatio float64
if !usePrice {
preConsumedTokens := common.PreConsumedQuota
if maxTokens != 0 {
preConsumedTokens = promptTokens + maxTokens
}
modelRatio = common.GetModelRatio(info.OriginModelName)
ratio := modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
return PriceData{
ModelPrice: modelPrice,
ModelRatio: modelRatio,
GroupRatio: groupRatio,
UsePrice: usePrice,
ShouldPreConsumedQuota: preConsumedQuota,
}
}

View File

@@ -1,7 +1,6 @@
package relay
import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
@@ -11,6 +10,7 @@ import (
"one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
)
@@ -73,15 +73,13 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo.PromptTokens = promptTokens
}
modelRatio := common.GetModelRatio(audioRequest.Model)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
priceData := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
}
preConsumedQuota, userQuota, openaiErr = preConsumeQuota(c, preConsumedQuota, relayInfo)
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
@@ -91,19 +89,12 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
}
}()
// map model name
modelMapping := c.GetString("model_mapping")
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[audioRequest.Model] != "" {
audioRequest.Model = modelMap[audioRequest.Model]
}
err = helper.ModelMappedHelper(c, relayInfo)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
relayInfo.UpstreamModelName = audioRequest.Model
audioRequest.Model = relayInfo.UpstreamModelName
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
@@ -140,7 +131,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
return openaiErr
}
postConsumeQuota(c, relayInfo, audioRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "")
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil
}

View File

@@ -12,6 +12,7 @@ import (
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
"strings"
@@ -68,7 +69,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
return imageRequest, nil
}
func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
relayInfo := relaycommon.GenRelayInfo(c)
imageRequest, err := getAndValidImageRequest(c, relayInfo)
@@ -77,19 +78,12 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
}
// map model name
modelMapping := c.GetString("model_mapping")
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[imageRequest.Model] != "" {
imageRequest.Model = modelMap[imageRequest.Model]
}
err = helper.ModelMappedHelper(c, relayInfo)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
relayInfo.UpstreamModelName = imageRequest.Model
imageRequest.Model = relayInfo.UpstreamModelName
modelPrice, success := common.GetModelPrice(imageRequest.Model, true)
if !success {
@@ -183,8 +177,15 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
quality = "hd"
}
logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
postConsumeQuota(c, relayInfo, imageRequest.Model, usage, 0, 0, userQuota, 0, groupRatio, imageRatio, true, logContent)
priceData := helper.PriceData{
UsePrice: true,
GroupRatio: groupRatio,
ModelPrice: modelPrice,
ModelRatio: 0,
ShouldPreConsumedQuota: 0,
}
logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
postConsumeQuota(c, relayInfo, usage, 0, userQuota, priceData, logContent)
return nil
}

View File

@@ -15,6 +15,7 @@ import (
"one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
"strings"
@@ -76,33 +77,6 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
}
// map model name
//isModelMapped := false
modelMapping := c.GetString("model_mapping")
//isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[textRequest.Model] != "" {
//isModelMapped = true
textRequest.Model = modelMap[textRequest.Model]
// set upstream model name
//isModelMapped = true
}
}
relayInfo.UpstreamModelName = textRequest.Model
relayInfo.RecodeModelName = textRequest.Model
modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64
var modelRatio float64
//err := service.SensitiveWordsCheck(textRequest)
if setting.ShouldCheckPromptSensitive() {
err = checkRequestSensitive(textRequest, relayInfo)
if err != nil {
@@ -110,6 +84,13 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
}
}
err = helper.ModelMappedHelper(c, relayInfo)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
textRequest.Model = relayInfo.UpstreamModelName
// 获取 promptTokens如果上下文中已经存在则直接使用
var promptTokens int
if value, exists := c.Get("prompt_tokens"); exists {
@@ -124,20 +105,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
c.Set("prompt_tokens", promptTokens)
}
if !getModelPriceSuccess {
preConsumedTokens := common.PreConsumedQuota
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
}
modelRatio = common.GetModelRatio(textRequest.Model)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
priceData := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
@@ -220,10 +191,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
return openaiErr
}
if strings.HasPrefix(relayInfo.RecodeModelName, "gpt-4o-audio") {
service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
} else {
postConsumeQuota(c, relayInfo, relayInfo.RecodeModelName, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
}
return nil
}
@@ -319,9 +290,8 @@ func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, us
}
}
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
modelPrice float64, usePrice bool, extraContent string) {
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
if usage == nil {
usage = &dto.Usage{
PromptTokens: relayInfo.PromptTokens,
@@ -333,12 +303,18 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
modelName := relayInfo.OriginModelName
tokenName := ctx.GetString("token_name")
completionRatio := common.GetCompletionRatio(modelName)
ratio := priceData.ModelRatio * priceData.GroupRatio
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatio
modelPrice := priceData.ModelPrice
usePrice := priceData.UsePrice
quota := 0
if !usePrice {
if !priceData.UsePrice {
quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio))
quota = int(math.Round(float64(quota) * ratio))
if ratio != 0 && quota <= 0 {

View File

@@ -10,8 +10,8 @@ import (
"one-api/dto"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
)
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
@@ -47,43 +47,20 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
}
// map model name
modelMapping := c.GetString("model_mapping")
//isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[embeddingRequest.Model] != "" {
embeddingRequest.Model = modelMap[embeddingRequest.Model]
// set upstream model name
//isModelMapped = true
}
err = helper.ModelMappedHelper(c, relayInfo)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
relayInfo.UpstreamModelName = embeddingRequest.Model
modelPrice, success := common.GetModelPrice(embeddingRequest.Model, false)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64
var modelRatio float64
embeddingRequest.Model = relayInfo.UpstreamModelName
promptToken := getEmbeddingPromptToken(*embeddingRequest)
if !success {
preConsumedTokens := promptToken
modelRatio = common.GetModelRatio(embeddingRequest.Model)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
relayInfo.PromptTokens = promptToken
priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
@@ -132,6 +109,6 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
postConsumeQuota(c, relayInfo, embeddingRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil
}

View File

@@ -9,8 +9,8 @@ import (
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
)
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
@@ -40,43 +40,20 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
}
// map model name
modelMapping := c.GetString("model_mapping")
//isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[rerankRequest.Model] != "" {
rerankRequest.Model = modelMap[rerankRequest.Model]
// set upstream model name
//isModelMapped = true
}
err = helper.ModelMappedHelper(c, relayInfo)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
relayInfo.UpstreamModelName = rerankRequest.Model
modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64
var modelRatio float64
rerankRequest.Model = relayInfo.UpstreamModelName
promptToken := getRerankPromptToken(*rerankRequest)
if !success {
preConsumedTokens := promptToken
modelRatio = common.GetModelRatio(rerankRequest.Model)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
relayInfo.PromptTokens = promptToken
priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
@@ -124,6 +101,6 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
postConsumeQuota(c, relayInfo, rerankRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil
}

View File

@@ -16,6 +16,10 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
if relayInfo.ReasoningEffort != "" {
other["reasoning_effort"] = relayInfo.ReasoningEffort
}
if relayInfo.IsModelMapped {
other["is_model_mapped"] = true
other["upstream_model_name"] = relayInfo.UpstreamModelName
}
adminInfo := make(map[string]interface{})
adminInfo["use_channel"] = ctx.GetStringSlice("use_channel")
other["admin_info"] = adminInfo

View File

@@ -10,6 +10,7 @@ import (
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/setting"
"strings"
"time"
@@ -68,7 +69,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
return err
}
modelName := relayInfo.UpstreamModelName
modelName := relayInfo.OriginModelName
textInputTokens := usage.InputTokenDetails.TextTokens
textOutTokens := usage.OutputTokenDetails.TextTokens
audioInputTokens := usage.InputTokenDetails.AudioTokens
@@ -122,7 +123,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
tokenName := ctx.GetString("token_name")
completionRatio := common.GetCompletionRatio(modelName)
audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
audioRatio := common.GetAudioRatio(relayInfo.OriginModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
quotaInfo := QuotaInfo{
@@ -173,8 +174,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
}
func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
usage *dto.Usage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
modelPrice float64, usePrice bool, extraContent string) {
usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
textInputTokens := usage.PromptTokensDetails.TextTokens
@@ -184,9 +184,14 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
audioOutTokens := usage.CompletionTokenDetails.AudioTokens
tokenName := ctx.GetString("token_name")
completionRatio := common.GetCompletionRatio(relayInfo.RecodeModelName)
audioRatio := common.GetAudioRatio(relayInfo.RecodeModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.RecodeModelName)
completionRatio := common.GetCompletionRatio(relayInfo.OriginModelName)
audioRatio := common.GetAudioRatio(relayInfo.OriginModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.OriginModelName)
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatio
modelPrice := priceData.ModelPrice
usePrice := priceData.UsePrice
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
@@ -197,7 +202,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
TextTokens: textOutTokens,
AudioTokens: audioOutTokens,
},
ModelName: relayInfo.RecodeModelName,
ModelName: relayInfo.OriginModelName,
UsePrice: usePrice,
ModelRatio: modelRatio,
GroupRatio: groupRatio,
@@ -220,7 +225,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
quota = 0
logContent += fmt.Sprintf("(可能是上游超时)")
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.RecodeModelName, preConsumedQuota))
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota))
} else {
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
@@ -233,7 +238,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
logModel := relayInfo.RecodeModelName
logModel := relayInfo.OriginModelName
if extraContent != "" {
logContent += ", " + extraContent
}

View File

@@ -15,7 +15,7 @@ import {
Button, Descriptions,
Form,
Layout,
Modal,
Modal, Popover,
Select,
Space,
Spin,
@@ -34,6 +34,7 @@ import {
import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph';
import { getLogOther } from '../helpers/other.js';
import { StyleContext } from '../context/Style/index.js';
import { IconInherit, IconRefresh } from '@douyinfe/semi-icons';
const { Header } = Layout;
@@ -141,7 +142,78 @@ const LogsTable = () => {
</Tag>
);
}
}
}
function renderModelName(record) {
let other = getLogOther(record.other);
let modelMapped = other?.is_model_mapped && other?.upstream_model_name && other?.upstream_model_name !== '';
if (!modelMapped) {
return <Tag
color={stringToColor(record.model_name)}
size='large'
onClick={(event) => {
copyText(event, record.model_name).then(r => {});
}}
>
{' '}{record.model_name}{' '}
</Tag>;
} else {
return (
<>
<Space vertical align={'start'}>
<Popover content={
<div style={{padding: 10}}>
<Space vertical align={'start'}>
<Tag
color={stringToColor(record.model_name)}
size='large'
onClick={(event) => {
copyText(event, record.model_name).then(r => {});
}}
>
{t('请求并计费模型')}{' '}{record.model_name}{' '}
</Tag>
<Tag
color={stringToColor(other.upstream_model_name)}
size='large'
onClick={(event) => {
copyText(event, other.upstream_model_name).then(r => {});
}}
>
{t('实际模型')}{' '}{other.upstream_model_name}{' '}
</Tag>
</Space>
</div>
}>
<Tag
color={stringToColor(record.model_name)}
size='large'
onClick={(event) => {
copyText(event, record.model_name).then(r => {});
}}
suffixIcon={<IconRefresh />}
>
{' '}{record.model_name}{' '}
</Tag>
</Popover>
{/*<Tooltip content={t('实际模型')}>*/}
{/* <Tag*/}
{/* color={stringToColor(other.upstream_model_name)}*/}
{/* size='large'*/}
{/* onClick={(event) => {*/}
{/* copyText(event, other.upstream_model_name).then(r => {});*/}
{/* }}*/}
{/* >*/}
{/* {' '}{other.upstream_model_name}{' '}*/}
{/* </Tag>*/}
{/*</Tooltip>*/}
</Space>
</>
);
}
}
const columns = [
{
@@ -272,18 +344,7 @@ const LogsTable = () => {
dataIndex: 'model_name',
render: (text, record, index) => {
return record.type === 0 || record.type === 2 ? (
<>
<Tag
color={stringToColor(text)}
size='large'
onClick={(event) => {
copyText(event, text);
}}
>
{' '}
{text}{' '}
</Tag>
</>
<>{renderModelName(record)}</>
) : (
<></>
);
@@ -580,6 +641,17 @@ const LogsTable = () => {
value: logs[i].content,
});
if (logs[i].type === 2) {
let modelMapped = other?.is_model_mapped && other?.upstream_model_name && other?.upstream_model_name !== '';
if (modelMapped) {
expandDataLocal.push({
key: t('请求并计费模型'),
value: logs[i].model_name,
});
expandDataLocal.push({
key: t('实际模型'),
value: other.upstream_model_name,
});
}
let content = '';
if (other?.ws || other?.audio) {
content = renderAudioModelPrice(