refactor: Simplify model mapping and pricing logic across relay modules

This commit is contained in:
1808837298@qq.com
2025-02-20 16:41:46 +08:00
parent 60aac77c08
commit 06da65a9d0
13 changed files with 279 additions and 199 deletions

View File

@@ -24,7 +24,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
var err *dto.OpenAIErrorWithStatusCode var err *dto.OpenAIErrorWithStatusCode
switch relayMode { switch relayMode {
case relayconstant.RelayModeImagesGenerations: case relayconstant.RelayModeImagesGenerations:
err = relay.ImageHelper(c, relayMode) err = relay.ImageHelper(c)
case relayconstant.RelayModeAudioSpeech: case relayconstant.RelayModeAudioSpeech:
fallthrough fallthrough
case relayconstant.RelayModeAudioTranslation: case relayconstant.RelayModeAudioTranslation:

View File

@@ -84,7 +84,7 @@ func InitOptionMap() {
common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter) common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) common.OptionMap["ShouldPreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString() common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString() common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
@@ -306,7 +306,7 @@ func updateOptionMap(key string, value string) (err error) {
common.QuotaForInvitee, _ = strconv.Atoi(value) common.QuotaForInvitee, _ = strconv.Atoi(value)
case "QuotaRemindThreshold": case "QuotaRemindThreshold":
common.QuotaRemindThreshold, _ = strconv.Atoi(value) common.QuotaRemindThreshold, _ = strconv.Atoi(value)
case "PreConsumedQuota": case "ShouldPreConsumedQuota":
common.PreConsumedQuota, _ = strconv.Atoi(value) common.PreConsumedQuota, _ = strconv.Atoi(value)
case "RetryTimes": case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value) common.RetryTimes, _ = strconv.Atoi(value)

View File

@@ -13,24 +13,24 @@ import (
) )
type RelayInfo struct { type RelayInfo struct {
ChannelType int ChannelType int
ChannelId int ChannelId int
TokenId int TokenId int
TokenKey string TokenKey string
UserId int UserId int
Group string Group string
TokenUnlimited bool TokenUnlimited bool
StartTime time.Time StartTime time.Time
FirstResponseTime time.Time FirstResponseTime time.Time
setFirstResponse bool setFirstResponse bool
ApiType int ApiType int
IsStream bool IsStream bool
IsPlayground bool IsPlayground bool
UsePrice bool UsePrice bool
RelayMode int RelayMode int
UpstreamModelName string UpstreamModelName string
OriginModelName string OriginModelName string
RecodeModelName string //RecodeModelName string
RequestURLPath string RequestURLPath string
ApiVersion string ApiVersion string
PromptTokens int PromptTokens int
@@ -39,6 +39,7 @@ type RelayInfo struct {
BaseUrl string BaseUrl string
SupportStreamOptions bool SupportStreamOptions bool
ShouldIncludeUsage bool ShouldIncludeUsage bool
IsModelMapped bool
ClientWs *websocket.Conn ClientWs *websocket.Conn
TargetWs *websocket.Conn TargetWs *websocket.Conn
InputAudioFormat string InputAudioFormat string
@@ -50,6 +51,18 @@ type RelayInfo struct {
ChannelSetting map[string]interface{} ChannelSetting map[string]interface{}
} }
// 定义支持流式选项的通道类型
var streamSupportedChannels = map[int]bool{
common.ChannelTypeOpenAI: true,
common.ChannelTypeAnthropic: true,
common.ChannelTypeAws: true,
common.ChannelTypeGemini: true,
common.ChannelCloudflare: true,
common.ChannelTypeAzure: true,
common.ChannelTypeVolcEngine: true,
common.ChannelTypeOllama: true,
}
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
info := GenRelayInfo(c) info := GenRelayInfo(c)
info.ClientWs = ws info.ClientWs = ws
@@ -89,12 +102,13 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
FirstResponseTime: startTime.Add(-time.Second), FirstResponseTime: startTime.Add(-time.Second),
OriginModelName: c.GetString("original_model"), OriginModelName: c.GetString("original_model"),
UpstreamModelName: c.GetString("original_model"), UpstreamModelName: c.GetString("original_model"),
RecodeModelName: c.GetString("recode_model"), //RecodeModelName: c.GetString("original_model"),
ApiType: apiType, IsModelMapped: false,
ApiVersion: c.GetString("api_version"), ApiType: apiType,
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), ApiVersion: c.GetString("api_version"),
Organization: c.GetString("channel_organization"), ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
ChannelSetting: channelSetting, Organization: c.GetString("channel_organization"),
ChannelSetting: channelSetting,
} }
if strings.HasPrefix(c.Request.URL.Path, "/pg") { if strings.HasPrefix(c.Request.URL.Path, "/pg") {
info.IsPlayground = true info.IsPlayground = true
@@ -110,10 +124,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
if info.ChannelType == common.ChannelTypeVertexAi { if info.ChannelType == common.ChannelTypeVertexAi {
info.ApiVersion = c.GetString("region") info.ApiVersion = c.GetString("region")
} }
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic || if streamSupportedChannels[info.ChannelType] {
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
info.ChannelType == common.ChannelCloudflare || info.ChannelType == common.ChannelTypeAzure ||
info.ChannelType == common.ChannelTypeVolcEngine || info.ChannelType == common.ChannelTypeOllama {
info.SupportStreamOptions = true info.SupportStreamOptions = true
} }
return info return info

View File

@@ -0,0 +1,25 @@
package helper
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"one-api/relay/common"
)
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
// map model name
modelMapping := c.GetString("model_mapping")
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return fmt.Errorf("unmarshal_model_mapping_failed")
}
if modelMap[info.OriginModelName] != "" {
info.UpstreamModelName = modelMap[info.OriginModelName]
info.IsModelMapped = true
}
}
return nil
}

41
relay/helper/price.go Normal file
View File

@@ -0,0 +1,41 @@
package helper
import (
"github.com/gin-gonic/gin"
"one-api/common"
relaycommon "one-api/relay/common"
"one-api/setting"
)
type PriceData struct {
ModelPrice float64
ModelRatio float64
GroupRatio float64
UsePrice bool
ShouldPreConsumedQuota int
}
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) PriceData {
modelPrice, usePrice := common.GetModelPrice(info.OriginModelName, false)
groupRatio := setting.GetGroupRatio(info.Group)
var preConsumedQuota int
var modelRatio float64
if !usePrice {
preConsumedTokens := common.PreConsumedQuota
if maxTokens != 0 {
preConsumedTokens = promptTokens + maxTokens
}
modelRatio = common.GetModelRatio(info.OriginModelName)
ratio := modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
return PriceData{
ModelPrice: modelPrice,
ModelRatio: modelRatio,
GroupRatio: groupRatio,
UsePrice: usePrice,
ShouldPreConsumedQuota: preConsumedQuota,
}
}

View File

@@ -1,7 +1,6 @@
package relay package relay
import ( import (
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -11,6 +10,7 @@ import (
"one-api/model" "one-api/model"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service" "one-api/service"
"one-api/setting" "one-api/setting"
) )
@@ -73,15 +73,13 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo.PromptTokens = promptTokens relayInfo.PromptTokens = promptTokens
} }
modelRatio := common.GetModelRatio(audioRequest.Model) priceData := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
userQuota, err := model.GetUserQuota(relayInfo.UserId, false) userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
} }
preConsumedQuota, userQuota, openaiErr = preConsumeQuota(c, preConsumedQuota, relayInfo) preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil { if openaiErr != nil {
return openaiErr return openaiErr
} }
@@ -91,19 +89,12 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
} }
}() }()
// map model name err = helper.ModelMappedHelper(c, relayInfo)
modelMapping := c.GetString("model_mapping") if err != nil {
if modelMapping != "" { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[audioRequest.Model] != "" {
audioRequest.Model = modelMap[audioRequest.Model]
}
} }
relayInfo.UpstreamModelName = audioRequest.Model
audioRequest.Model = relayInfo.UpstreamModelName
adaptor := GetAdaptor(relayInfo.ApiType) adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil { if adaptor == nil {
@@ -140,7 +131,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
return openaiErr return openaiErr
} }
postConsumeQuota(c, relayInfo, audioRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "") postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil return nil
} }

View File

@@ -12,6 +12,7 @@ import (
"one-api/dto" "one-api/dto"
"one-api/model" "one-api/model"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service" "one-api/service"
"one-api/setting" "one-api/setting"
"strings" "strings"
@@ -68,7 +69,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
return imageRequest, nil return imageRequest, nil
} }
func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
relayInfo := relaycommon.GenRelayInfo(c) relayInfo := relaycommon.GenRelayInfo(c)
imageRequest, err := getAndValidImageRequest(c, relayInfo) imageRequest, err := getAndValidImageRequest(c, relayInfo)
@@ -77,19 +78,12 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
} }
// map model name err = helper.ModelMappedHelper(c, relayInfo)
modelMapping := c.GetString("model_mapping") if err != nil {
if modelMapping != "" { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[imageRequest.Model] != "" {
imageRequest.Model = modelMap[imageRequest.Model]
}
} }
relayInfo.UpstreamModelName = imageRequest.Model
imageRequest.Model = relayInfo.UpstreamModelName
modelPrice, success := common.GetModelPrice(imageRequest.Model, true) modelPrice, success := common.GetModelPrice(imageRequest.Model, true)
if !success { if !success {
@@ -183,8 +177,15 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
quality = "hd" quality = "hd"
} }
logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality) priceData := helper.PriceData{
postConsumeQuota(c, relayInfo, imageRequest.Model, usage, 0, 0, userQuota, 0, groupRatio, imageRatio, true, logContent) UsePrice: true,
GroupRatio: groupRatio,
ModelPrice: modelPrice,
ModelRatio: 0,
ShouldPreConsumedQuota: 0,
}
logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
postConsumeQuota(c, relayInfo, usage, 0, userQuota, priceData, logContent)
return nil return nil
} }

View File

@@ -15,6 +15,7 @@ import (
"one-api/model" "one-api/model"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service" "one-api/service"
"one-api/setting" "one-api/setting"
"strings" "strings"
@@ -76,33 +77,6 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
} }
// map model name
//isModelMapped := false
modelMapping := c.GetString("model_mapping")
//isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[textRequest.Model] != "" {
//isModelMapped = true
textRequest.Model = modelMap[textRequest.Model]
// set upstream model name
//isModelMapped = true
}
}
relayInfo.UpstreamModelName = textRequest.Model
relayInfo.RecodeModelName = textRequest.Model
modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64
var modelRatio float64
//err := service.SensitiveWordsCheck(textRequest)
if setting.ShouldCheckPromptSensitive() { if setting.ShouldCheckPromptSensitive() {
err = checkRequestSensitive(textRequest, relayInfo) err = checkRequestSensitive(textRequest, relayInfo)
if err != nil { if err != nil {
@@ -110,6 +84,13 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
} }
} }
err = helper.ModelMappedHelper(c, relayInfo)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
textRequest.Model = relayInfo.UpstreamModelName
// 获取 promptTokens如果上下文中已经存在则直接使用 // 获取 promptTokens如果上下文中已经存在则直接使用
var promptTokens int var promptTokens int
if value, exists := c.Get("prompt_tokens"); exists { if value, exists := c.Get("prompt_tokens"); exists {
@@ -124,20 +105,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
c.Set("prompt_tokens", promptTokens) c.Set("prompt_tokens", promptTokens)
} }
if !getModelPriceSuccess { priceData := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
preConsumedTokens := common.PreConsumedQuota
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
}
modelRatio = common.GetModelRatio(textRequest.Model)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
// pre-consume quota 预消耗配额 // pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo) preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil { if openaiErr != nil {
return openaiErr return openaiErr
} }
@@ -220,10 +191,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
return openaiErr return openaiErr
} }
if strings.HasPrefix(relayInfo.RecodeModelName, "gpt-4o-audio") { if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
} else { } else {
postConsumeQuota(c, relayInfo, relayInfo.RecodeModelName, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
} }
return nil return nil
} }
@@ -319,9 +290,8 @@ func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, us
} }
} }
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
modelPrice float64, usePrice bool, extraContent string) {
if usage == nil { if usage == nil {
usage = &dto.Usage{ usage = &dto.Usage{
PromptTokens: relayInfo.PromptTokens, PromptTokens: relayInfo.PromptTokens,
@@ -333,12 +303,18 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens completionTokens := usage.CompletionTokens
modelName := relayInfo.OriginModelName
tokenName := ctx.GetString("token_name") tokenName := ctx.GetString("token_name")
completionRatio := common.GetCompletionRatio(modelName) completionRatio := common.GetCompletionRatio(modelName)
ratio := priceData.ModelRatio * priceData.GroupRatio
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatio
modelPrice := priceData.ModelPrice
usePrice := priceData.UsePrice
quota := 0 quota := 0
if !usePrice { if !priceData.UsePrice {
quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio)) quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio))
quota = int(math.Round(float64(quota) * ratio)) quota = int(math.Round(float64(quota) * ratio))
if ratio != 0 && quota <= 0 { if ratio != 0 && quota <= 0 {

View File

@@ -10,8 +10,8 @@ import (
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service" "one-api/service"
"one-api/setting"
) )
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int { func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
@@ -47,43 +47,20 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
} }
// map model name err = helper.ModelMappedHelper(c, relayInfo)
modelMapping := c.GetString("model_mapping") if err != nil {
//isModelMapped := false return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[embeddingRequest.Model] != "" {
embeddingRequest.Model = modelMap[embeddingRequest.Model]
// set upstream model name
//isModelMapped = true
}
} }
relayInfo.UpstreamModelName = embeddingRequest.Model embeddingRequest.Model = relayInfo.UpstreamModelName
modelPrice, success := common.GetModelPrice(embeddingRequest.Model, false)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64
var modelRatio float64
promptToken := getEmbeddingPromptToken(*embeddingRequest) promptToken := getEmbeddingPromptToken(*embeddingRequest)
if !success {
preConsumedTokens := promptToken
modelRatio = common.GetModelRatio(embeddingRequest.Model)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
relayInfo.PromptTokens = promptToken relayInfo.PromptTokens = promptToken
priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
// pre-consume quota 预消耗配额 // pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo) preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil { if openaiErr != nil {
return openaiErr return openaiErr
} }
@@ -132,6 +109,6 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr
} }
postConsumeQuota(c, relayInfo, embeddingRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "") postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil return nil
} }

View File

@@ -9,8 +9,8 @@ import (
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service" "one-api/service"
"one-api/setting"
) )
func getRerankPromptToken(rerankRequest dto.RerankRequest) int { func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
@@ -40,43 +40,20 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
} }
// map model name err = helper.ModelMappedHelper(c, relayInfo)
modelMapping := c.GetString("model_mapping") if err != nil {
//isModelMapped := false return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[rerankRequest.Model] != "" {
rerankRequest.Model = modelMap[rerankRequest.Model]
// set upstream model name
//isModelMapped = true
}
} }
relayInfo.UpstreamModelName = rerankRequest.Model rerankRequest.Model = relayInfo.UpstreamModelName
modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64
var modelRatio float64
promptToken := getRerankPromptToken(*rerankRequest) promptToken := getRerankPromptToken(*rerankRequest)
if !success {
preConsumedTokens := promptToken
modelRatio = common.GetModelRatio(rerankRequest.Model)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
relayInfo.PromptTokens = promptToken relayInfo.PromptTokens = promptToken
priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
// pre-consume quota 预消耗配额 // pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo) preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil { if openaiErr != nil {
return openaiErr return openaiErr
} }
@@ -124,6 +101,6 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr
} }
postConsumeQuota(c, relayInfo, rerankRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "") postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil return nil
} }

View File

@@ -16,6 +16,10 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
if relayInfo.ReasoningEffort != "" { if relayInfo.ReasoningEffort != "" {
other["reasoning_effort"] = relayInfo.ReasoningEffort other["reasoning_effort"] = relayInfo.ReasoningEffort
} }
if relayInfo.IsModelMapped {
other["is_model_mapped"] = true
other["upstream_model_name"] = relayInfo.UpstreamModelName
}
adminInfo := make(map[string]interface{}) adminInfo := make(map[string]interface{})
adminInfo["use_channel"] = ctx.GetStringSlice("use_channel") adminInfo["use_channel"] = ctx.GetStringSlice("use_channel")
other["admin_info"] = adminInfo other["admin_info"] = adminInfo

View File

@@ -10,6 +10,7 @@ import (
"one-api/dto" "one-api/dto"
"one-api/model" "one-api/model"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/setting" "one-api/setting"
"strings" "strings"
"time" "time"
@@ -68,7 +69,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
return err return err
} }
modelName := relayInfo.UpstreamModelName modelName := relayInfo.OriginModelName
textInputTokens := usage.InputTokenDetails.TextTokens textInputTokens := usage.InputTokenDetails.TextTokens
textOutTokens := usage.OutputTokenDetails.TextTokens textOutTokens := usage.OutputTokenDetails.TextTokens
audioInputTokens := usage.InputTokenDetails.AudioTokens audioInputTokens := usage.InputTokenDetails.AudioTokens
@@ -122,7 +123,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
tokenName := ctx.GetString("token_name") tokenName := ctx.GetString("token_name")
completionRatio := common.GetCompletionRatio(modelName) completionRatio := common.GetCompletionRatio(modelName)
audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName) audioRatio := common.GetAudioRatio(relayInfo.OriginModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(modelName) audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
quotaInfo := QuotaInfo{ quotaInfo := QuotaInfo{
@@ -173,8 +174,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
} }
func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
usage *dto.Usage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
modelPrice float64, usePrice bool, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
textInputTokens := usage.PromptTokensDetails.TextTokens textInputTokens := usage.PromptTokensDetails.TextTokens
@@ -184,9 +184,14 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
audioOutTokens := usage.CompletionTokenDetails.AudioTokens audioOutTokens := usage.CompletionTokenDetails.AudioTokens
tokenName := ctx.GetString("token_name") tokenName := ctx.GetString("token_name")
completionRatio := common.GetCompletionRatio(relayInfo.RecodeModelName) completionRatio := common.GetCompletionRatio(relayInfo.OriginModelName)
audioRatio := common.GetAudioRatio(relayInfo.RecodeModelName) audioRatio := common.GetAudioRatio(relayInfo.OriginModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.RecodeModelName) audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.OriginModelName)
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatio
modelPrice := priceData.ModelPrice
usePrice := priceData.UsePrice
quotaInfo := QuotaInfo{ quotaInfo := QuotaInfo{
InputDetails: TokenDetails{ InputDetails: TokenDetails{
@@ -197,7 +202,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
TextTokens: textOutTokens, TextTokens: textOutTokens,
AudioTokens: audioOutTokens, AudioTokens: audioOutTokens,
}, },
ModelName: relayInfo.RecodeModelName, ModelName: relayInfo.OriginModelName,
UsePrice: usePrice, UsePrice: usePrice,
ModelRatio: modelRatio, ModelRatio: modelRatio,
GroupRatio: groupRatio, GroupRatio: groupRatio,
@@ -220,7 +225,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
quota = 0 quota = 0
logContent += fmt.Sprintf("(可能是上游超时)") logContent += fmt.Sprintf("(可能是上游超时)")
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.RecodeModelName, preConsumedQuota)) "tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota))
} else { } else {
quotaDelta := quota - preConsumedQuota quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 { if quotaDelta != 0 {
@@ -233,7 +238,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
} }
logModel := relayInfo.RecodeModelName logModel := relayInfo.OriginModelName
if extraContent != "" { if extraContent != "" {
logContent += ", " + extraContent logContent += ", " + extraContent
} }

View File

@@ -15,7 +15,7 @@ import {
Button, Descriptions, Button, Descriptions,
Form, Form,
Layout, Layout,
Modal, Modal, Popover,
Select, Select,
Space, Space,
Spin, Spin,
@@ -34,6 +34,7 @@ import {
import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph'; import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph';
import { getLogOther } from '../helpers/other.js'; import { getLogOther } from '../helpers/other.js';
import { StyleContext } from '../context/Style/index.js'; import { StyleContext } from '../context/Style/index.js';
import { IconInherit, IconRefresh } from '@douyinfe/semi-icons';
const { Header } = Layout; const { Header } = Layout;
@@ -141,7 +142,78 @@ const LogsTable = () => {
</Tag> </Tag>
); );
} }
} }
function renderModelName(record) {
let other = getLogOther(record.other);
let modelMapped = other?.is_model_mapped && other?.upstream_model_name && other?.upstream_model_name !== '';
if (!modelMapped) {
return <Tag
color={stringToColor(record.model_name)}
size='large'
onClick={(event) => {
copyText(event, record.model_name).then(r => {});
}}
>
{' '}{record.model_name}{' '}
</Tag>;
} else {
return (
<>
<Space vertical align={'start'}>
<Popover content={
<div style={{padding: 10}}>
<Space vertical align={'start'}>
<Tag
color={stringToColor(record.model_name)}
size='large'
onClick={(event) => {
copyText(event, record.model_name).then(r => {});
}}
>
{t('请求并计费模型')}{' '}{record.model_name}{' '}
</Tag>
<Tag
color={stringToColor(other.upstream_model_name)}
size='large'
onClick={(event) => {
copyText(event, other.upstream_model_name).then(r => {});
}}
>
{t('实际模型')}{' '}{other.upstream_model_name}{' '}
</Tag>
</Space>
</div>
}>
<Tag
color={stringToColor(record.model_name)}
size='large'
onClick={(event) => {
copyText(event, record.model_name).then(r => {});
}}
suffixIcon={<IconRefresh />}
>
{' '}{record.model_name}{' '}
</Tag>
</Popover>
{/*<Tooltip content={t('实际模型')}>*/}
{/* <Tag*/}
{/* color={stringToColor(other.upstream_model_name)}*/}
{/* size='large'*/}
{/* onClick={(event) => {*/}
{/* copyText(event, other.upstream_model_name).then(r => {});*/}
{/* }}*/}
{/* >*/}
{/* {' '}{other.upstream_model_name}{' '}*/}
{/* </Tag>*/}
{/*</Tooltip>*/}
</Space>
</>
);
}
}
const columns = [ const columns = [
{ {
@@ -272,18 +344,7 @@ const LogsTable = () => {
dataIndex: 'model_name', dataIndex: 'model_name',
render: (text, record, index) => { render: (text, record, index) => {
return record.type === 0 || record.type === 2 ? ( return record.type === 0 || record.type === 2 ? (
<> <>{renderModelName(record)}</>
<Tag
color={stringToColor(text)}
size='large'
onClick={(event) => {
copyText(event, text);
}}
>
{' '}
{text}{' '}
</Tag>
</>
) : ( ) : (
<></> <></>
); );
@@ -580,6 +641,17 @@ const LogsTable = () => {
value: logs[i].content, value: logs[i].content,
}); });
if (logs[i].type === 2) { if (logs[i].type === 2) {
let modelMapped = other?.is_model_mapped && other?.upstream_model_name && other?.upstream_model_name !== '';
if (modelMapped) {
expandDataLocal.push({
key: t('请求并计费模型'),
value: logs[i].model_name,
});
expandDataLocal.push({
key: t('实际模型'),
value: other.upstream_model_name,
});
}
let content = ''; let content = '';
if (other?.ws || other?.audio) { if (other?.ws || other?.audio) {
content = renderAudioModelPrice( content = renderAudioModelPrice(