diff --git a/controller/relay.go b/controller/relay.go
index d7e0f00a..0f739415 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -24,7 +24,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
var err *dto.OpenAIErrorWithStatusCode
switch relayMode {
case relayconstant.RelayModeImagesGenerations:
- err = relay.ImageHelper(c, relayMode)
+ err = relay.ImageHelper(c)
case relayconstant.RelayModeAudioSpeech:
fallthrough
case relayconstant.RelayModeAudioTranslation:
diff --git a/model/option.go b/model/option.go
index 0c4114a4..24935c69 100644
--- a/model/option.go
+++ b/model/option.go
@@ -84,7 +84,7 @@ func InitOptionMap() {
common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
- common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
+ common.OptionMap["ShouldPreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
@@ -306,7 +306,7 @@ func updateOptionMap(key string, value string) (err error) {
common.QuotaForInvitee, _ = strconv.Atoi(value)
case "QuotaRemindThreshold":
common.QuotaRemindThreshold, _ = strconv.Atoi(value)
- case "PreConsumedQuota":
+ case "ShouldPreConsumedQuota":
common.PreConsumedQuota, _ = strconv.Atoi(value)
case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value)
diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go
index 007d17d6..1f4a3a42 100644
--- a/relay/common/relay_info.go
+++ b/relay/common/relay_info.go
@@ -13,24 +13,24 @@ import (
)
type RelayInfo struct {
- ChannelType int
- ChannelId int
- TokenId int
- TokenKey string
- UserId int
- Group string
- TokenUnlimited bool
- StartTime time.Time
- FirstResponseTime time.Time
- setFirstResponse bool
- ApiType int
- IsStream bool
- IsPlayground bool
- UsePrice bool
- RelayMode int
- UpstreamModelName string
- OriginModelName string
- RecodeModelName string
+ ChannelType int
+ ChannelId int
+ TokenId int
+ TokenKey string
+ UserId int
+ Group string
+ TokenUnlimited bool
+ StartTime time.Time
+ FirstResponseTime time.Time
+ setFirstResponse bool
+ ApiType int
+ IsStream bool
+ IsPlayground bool
+ UsePrice bool
+ RelayMode int
+ UpstreamModelName string
+ OriginModelName string
+ //RecodeModelName string
RequestURLPath string
ApiVersion string
PromptTokens int
@@ -39,6 +39,7 @@ type RelayInfo struct {
BaseUrl string
SupportStreamOptions bool
ShouldIncludeUsage bool
+ IsModelMapped bool
ClientWs *websocket.Conn
TargetWs *websocket.Conn
InputAudioFormat string
@@ -50,6 +51,18 @@ type RelayInfo struct {
ChannelSetting map[string]interface{}
}
+// 定义支持流式选项的通道类型
+var streamSupportedChannels = map[int]bool{
+ common.ChannelTypeOpenAI: true,
+ common.ChannelTypeAnthropic: true,
+ common.ChannelTypeAws: true,
+ common.ChannelTypeGemini: true,
+ common.ChannelCloudflare: true,
+ common.ChannelTypeAzure: true,
+ common.ChannelTypeVolcEngine: true,
+ common.ChannelTypeOllama: true,
+}
+
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
info := GenRelayInfo(c)
info.ClientWs = ws
@@ -89,12 +102,13 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
FirstResponseTime: startTime.Add(-time.Second),
OriginModelName: c.GetString("original_model"),
UpstreamModelName: c.GetString("original_model"),
- RecodeModelName: c.GetString("recode_model"),
- ApiType: apiType,
- ApiVersion: c.GetString("api_version"),
- ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
- Organization: c.GetString("channel_organization"),
- ChannelSetting: channelSetting,
+ //RecodeModelName: c.GetString("original_model"),
+ IsModelMapped: false,
+ ApiType: apiType,
+ ApiVersion: c.GetString("api_version"),
+ ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
+ Organization: c.GetString("channel_organization"),
+ ChannelSetting: channelSetting,
}
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
info.IsPlayground = true
@@ -110,10 +124,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
if info.ChannelType == common.ChannelTypeVertexAi {
info.ApiVersion = c.GetString("region")
}
- if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
- info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
- info.ChannelType == common.ChannelCloudflare || info.ChannelType == common.ChannelTypeAzure ||
- info.ChannelType == common.ChannelTypeVolcEngine || info.ChannelType == common.ChannelTypeOllama {
+ if streamSupportedChannels[info.ChannelType] {
info.SupportStreamOptions = true
}
return info
diff --git a/relay/helper/model_mapped.go b/relay/helper/model_mapped.go
new file mode 100644
index 00000000..948c5226
--- /dev/null
+++ b/relay/helper/model_mapped.go
@@ -0,0 +1,25 @@
+package helper
+
+import (
+ "encoding/json"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "one-api/relay/common"
+)
+
+func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
+ // map model name
+ modelMapping := c.GetString("model_mapping")
+ if modelMapping != "" && modelMapping != "{}" {
+ modelMap := make(map[string]string)
+ err := json.Unmarshal([]byte(modelMapping), &modelMap)
+ if err != nil {
+ return fmt.Errorf("unmarshal_model_mapping_failed")
+ }
+ if modelMap[info.OriginModelName] != "" {
+ info.UpstreamModelName = modelMap[info.OriginModelName]
+ info.IsModelMapped = true
+ }
+ }
+ return nil
+}
diff --git a/relay/helper/price.go b/relay/helper/price.go
new file mode 100644
index 00000000..d65b86aa
--- /dev/null
+++ b/relay/helper/price.go
@@ -0,0 +1,41 @@
+package helper
+
+import (
+ "github.com/gin-gonic/gin"
+ "one-api/common"
+ relaycommon "one-api/relay/common"
+ "one-api/setting"
+)
+
+type PriceData struct {
+ ModelPrice float64
+ ModelRatio float64
+ GroupRatio float64
+ UsePrice bool
+ ShouldPreConsumedQuota int
+}
+
+func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) PriceData {
+ modelPrice, usePrice := common.GetModelPrice(info.OriginModelName, false)
+ groupRatio := setting.GetGroupRatio(info.Group)
+ var preConsumedQuota int
+ var modelRatio float64
+ if !usePrice {
+ preConsumedTokens := common.PreConsumedQuota
+ if maxTokens != 0 {
+ preConsumedTokens = promptTokens + maxTokens
+ }
+ modelRatio = common.GetModelRatio(info.OriginModelName)
+ ratio := modelRatio * groupRatio
+ preConsumedQuota = int(float64(preConsumedTokens) * ratio)
+ } else {
+ preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
+ }
+ return PriceData{
+ ModelPrice: modelPrice,
+ ModelRatio: modelRatio,
+ GroupRatio: groupRatio,
+ UsePrice: usePrice,
+ ShouldPreConsumedQuota: preConsumedQuota,
+ }
+}
diff --git a/relay/relay-audio.go b/relay/relay-audio.go
index 4c23a8f8..a858bb91 100644
--- a/relay/relay-audio.go
+++ b/relay/relay-audio.go
@@ -1,7 +1,6 @@
package relay
import (
- "encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
@@ -11,6 +10,7 @@ import (
"one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
+ "one-api/relay/helper"
"one-api/service"
"one-api/setting"
)
@@ -73,15 +73,13 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo.PromptTokens = promptTokens
}
- modelRatio := common.GetModelRatio(audioRequest.Model)
- groupRatio := setting.GetGroupRatio(relayInfo.Group)
- ratio := modelRatio * groupRatio
- preConsumedQuota := int(float64(preConsumedTokens) * ratio)
+ priceData := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
+
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
}
- preConsumedQuota, userQuota, openaiErr = preConsumeQuota(c, preConsumedQuota, relayInfo)
+ preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
@@ -91,19 +89,12 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
}
}()
- // map model name
- modelMapping := c.GetString("model_mapping")
- if modelMapping != "" {
- modelMap := make(map[string]string)
- err := json.Unmarshal([]byte(modelMapping), &modelMap)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
- }
- if modelMap[audioRequest.Model] != "" {
- audioRequest.Model = modelMap[audioRequest.Model]
- }
+ err = helper.ModelMappedHelper(c, relayInfo)
+ if err != nil {
+ return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
- relayInfo.UpstreamModelName = audioRequest.Model
+
+ audioRequest.Model = relayInfo.UpstreamModelName
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
@@ -140,7 +131,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
return openaiErr
}
- postConsumeQuota(c, relayInfo, audioRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "")
+ postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil
}
diff --git a/relay/relay-image.go b/relay/relay-image.go
index 207350da..24e62073 100644
--- a/relay/relay-image.go
+++ b/relay/relay-image.go
@@ -12,6 +12,7 @@ import (
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
"one-api/service"
"one-api/setting"
"strings"
@@ -68,7 +69,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
return imageRequest, nil
}
-func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
+func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
relayInfo := relaycommon.GenRelayInfo(c)
imageRequest, err := getAndValidImageRequest(c, relayInfo)
@@ -77,19 +78,12 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
}
- // map model name
- modelMapping := c.GetString("model_mapping")
- if modelMapping != "" {
- modelMap := make(map[string]string)
- err := json.Unmarshal([]byte(modelMapping), &modelMap)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
- }
- if modelMap[imageRequest.Model] != "" {
- imageRequest.Model = modelMap[imageRequest.Model]
- }
+ err = helper.ModelMappedHelper(c, relayInfo)
+ if err != nil {
+ return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
- relayInfo.UpstreamModelName = imageRequest.Model
+
+ imageRequest.Model = relayInfo.UpstreamModelName
modelPrice, success := common.GetModelPrice(imageRequest.Model, true)
if !success {
@@ -183,8 +177,15 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
quality = "hd"
}
- logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
- postConsumeQuota(c, relayInfo, imageRequest.Model, usage, 0, 0, userQuota, 0, groupRatio, imageRatio, true, logContent)
+ priceData := helper.PriceData{
+ UsePrice: true,
+ GroupRatio: groupRatio,
+ ModelPrice: modelPrice,
+ ModelRatio: 0,
+ ShouldPreConsumedQuota: 0,
+ }
+ logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
+ postConsumeQuota(c, relayInfo, usage, 0, userQuota, priceData, logContent)
return nil
}
diff --git a/relay/relay-text.go b/relay/relay-text.go
index 5216528e..b438571c 100644
--- a/relay/relay-text.go
+++ b/relay/relay-text.go
@@ -15,6 +15,7 @@ import (
"one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
+ "one-api/relay/helper"
"one-api/service"
"one-api/setting"
"strings"
@@ -76,33 +77,6 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
}
- // map model name
- //isModelMapped := false
- modelMapping := c.GetString("model_mapping")
- //isModelMapped := false
- if modelMapping != "" && modelMapping != "{}" {
- modelMap := make(map[string]string)
- err := json.Unmarshal([]byte(modelMapping), &modelMap)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
- }
- if modelMap[textRequest.Model] != "" {
- //isModelMapped = true
- textRequest.Model = modelMap[textRequest.Model]
- // set upstream model name
- //isModelMapped = true
- }
- }
- relayInfo.UpstreamModelName = textRequest.Model
- relayInfo.RecodeModelName = textRequest.Model
- modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false)
- groupRatio := setting.GetGroupRatio(relayInfo.Group)
-
- var preConsumedQuota int
- var ratio float64
- var modelRatio float64
- //err := service.SensitiveWordsCheck(textRequest)
-
if setting.ShouldCheckPromptSensitive() {
err = checkRequestSensitive(textRequest, relayInfo)
if err != nil {
@@ -110,6 +84,13 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
}
}
+ err = helper.ModelMappedHelper(c, relayInfo)
+ if err != nil {
+ return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
+ }
+
+ textRequest.Model = relayInfo.UpstreamModelName
+
// 获取 promptTokens,如果上下文中已经存在,则直接使用
var promptTokens int
if value, exists := c.Get("prompt_tokens"); exists {
@@ -124,20 +105,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
c.Set("prompt_tokens", promptTokens)
}
- if !getModelPriceSuccess {
- preConsumedTokens := common.PreConsumedQuota
- if textRequest.MaxTokens != 0 {
- preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
- }
- modelRatio = common.GetModelRatio(textRequest.Model)
- ratio = modelRatio * groupRatio
- preConsumedQuota = int(float64(preConsumedTokens) * ratio)
- } else {
- preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
- }
+ priceData := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
// pre-consume quota 预消耗配额
- preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
+ preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
@@ -220,10 +191,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
return openaiErr
}
- if strings.HasPrefix(relayInfo.RecodeModelName, "gpt-4o-audio") {
- service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
+ if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
+ service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
} else {
- postConsumeQuota(c, relayInfo, relayInfo.RecodeModelName, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
+ postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
}
return nil
}
@@ -319,9 +290,8 @@ func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, us
}
}
-func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
- usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
- modelPrice float64, usePrice bool, extraContent string) {
+func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
+ usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
if usage == nil {
usage = &dto.Usage{
PromptTokens: relayInfo.PromptTokens,
@@ -333,12 +303,18 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
+ modelName := relayInfo.OriginModelName
tokenName := ctx.GetString("token_name")
completionRatio := common.GetCompletionRatio(modelName)
+ ratio := priceData.ModelRatio * priceData.GroupRatio
+ modelRatio := priceData.ModelRatio
+ groupRatio := priceData.GroupRatio
+ modelPrice := priceData.ModelPrice
+ usePrice := priceData.UsePrice
quota := 0
- if !usePrice {
+ if !priceData.UsePrice {
quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio))
quota = int(math.Round(float64(quota) * ratio))
if ratio != 0 && quota <= 0 {
diff --git a/relay/relay_embedding.go b/relay/relay_embedding.go
index 0a41c11d..18739d9f 100644
--- a/relay/relay_embedding.go
+++ b/relay/relay_embedding.go
@@ -10,8 +10,8 @@ import (
"one-api/dto"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
+ "one-api/relay/helper"
"one-api/service"
- "one-api/setting"
)
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
@@ -47,43 +47,20 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
}
- // map model name
- modelMapping := c.GetString("model_mapping")
- //isModelMapped := false
- if modelMapping != "" && modelMapping != "{}" {
- modelMap := make(map[string]string)
- err := json.Unmarshal([]byte(modelMapping), &modelMap)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
- }
- if modelMap[embeddingRequest.Model] != "" {
- embeddingRequest.Model = modelMap[embeddingRequest.Model]
- // set upstream model name
- //isModelMapped = true
- }
+ err = helper.ModelMappedHelper(c, relayInfo)
+ if err != nil {
+ return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
- relayInfo.UpstreamModelName = embeddingRequest.Model
- modelPrice, success := common.GetModelPrice(embeddingRequest.Model, false)
- groupRatio := setting.GetGroupRatio(relayInfo.Group)
-
- var preConsumedQuota int
- var ratio float64
- var modelRatio float64
+ embeddingRequest.Model = relayInfo.UpstreamModelName
promptToken := getEmbeddingPromptToken(*embeddingRequest)
- if !success {
- preConsumedTokens := promptToken
- modelRatio = common.GetModelRatio(embeddingRequest.Model)
- ratio = modelRatio * groupRatio
- preConsumedQuota = int(float64(preConsumedTokens) * ratio)
- } else {
- preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
- }
relayInfo.PromptTokens = promptToken
+ priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
+
// pre-consume quota 预消耗配额
- preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
+ preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
@@ -132,6 +109,6 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
- postConsumeQuota(c, relayInfo, embeddingRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
+ postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil
}
diff --git a/relay/relay_rerank.go b/relay/relay_rerank.go
index e53e37d4..37178cad 100644
--- a/relay/relay_rerank.go
+++ b/relay/relay_rerank.go
@@ -9,8 +9,8 @@ import (
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
"one-api/service"
- "one-api/setting"
)
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
@@ -40,43 +40,20 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
}
- // map model name
- modelMapping := c.GetString("model_mapping")
- //isModelMapped := false
- if modelMapping != "" && modelMapping != "{}" {
- modelMap := make(map[string]string)
- err := json.Unmarshal([]byte(modelMapping), &modelMap)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
- }
- if modelMap[rerankRequest.Model] != "" {
- rerankRequest.Model = modelMap[rerankRequest.Model]
- // set upstream model name
- //isModelMapped = true
- }
+ err = helper.ModelMappedHelper(c, relayInfo)
+ if err != nil {
+ return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
- relayInfo.UpstreamModelName = rerankRequest.Model
- modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
- groupRatio := setting.GetGroupRatio(relayInfo.Group)
-
- var preConsumedQuota int
- var ratio float64
- var modelRatio float64
+ rerankRequest.Model = relayInfo.UpstreamModelName
promptToken := getRerankPromptToken(*rerankRequest)
- if !success {
- preConsumedTokens := promptToken
- modelRatio = common.GetModelRatio(rerankRequest.Model)
- ratio = modelRatio * groupRatio
- preConsumedQuota = int(float64(preConsumedTokens) * ratio)
- } else {
- preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
- }
relayInfo.PromptTokens = promptToken
+ priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
+
// pre-consume quota 预消耗配额
- preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
+ preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
@@ -124,6 +101,6 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
- postConsumeQuota(c, relayInfo, rerankRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
+ postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil
}
diff --git a/service/log_info_generate.go b/service/log_info_generate.go
index 1ce09d92..1e32d6f1 100644
--- a/service/log_info_generate.go
+++ b/service/log_info_generate.go
@@ -16,6 +16,10 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
if relayInfo.ReasoningEffort != "" {
other["reasoning_effort"] = relayInfo.ReasoningEffort
}
+ if relayInfo.IsModelMapped {
+ other["is_model_mapped"] = true
+ other["upstream_model_name"] = relayInfo.UpstreamModelName
+ }
adminInfo := make(map[string]interface{})
adminInfo["use_channel"] = ctx.GetStringSlice("use_channel")
other["admin_info"] = adminInfo
diff --git a/service/quota.go b/service/quota.go
index 2ec04fe0..98b8530f 100644
--- a/service/quota.go
+++ b/service/quota.go
@@ -10,6 +10,7 @@ import (
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
"one-api/setting"
"strings"
"time"
@@ -68,7 +69,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
return err
}
- modelName := relayInfo.UpstreamModelName
+ modelName := relayInfo.OriginModelName
textInputTokens := usage.InputTokenDetails.TextTokens
textOutTokens := usage.OutputTokenDetails.TextTokens
audioInputTokens := usage.InputTokenDetails.AudioTokens
@@ -122,7 +123,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
tokenName := ctx.GetString("token_name")
completionRatio := common.GetCompletionRatio(modelName)
- audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
+ audioRatio := common.GetAudioRatio(relayInfo.OriginModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
quotaInfo := QuotaInfo{
@@ -173,8 +174,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
}
func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
- usage *dto.Usage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
- modelPrice float64, usePrice bool, extraContent string) {
+ usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
textInputTokens := usage.PromptTokensDetails.TextTokens
@@ -184,9 +184,14 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
audioOutTokens := usage.CompletionTokenDetails.AudioTokens
tokenName := ctx.GetString("token_name")
- completionRatio := common.GetCompletionRatio(relayInfo.RecodeModelName)
- audioRatio := common.GetAudioRatio(relayInfo.RecodeModelName)
- audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.RecodeModelName)
+ completionRatio := common.GetCompletionRatio(relayInfo.OriginModelName)
+ audioRatio := common.GetAudioRatio(relayInfo.OriginModelName)
+ audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.OriginModelName)
+
+ modelRatio := priceData.ModelRatio
+ groupRatio := priceData.GroupRatio
+ modelPrice := priceData.ModelPrice
+ usePrice := priceData.UsePrice
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
@@ -197,7 +202,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
TextTokens: textOutTokens,
AudioTokens: audioOutTokens,
},
- ModelName: relayInfo.RecodeModelName,
+ ModelName: relayInfo.OriginModelName,
UsePrice: usePrice,
ModelRatio: modelRatio,
GroupRatio: groupRatio,
@@ -220,7 +225,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
quota = 0
logContent += fmt.Sprintf("(可能是上游超时)")
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
- "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.RecodeModelName, preConsumedQuota))
+ "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota))
} else {
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
@@ -233,7 +238,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
- logModel := relayInfo.RecodeModelName
+ logModel := relayInfo.OriginModelName
if extraContent != "" {
logContent += ", " + extraContent
}
diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js
index e512a9e9..20758b72 100644
--- a/web/src/components/LogsTable.js
+++ b/web/src/components/LogsTable.js
@@ -15,7 +15,7 @@ import {
Button, Descriptions,
Form,
Layout,
- Modal,
+ Modal, Popover,
Select,
Space,
Spin,
@@ -34,6 +34,7 @@ import {
import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph';
import { getLogOther } from '../helpers/other.js';
import { StyleContext } from '../context/Style/index.js';
+import { IconInherit, IconRefresh } from '@douyinfe/semi-icons';
const { Header } = Layout;
@@ -141,7 +142,78 @@ const LogsTable = () => {
);
}
- }
+ }
+
+ function renderModelName(record) {
+
+ let other = getLogOther(record.other);
+ let modelMapped = other?.is_model_mapped && other?.upstream_model_name && other?.upstream_model_name !== '';
+ if (!modelMapped) {
+ return {
+ copyText(event, record.model_name).then(r => {});
+ }}
+ >
+ {' '}{record.model_name}{' '}
+ ;
+ } else {
+ return (
+ <>
+
+
+
+ {
+ copyText(event, record.model_name).then(r => {});
+ }}
+ >
+ {t('请求并计费模型')}{' '}{record.model_name}{' '}
+
+ {
+ copyText(event, other.upstream_model_name).then(r => {});
+ }}
+ >
+ {t('实际模型')}{' '}{other.upstream_model_name}{' '}
+
+
+
+ }>
+ {
+ copyText(event, record.model_name).then(r => {});
+ }}
+ suffixIcon={}
+ >
+ {' '}{record.model_name}{' '}
+
+
+ {/**/}
+ {/* {*/}
+ {/* copyText(event, other.upstream_model_name).then(r => {});*/}
+ {/* }}*/}
+ {/* >*/}
+ {/* {' '}{other.upstream_model_name}{' '}*/}
+ {/* */}
+ {/**/}
+
+ >
+ );
+ }
+
+ }
const columns = [
{
@@ -272,18 +344,7 @@ const LogsTable = () => {
dataIndex: 'model_name',
render: (text, record, index) => {
return record.type === 0 || record.type === 2 ? (
- <>
- {
- copyText(event, text);
- }}
- >
- {' '}
- {text}{' '}
-
- >
+ <>{renderModelName(record)}>
) : (
<>>
);
@@ -580,6 +641,17 @@ const LogsTable = () => {
value: logs[i].content,
});
if (logs[i].type === 2) {
+ let modelMapped = other?.is_model_mapped && other?.upstream_model_name && other?.upstream_model_name !== '';
+ if (modelMapped) {
+ expandDataLocal.push({
+ key: t('请求并计费模型'),
+ value: logs[i].model_name,
+ });
+ expandDataLocal.push({
+ key: t('实际模型'),
+ value: other.upstream_model_name,
+ });
+ }
let content = '';
if (other?.ws || other?.audio) {
content = renderAudioModelPrice(