refactor: Simplify model mapping and pricing logic across relay modules

This commit is contained in:
1808837298@qq.com
2025-02-20 16:41:46 +08:00
parent 60aac77c08
commit 06da65a9d0
13 changed files with 279 additions and 199 deletions

View File

@@ -24,7 +24,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
var err *dto.OpenAIErrorWithStatusCode var err *dto.OpenAIErrorWithStatusCode
switch relayMode { switch relayMode {
case relayconstant.RelayModeImagesGenerations: case relayconstant.RelayModeImagesGenerations:
err = relay.ImageHelper(c, relayMode) err = relay.ImageHelper(c)
case relayconstant.RelayModeAudioSpeech: case relayconstant.RelayModeAudioSpeech:
fallthrough fallthrough
case relayconstant.RelayModeAudioTranslation: case relayconstant.RelayModeAudioTranslation:

View File

@@ -84,7 +84,7 @@ func InitOptionMap() {
common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter) common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) common.OptionMap["ShouldPreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString() common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString() common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
@@ -306,7 +306,7 @@ func updateOptionMap(key string, value string) (err error) {
common.QuotaForInvitee, _ = strconv.Atoi(value) common.QuotaForInvitee, _ = strconv.Atoi(value)
case "QuotaRemindThreshold": case "QuotaRemindThreshold":
common.QuotaRemindThreshold, _ = strconv.Atoi(value) common.QuotaRemindThreshold, _ = strconv.Atoi(value)
case "PreConsumedQuota": case "ShouldPreConsumedQuota":
common.PreConsumedQuota, _ = strconv.Atoi(value) common.PreConsumedQuota, _ = strconv.Atoi(value)
case "RetryTimes": case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value) common.RetryTimes, _ = strconv.Atoi(value)

View File

@@ -30,7 +30,7 @@ type RelayInfo struct {
RelayMode int RelayMode int
UpstreamModelName string UpstreamModelName string
OriginModelName string OriginModelName string
RecodeModelName string //RecodeModelName string
RequestURLPath string RequestURLPath string
ApiVersion string ApiVersion string
PromptTokens int PromptTokens int
@@ -39,6 +39,7 @@ type RelayInfo struct {
BaseUrl string BaseUrl string
SupportStreamOptions bool SupportStreamOptions bool
ShouldIncludeUsage bool ShouldIncludeUsage bool
IsModelMapped bool
ClientWs *websocket.Conn ClientWs *websocket.Conn
TargetWs *websocket.Conn TargetWs *websocket.Conn
InputAudioFormat string InputAudioFormat string
@@ -50,6 +51,18 @@ type RelayInfo struct {
ChannelSetting map[string]interface{} ChannelSetting map[string]interface{}
} }
// 定义支持流式选项的通道类型
var streamSupportedChannels = map[int]bool{
common.ChannelTypeOpenAI: true,
common.ChannelTypeAnthropic: true,
common.ChannelTypeAws: true,
common.ChannelTypeGemini: true,
common.ChannelCloudflare: true,
common.ChannelTypeAzure: true,
common.ChannelTypeVolcEngine: true,
common.ChannelTypeOllama: true,
}
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
info := GenRelayInfo(c) info := GenRelayInfo(c)
info.ClientWs = ws info.ClientWs = ws
@@ -89,7 +102,8 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
FirstResponseTime: startTime.Add(-time.Second), FirstResponseTime: startTime.Add(-time.Second),
OriginModelName: c.GetString("original_model"), OriginModelName: c.GetString("original_model"),
UpstreamModelName: c.GetString("original_model"), UpstreamModelName: c.GetString("original_model"),
RecodeModelName: c.GetString("recode_model"), //RecodeModelName: c.GetString("original_model"),
IsModelMapped: false,
ApiType: apiType, ApiType: apiType,
ApiVersion: c.GetString("api_version"), ApiVersion: c.GetString("api_version"),
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
@@ -110,10 +124,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
if info.ChannelType == common.ChannelTypeVertexAi { if info.ChannelType == common.ChannelTypeVertexAi {
info.ApiVersion = c.GetString("region") info.ApiVersion = c.GetString("region")
} }
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic || if streamSupportedChannels[info.ChannelType] {
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
info.ChannelType == common.ChannelCloudflare || info.ChannelType == common.ChannelTypeAzure ||
info.ChannelType == common.ChannelTypeVolcEngine || info.ChannelType == common.ChannelTypeOllama {
info.SupportStreamOptions = true info.SupportStreamOptions = true
} }
return info return info

View File

@@ -0,0 +1,25 @@
package helper
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"one-api/relay/common"
)
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
// map model name
modelMapping := c.GetString("model_mapping")
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return fmt.Errorf("unmarshal_model_mapping_failed")
}
if modelMap[info.OriginModelName] != "" {
info.UpstreamModelName = modelMap[info.OriginModelName]
info.IsModelMapped = true
}
}
return nil
}

41
relay/helper/price.go Normal file
View File

@@ -0,0 +1,41 @@
package helper
import (
"github.com/gin-gonic/gin"
"one-api/common"
relaycommon "one-api/relay/common"
"one-api/setting"
)
type PriceData struct {
ModelPrice float64
ModelRatio float64
GroupRatio float64
UsePrice bool
ShouldPreConsumedQuota int
}
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) PriceData {
modelPrice, usePrice := common.GetModelPrice(info.OriginModelName, false)
groupRatio := setting.GetGroupRatio(info.Group)
var preConsumedQuota int
var modelRatio float64
if !usePrice {
preConsumedTokens := common.PreConsumedQuota
if maxTokens != 0 {
preConsumedTokens = promptTokens + maxTokens
}
modelRatio = common.GetModelRatio(info.OriginModelName)
ratio := modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
return PriceData{
ModelPrice: modelPrice,
ModelRatio: modelRatio,
GroupRatio: groupRatio,
UsePrice: usePrice,
ShouldPreConsumedQuota: preConsumedQuota,
}
}

View File

@@ -1,7 +1,6 @@
package relay package relay
import ( import (
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -11,6 +10,7 @@ import (
"one-api/model" "one-api/model"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service" "one-api/service"
"one-api/setting" "one-api/setting"
) )
@@ -73,15 +73,13 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo.PromptTokens = promptTokens relayInfo.PromptTokens = promptTokens
} }
modelRatio := common.GetModelRatio(audioRequest.Model) priceData := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
userQuota, err := model.GetUserQuota(relayInfo.UserId, false) userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
} }
preConsumedQuota, userQuota, openaiErr = preConsumeQuota(c, preConsumedQuota, relayInfo) preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil { if openaiErr != nil {
return openaiErr return openaiErr
} }
@@ -91,19 +89,12 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
} }
}() }()
// map model name err = helper.ModelMappedHelper(c, relayInfo)
modelMapping := c.GetString("model_mapping")
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
} }
if modelMap[audioRequest.Model] != "" {
audioRequest.Model = modelMap[audioRequest.Model] audioRequest.Model = relayInfo.UpstreamModelName
}
}
relayInfo.UpstreamModelName = audioRequest.Model
adaptor := GetAdaptor(relayInfo.ApiType) adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil { if adaptor == nil {
@@ -140,7 +131,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
return openaiErr return openaiErr
} }
postConsumeQuota(c, relayInfo, audioRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "") postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil return nil
} }

View File

@@ -12,6 +12,7 @@ import (
"one-api/dto" "one-api/dto"
"one-api/model" "one-api/model"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service" "one-api/service"
"one-api/setting" "one-api/setting"
"strings" "strings"
@@ -68,7 +69,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
return imageRequest, nil return imageRequest, nil
} }
func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
relayInfo := relaycommon.GenRelayInfo(c) relayInfo := relaycommon.GenRelayInfo(c)
imageRequest, err := getAndValidImageRequest(c, relayInfo) imageRequest, err := getAndValidImageRequest(c, relayInfo)
@@ -77,19 +78,12 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
} }
// map model name err = helper.ModelMappedHelper(c, relayInfo)
modelMapping := c.GetString("model_mapping")
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
} }
if modelMap[imageRequest.Model] != "" {
imageRequest.Model = modelMap[imageRequest.Model] imageRequest.Model = relayInfo.UpstreamModelName
}
}
relayInfo.UpstreamModelName = imageRequest.Model
modelPrice, success := common.GetModelPrice(imageRequest.Model, true) modelPrice, success := common.GetModelPrice(imageRequest.Model, true)
if !success { if !success {
@@ -183,8 +177,15 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
quality = "hd" quality = "hd"
} }
logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality) priceData := helper.PriceData{
postConsumeQuota(c, relayInfo, imageRequest.Model, usage, 0, 0, userQuota, 0, groupRatio, imageRatio, true, logContent) UsePrice: true,
GroupRatio: groupRatio,
ModelPrice: modelPrice,
ModelRatio: 0,
ShouldPreConsumedQuota: 0,
}
logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
postConsumeQuota(c, relayInfo, usage, 0, userQuota, priceData, logContent)
return nil return nil
} }

View File

@@ -15,6 +15,7 @@ import (
"one-api/model" "one-api/model"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service" "one-api/service"
"one-api/setting" "one-api/setting"
"strings" "strings"
@@ -76,33 +77,6 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
} }
// map model name
//isModelMapped := false
modelMapping := c.GetString("model_mapping")
//isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[textRequest.Model] != "" {
//isModelMapped = true
textRequest.Model = modelMap[textRequest.Model]
// set upstream model name
//isModelMapped = true
}
}
relayInfo.UpstreamModelName = textRequest.Model
relayInfo.RecodeModelName = textRequest.Model
modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64
var modelRatio float64
//err := service.SensitiveWordsCheck(textRequest)
if setting.ShouldCheckPromptSensitive() { if setting.ShouldCheckPromptSensitive() {
err = checkRequestSensitive(textRequest, relayInfo) err = checkRequestSensitive(textRequest, relayInfo)
if err != nil { if err != nil {
@@ -110,6 +84,13 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
} }
} }
err = helper.ModelMappedHelper(c, relayInfo)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
textRequest.Model = relayInfo.UpstreamModelName
// 获取 promptTokens如果上下文中已经存在则直接使用 // 获取 promptTokens如果上下文中已经存在则直接使用
var promptTokens int var promptTokens int
if value, exists := c.Get("prompt_tokens"); exists { if value, exists := c.Get("prompt_tokens"); exists {
@@ -124,20 +105,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
c.Set("prompt_tokens", promptTokens) c.Set("prompt_tokens", promptTokens)
} }
if !getModelPriceSuccess { priceData := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
preConsumedTokens := common.PreConsumedQuota
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
}
modelRatio = common.GetModelRatio(textRequest.Model)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
// pre-consume quota 预消耗配额 // pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo) preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil { if openaiErr != nil {
return openaiErr return openaiErr
} }
@@ -220,10 +191,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
return openaiErr return openaiErr
} }
if strings.HasPrefix(relayInfo.RecodeModelName, "gpt-4o-audio") { if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
} else { } else {
postConsumeQuota(c, relayInfo, relayInfo.RecodeModelName, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
} }
return nil return nil
} }
@@ -319,9 +290,8 @@ func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, us
} }
} }
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
modelPrice float64, usePrice bool, extraContent string) {
if usage == nil { if usage == nil {
usage = &dto.Usage{ usage = &dto.Usage{
PromptTokens: relayInfo.PromptTokens, PromptTokens: relayInfo.PromptTokens,
@@ -333,12 +303,18 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens completionTokens := usage.CompletionTokens
modelName := relayInfo.OriginModelName
tokenName := ctx.GetString("token_name") tokenName := ctx.GetString("token_name")
completionRatio := common.GetCompletionRatio(modelName) completionRatio := common.GetCompletionRatio(modelName)
ratio := priceData.ModelRatio * priceData.GroupRatio
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatio
modelPrice := priceData.ModelPrice
usePrice := priceData.UsePrice
quota := 0 quota := 0
if !usePrice { if !priceData.UsePrice {
quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio)) quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio))
quota = int(math.Round(float64(quota) * ratio)) quota = int(math.Round(float64(quota) * ratio))
if ratio != 0 && quota <= 0 { if ratio != 0 && quota <= 0 {

View File

@@ -10,8 +10,8 @@ import (
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service" "one-api/service"
"one-api/setting"
) )
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int { func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
@@ -47,43 +47,20 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
} }
// map model name err = helper.ModelMappedHelper(c, relayInfo)
modelMapping := c.GetString("model_mapping")
//isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
if modelMap[embeddingRequest.Model] != "" {
embeddingRequest.Model = modelMap[embeddingRequest.Model]
// set upstream model name
//isModelMapped = true
}
} }
relayInfo.UpstreamModelName = embeddingRequest.Model embeddingRequest.Model = relayInfo.UpstreamModelName
modelPrice, success := common.GetModelPrice(embeddingRequest.Model, false)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64
var modelRatio float64
promptToken := getEmbeddingPromptToken(*embeddingRequest) promptToken := getEmbeddingPromptToken(*embeddingRequest)
if !success {
preConsumedTokens := promptToken
modelRatio = common.GetModelRatio(embeddingRequest.Model)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
relayInfo.PromptTokens = promptToken relayInfo.PromptTokens = promptToken
priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
// pre-consume quota 预消耗配额 // pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo) preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil { if openaiErr != nil {
return openaiErr return openaiErr
} }
@@ -132,6 +109,6 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr
} }
postConsumeQuota(c, relayInfo, embeddingRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "") postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil return nil
} }

View File

@@ -9,8 +9,8 @@ import (
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service" "one-api/service"
"one-api/setting"
) )
func getRerankPromptToken(rerankRequest dto.RerankRequest) int { func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
@@ -40,43 +40,20 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
} }
// map model name err = helper.ModelMappedHelper(c, relayInfo)
modelMapping := c.GetString("model_mapping")
//isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
if modelMap[rerankRequest.Model] != "" {
rerankRequest.Model = modelMap[rerankRequest.Model]
// set upstream model name
//isModelMapped = true
}
} }
relayInfo.UpstreamModelName = rerankRequest.Model rerankRequest.Model = relayInfo.UpstreamModelName
modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64
var modelRatio float64
promptToken := getRerankPromptToken(*rerankRequest) promptToken := getRerankPromptToken(*rerankRequest)
if !success {
preConsumedTokens := promptToken
modelRatio = common.GetModelRatio(rerankRequest.Model)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
relayInfo.PromptTokens = promptToken relayInfo.PromptTokens = promptToken
priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
// pre-consume quota 预消耗配额 // pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo) preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil { if openaiErr != nil {
return openaiErr return openaiErr
} }
@@ -124,6 +101,6 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr
} }
postConsumeQuota(c, relayInfo, rerankRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "") postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil return nil
} }

View File

@@ -16,6 +16,10 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
if relayInfo.ReasoningEffort != "" { if relayInfo.ReasoningEffort != "" {
other["reasoning_effort"] = relayInfo.ReasoningEffort other["reasoning_effort"] = relayInfo.ReasoningEffort
} }
if relayInfo.IsModelMapped {
other["is_model_mapped"] = true
other["upstream_model_name"] = relayInfo.UpstreamModelName
}
adminInfo := make(map[string]interface{}) adminInfo := make(map[string]interface{})
adminInfo["use_channel"] = ctx.GetStringSlice("use_channel") adminInfo["use_channel"] = ctx.GetStringSlice("use_channel")
other["admin_info"] = adminInfo other["admin_info"] = adminInfo

View File

@@ -10,6 +10,7 @@ import (
"one-api/dto" "one-api/dto"
"one-api/model" "one-api/model"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/setting" "one-api/setting"
"strings" "strings"
"time" "time"
@@ -68,7 +69,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
return err return err
} }
modelName := relayInfo.UpstreamModelName modelName := relayInfo.OriginModelName
textInputTokens := usage.InputTokenDetails.TextTokens textInputTokens := usage.InputTokenDetails.TextTokens
textOutTokens := usage.OutputTokenDetails.TextTokens textOutTokens := usage.OutputTokenDetails.TextTokens
audioInputTokens := usage.InputTokenDetails.AudioTokens audioInputTokens := usage.InputTokenDetails.AudioTokens
@@ -122,7 +123,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
tokenName := ctx.GetString("token_name") tokenName := ctx.GetString("token_name")
completionRatio := common.GetCompletionRatio(modelName) completionRatio := common.GetCompletionRatio(modelName)
audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName) audioRatio := common.GetAudioRatio(relayInfo.OriginModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(modelName) audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
quotaInfo := QuotaInfo{ quotaInfo := QuotaInfo{
@@ -173,8 +174,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
} }
func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
usage *dto.Usage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
modelPrice float64, usePrice bool, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
textInputTokens := usage.PromptTokensDetails.TextTokens textInputTokens := usage.PromptTokensDetails.TextTokens
@@ -184,9 +184,14 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
audioOutTokens := usage.CompletionTokenDetails.AudioTokens audioOutTokens := usage.CompletionTokenDetails.AudioTokens
tokenName := ctx.GetString("token_name") tokenName := ctx.GetString("token_name")
completionRatio := common.GetCompletionRatio(relayInfo.RecodeModelName) completionRatio := common.GetCompletionRatio(relayInfo.OriginModelName)
audioRatio := common.GetAudioRatio(relayInfo.RecodeModelName) audioRatio := common.GetAudioRatio(relayInfo.OriginModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.RecodeModelName) audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.OriginModelName)
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatio
modelPrice := priceData.ModelPrice
usePrice := priceData.UsePrice
quotaInfo := QuotaInfo{ quotaInfo := QuotaInfo{
InputDetails: TokenDetails{ InputDetails: TokenDetails{
@@ -197,7 +202,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
TextTokens: textOutTokens, TextTokens: textOutTokens,
AudioTokens: audioOutTokens, AudioTokens: audioOutTokens,
}, },
ModelName: relayInfo.RecodeModelName, ModelName: relayInfo.OriginModelName,
UsePrice: usePrice, UsePrice: usePrice,
ModelRatio: modelRatio, ModelRatio: modelRatio,
GroupRatio: groupRatio, GroupRatio: groupRatio,
@@ -220,7 +225,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
quota = 0 quota = 0
logContent += fmt.Sprintf("(可能是上游超时)") logContent += fmt.Sprintf("(可能是上游超时)")
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.RecodeModelName, preConsumedQuota)) "tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota))
} else { } else {
quotaDelta := quota - preConsumedQuota quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 { if quotaDelta != 0 {
@@ -233,7 +238,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
} }
logModel := relayInfo.RecodeModelName logModel := relayInfo.OriginModelName
if extraContent != "" { if extraContent != "" {
logContent += ", " + extraContent logContent += ", " + extraContent
} }

View File

@@ -15,7 +15,7 @@ import {
Button, Descriptions, Button, Descriptions,
Form, Form,
Layout, Layout,
Modal, Modal, Popover,
Select, Select,
Space, Space,
Spin, Spin,
@@ -34,6 +34,7 @@ import {
import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph'; import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph';
import { getLogOther } from '../helpers/other.js'; import { getLogOther } from '../helpers/other.js';
import { StyleContext } from '../context/Style/index.js'; import { StyleContext } from '../context/Style/index.js';
import { IconInherit, IconRefresh } from '@douyinfe/semi-icons';
const { Header } = Layout; const { Header } = Layout;
@@ -143,6 +144,77 @@ const LogsTable = () => {
} }
} }
function renderModelName(record) {
let other = getLogOther(record.other);
let modelMapped = other?.is_model_mapped && other?.upstream_model_name && other?.upstream_model_name !== '';
if (!modelMapped) {
return <Tag
color={stringToColor(record.model_name)}
size='large'
onClick={(event) => {
copyText(event, record.model_name).then(r => {});
}}
>
{' '}{record.model_name}{' '}
</Tag>;
} else {
return (
<>
<Space vertical align={'start'}>
<Popover content={
<div style={{padding: 10}}>
<Space vertical align={'start'}>
<Tag
color={stringToColor(record.model_name)}
size='large'
onClick={(event) => {
copyText(event, record.model_name).then(r => {});
}}
>
{t('请求并计费模型')}{' '}{record.model_name}{' '}
</Tag>
<Tag
color={stringToColor(other.upstream_model_name)}
size='large'
onClick={(event) => {
copyText(event, other.upstream_model_name).then(r => {});
}}
>
{t('实际模型')}{' '}{other.upstream_model_name}{' '}
</Tag>
</Space>
</div>
}>
<Tag
color={stringToColor(record.model_name)}
size='large'
onClick={(event) => {
copyText(event, record.model_name).then(r => {});
}}
suffixIcon={<IconRefresh />}
>
{' '}{record.model_name}{' '}
</Tag>
</Popover>
{/*<Tooltip content={t('实际模型')}>*/}
{/* <Tag*/}
{/* color={stringToColor(other.upstream_model_name)}*/}
{/* size='large'*/}
{/* onClick={(event) => {*/}
{/* copyText(event, other.upstream_model_name).then(r => {});*/}
{/* }}*/}
{/* >*/}
{/* {' '}{other.upstream_model_name}{' '}*/}
{/* </Tag>*/}
{/*</Tooltip>*/}
</Space>
</>
);
}
}
const columns = [ const columns = [
{ {
title: t('时间'), title: t('时间'),
@@ -272,18 +344,7 @@ const LogsTable = () => {
dataIndex: 'model_name', dataIndex: 'model_name',
render: (text, record, index) => { render: (text, record, index) => {
return record.type === 0 || record.type === 2 ? ( return record.type === 0 || record.type === 2 ? (
<> <>{renderModelName(record)}</>
<Tag
color={stringToColor(text)}
size='large'
onClick={(event) => {
copyText(event, text);
}}
>
{' '}
{text}{' '}
</Tag>
</>
) : ( ) : (
<></> <></>
); );
@@ -580,6 +641,17 @@ const LogsTable = () => {
value: logs[i].content, value: logs[i].content,
}); });
if (logs[i].type === 2) { if (logs[i].type === 2) {
let modelMapped = other?.is_model_mapped && other?.upstream_model_name && other?.upstream_model_name !== '';
if (modelMapped) {
expandDataLocal.push({
key: t('请求并计费模型'),
value: logs[i].model_name,
});
expandDataLocal.push({
key: t('实际模型'),
value: other.upstream_model_name,
});
}
let content = ''; let content = '';
if (other?.ws || other?.audio) { if (other?.ws || other?.audio) {
content = renderAudioModelPrice( content = renderAudioModelPrice(