diff --git a/controller/relay.go b/controller/relay.go index d7e0f00a..0f739415 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -24,7 +24,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode var err *dto.OpenAIErrorWithStatusCode switch relayMode { case relayconstant.RelayModeImagesGenerations: - err = relay.ImageHelper(c, relayMode) + err = relay.ImageHelper(c) case relayconstant.RelayModeAudioSpeech: fallthrough case relayconstant.RelayModeAudioTranslation: diff --git a/model/option.go b/model/option.go index 0c4114a4..24935c69 100644 --- a/model/option.go +++ b/model/option.go @@ -84,7 +84,7 @@ func InitOptionMap() { common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter) common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) - common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) + common.OptionMap["ShouldPreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString() common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString() @@ -306,7 +306,7 @@ func updateOptionMap(key string, value string) (err error) { common.QuotaForInvitee, _ = strconv.Atoi(value) case "QuotaRemindThreshold": common.QuotaRemindThreshold, _ = strconv.Atoi(value) - case "PreConsumedQuota": + case "ShouldPreConsumedQuota": common.PreConsumedQuota, _ = strconv.Atoi(value) case "RetryTimes": common.RetryTimes, _ = strconv.Atoi(value) diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 007d17d6..1f4a3a42 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -13,24 +13,24 @@ import ( ) type RelayInfo struct { - ChannelType int - ChannelId int - TokenId int - TokenKey string - UserId int - Group string - TokenUnlimited bool - StartTime time.Time - FirstResponseTime time.Time - setFirstResponse bool - ApiType int - IsStream bool - IsPlayground bool - UsePrice bool - RelayMode int - UpstreamModelName string - OriginModelName string - RecodeModelName string + ChannelType int + ChannelId int + TokenId int + TokenKey string + UserId int + Group string + TokenUnlimited bool + StartTime time.Time + FirstResponseTime time.Time + setFirstResponse bool + ApiType int + IsStream bool + IsPlayground bool + UsePrice bool + RelayMode int + UpstreamModelName string + OriginModelName string + //RecodeModelName string RequestURLPath string ApiVersion string PromptTokens int @@ -39,6 +39,7 @@ type RelayInfo struct { BaseUrl string SupportStreamOptions bool ShouldIncludeUsage bool + IsModelMapped bool ClientWs *websocket.Conn TargetWs *websocket.Conn InputAudioFormat string @@ -50,6 +51,18 @@ type RelayInfo struct { ChannelSetting map[string]interface{} } +// 定义支持流式选项的通道类型 +var streamSupportedChannels = map[int]bool{ + common.ChannelTypeOpenAI: true, + common.ChannelTypeAnthropic: true, + common.ChannelTypeAws: true, + common.ChannelTypeGemini: true, + common.ChannelCloudflare: true, + common.ChannelTypeAzure: true, + common.ChannelTypeVolcEngine: true, + common.ChannelTypeOllama: true, +} + func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { info := GenRelayInfo(c) info.ClientWs = ws @@ -89,12 +102,13 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { FirstResponseTime: startTime.Add(-time.Second), OriginModelName: c.GetString("original_model"), UpstreamModelName: c.GetString("original_model"), - RecodeModelName: c.GetString("recode_model"), - ApiType: apiType, - ApiVersion: c.GetString("api_version"), - ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), - Organization: c.GetString("channel_organization"), - ChannelSetting: channelSetting, + //RecodeModelName: c.GetString("original_model"), + IsModelMapped: false, + ApiType: apiType, + ApiVersion: c.GetString("api_version"), + ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), + Organization: c.GetString("channel_organization"), + ChannelSetting: channelSetting, } if strings.HasPrefix(c.Request.URL.Path, "/pg") { info.IsPlayground = true @@ -110,10 +124,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { if info.ChannelType == common.ChannelTypeVertexAi { info.ApiVersion = c.GetString("region") } - if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic || - info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini || - info.ChannelType == common.ChannelCloudflare || info.ChannelType == common.ChannelTypeAzure || - info.ChannelType == common.ChannelTypeVolcEngine || info.ChannelType == common.ChannelTypeOllama { + if streamSupportedChannels[info.ChannelType] { info.SupportStreamOptions = true } return info diff --git a/relay/helper/model_mapped.go b/relay/helper/model_mapped.go new file mode 100644 index 00000000..948c5226 --- /dev/null +++ b/relay/helper/model_mapped.go @@ -0,0 +1,25 @@ +package helper + +import ( + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "one-api/relay/common" +) + +func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error { + // map model name + modelMapping := c.GetString("model_mapping") + if modelMapping != "" && modelMapping != "{}" { + modelMap := make(map[string]string) + err := json.Unmarshal([]byte(modelMapping), &modelMap) + if err != nil { + return fmt.Errorf("unmarshal_model_mapping_failed") + } + if modelMap[info.OriginModelName] != "" { + info.UpstreamModelName = modelMap[info.OriginModelName] + info.IsModelMapped = true + } + } + return nil +} diff --git a/relay/helper/price.go b/relay/helper/price.go new file mode 100644 index 00000000..d65b86aa --- /dev/null +++ b/relay/helper/price.go @@ -0,0 +1,41 @@ +package helper + +import ( + "github.com/gin-gonic/gin" + "one-api/common" + relaycommon "one-api/relay/common" + "one-api/setting" +) + +type PriceData struct { + ModelPrice float64 + ModelRatio float64 + GroupRatio float64 + UsePrice bool + ShouldPreConsumedQuota int +} + +func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) PriceData { + modelPrice, usePrice := common.GetModelPrice(info.OriginModelName, false) + groupRatio := setting.GetGroupRatio(info.Group) + var preConsumedQuota int + var modelRatio float64 + if !usePrice { + preConsumedTokens := common.PreConsumedQuota + if maxTokens != 0 { + preConsumedTokens = promptTokens + maxTokens + } + modelRatio = common.GetModelRatio(info.OriginModelName) + ratio := modelRatio * groupRatio + preConsumedQuota = int(float64(preConsumedTokens) * ratio) + } else { + preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) + } + return PriceData{ + ModelPrice: modelPrice, + ModelRatio: modelRatio, + GroupRatio: groupRatio, + UsePrice: usePrice, + ShouldPreConsumedQuota: preConsumedQuota, + } +} diff --git a/relay/relay-audio.go b/relay/relay-audio.go index 4c23a8f8..a858bb91 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -1,7 +1,6 @@ package relay import ( - "encoding/json" "errors" "fmt" "github.com/gin-gonic/gin" @@ -11,6 +10,7 @@ import ( "one-api/model" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" + "one-api/relay/helper" "one-api/service" "one-api/setting" ) @@ -73,15 +73,13 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { relayInfo.PromptTokens = promptTokens } - modelRatio := common.GetModelRatio(audioRequest.Model) - groupRatio := setting.GetGroupRatio(relayInfo.Group) - ratio := modelRatio * groupRatio - preConsumedQuota := int(float64(preConsumedTokens) * ratio) + priceData := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0) + userQuota, err := model.GetUserQuota(relayInfo.UserId, false) if err != nil { return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) } - preConsumedQuota, userQuota, openaiErr = preConsumeQuota(c, preConsumedQuota, relayInfo) + preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) if openaiErr != nil { return openaiErr } @@ -91,19 +89,12 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { } }() - // map model name - modelMapping := c.GetString("model_mapping") - if modelMapping != "" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[audioRequest.Model] != "" { - audioRequest.Model = modelMap[audioRequest.Model] - } + err = helper.ModelMappedHelper(c, relayInfo) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - relayInfo.UpstreamModelName = audioRequest.Model + + audioRequest.Model = relayInfo.UpstreamModelName adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { @@ -140,7 +131,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { return openaiErr } - postConsumeQuota(c, relayInfo, audioRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "") + postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") return nil } diff --git a/relay/relay-image.go b/relay/relay-image.go index 207350da..24e62073 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -12,6 +12,7 @@ import ( "one-api/dto" "one-api/model" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/service" "one-api/setting" "strings" @@ -68,7 +69,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. return imageRequest, nil } -func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { +func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { relayInfo := relaycommon.GenRelayInfo(c) imageRequest, err := getAndValidImageRequest(c, relayInfo) @@ -77,19 +78,12 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) } - // map model name - modelMapping := c.GetString("model_mapping") - if modelMapping != "" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[imageRequest.Model] != "" { - imageRequest.Model = modelMap[imageRequest.Model] - } + err = helper.ModelMappedHelper(c, relayInfo) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - relayInfo.UpstreamModelName = imageRequest.Model + + imageRequest.Model = relayInfo.UpstreamModelName modelPrice, success := common.GetModelPrice(imageRequest.Model, true) if !success { @@ -183,8 +177,15 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { quality = "hd" } - logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality) - postConsumeQuota(c, relayInfo, imageRequest.Model, usage, 0, 0, userQuota, 0, groupRatio, imageRatio, true, logContent) + priceData := helper.PriceData{ + UsePrice: true, + GroupRatio: groupRatio, + ModelPrice: modelPrice, + ModelRatio: 0, + ShouldPreConsumedQuota: 0, + } + logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality) + postConsumeQuota(c, relayInfo, usage, 0, userQuota, priceData, logContent) return nil } diff --git a/relay/relay-text.go b/relay/relay-text.go index 5216528e..b438571c 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -15,6 +15,7 @@ import ( "one-api/model" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" + "one-api/relay/helper" "one-api/service" "one-api/setting" "strings" @@ -76,33 +77,6 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) } - // map model name - //isModelMapped := false - modelMapping := c.GetString("model_mapping") - //isModelMapped := false - if modelMapping != "" && modelMapping != "{}" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[textRequest.Model] != "" { - //isModelMapped = true - textRequest.Model = modelMap[textRequest.Model] - // set upstream model name - //isModelMapped = true - } - } - relayInfo.UpstreamModelName = textRequest.Model - relayInfo.RecodeModelName = textRequest.Model - modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false) - groupRatio := setting.GetGroupRatio(relayInfo.Group) - - var preConsumedQuota int - var ratio float64 - var modelRatio float64 - //err := service.SensitiveWordsCheck(textRequest) - if setting.ShouldCheckPromptSensitive() { err = checkRequestSensitive(textRequest, relayInfo) if err != nil { @@ -110,6 +84,13 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { } } + err = helper.ModelMappedHelper(c, relayInfo) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) + } + + textRequest.Model = relayInfo.UpstreamModelName + // 获取 promptTokens,如果上下文中已经存在,则直接使用 var promptTokens int if value, exists := c.Get("prompt_tokens"); exists { @@ -124,20 +105,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { c.Set("prompt_tokens", promptTokens) } - if !getModelPriceSuccess { - preConsumedTokens := common.PreConsumedQuota - if textRequest.MaxTokens != 0 { - preConsumedTokens = promptTokens + int(textRequest.MaxTokens) - } - modelRatio = common.GetModelRatio(textRequest.Model) - ratio = modelRatio * groupRatio - preConsumedQuota = int(float64(preConsumedTokens) * ratio) - } else { - preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) - } + priceData := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens)) // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo) + preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) if openaiErr != nil { return openaiErr } @@ -220,10 +191,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { return openaiErr } - if strings.HasPrefix(relayInfo.RecodeModelName, "gpt-4o-audio") { - service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") + if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") { + service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") } else { - postConsumeQuota(c, relayInfo, relayInfo.RecodeModelName, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") + postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") } return nil } @@ -319,9 +290,8 @@ func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, us } } -func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, - usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, - modelPrice float64, usePrice bool, extraContent string) { +func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, + usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { if usage == nil { usage = &dto.Usage{ PromptTokens: relayInfo.PromptTokens, @@ -333,12 +303,18 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() promptTokens := usage.PromptTokens completionTokens := usage.CompletionTokens + modelName := relayInfo.OriginModelName tokenName := ctx.GetString("token_name") completionRatio := common.GetCompletionRatio(modelName) + ratio := priceData.ModelRatio * priceData.GroupRatio + modelRatio := priceData.ModelRatio + groupRatio := priceData.GroupRatio + modelPrice := priceData.ModelPrice + usePrice := priceData.UsePrice quota := 0 - if !usePrice { + if !priceData.UsePrice { quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio)) quota = int(math.Round(float64(quota) * ratio)) if ratio != 0 && quota <= 0 { diff --git a/relay/relay_embedding.go b/relay/relay_embedding.go index 0a41c11d..18739d9f 100644 --- a/relay/relay_embedding.go +++ b/relay/relay_embedding.go @@ -10,8 +10,8 @@ import ( "one-api/dto" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" + "one-api/relay/helper" "one-api/service" - "one-api/setting" ) func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int { @@ -47,43 +47,20 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest) } - // map model name - modelMapping := c.GetString("model_mapping") - //isModelMapped := false - if modelMapping != "" && modelMapping != "{}" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[embeddingRequest.Model] != "" { - embeddingRequest.Model = modelMap[embeddingRequest.Model] - // set upstream model name - //isModelMapped = true - } + err = helper.ModelMappedHelper(c, relayInfo) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - relayInfo.UpstreamModelName = embeddingRequest.Model - modelPrice, success := common.GetModelPrice(embeddingRequest.Model, false) - groupRatio := setting.GetGroupRatio(relayInfo.Group) - - var preConsumedQuota int - var ratio float64 - var modelRatio float64 + embeddingRequest.Model = relayInfo.UpstreamModelName promptToken := getEmbeddingPromptToken(*embeddingRequest) - if !success { - preConsumedTokens := promptToken - modelRatio = common.GetModelRatio(embeddingRequest.Model) - ratio = modelRatio * groupRatio - preConsumedQuota = int(float64(preConsumedTokens) * ratio) - } else { - preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) - } relayInfo.PromptTokens = promptToken + priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0) + // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo) + preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) if openaiErr != nil { return openaiErr } @@ -132,6 +109,6 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - postConsumeQuota(c, relayInfo, embeddingRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "") + postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") return nil } diff --git a/relay/relay_rerank.go b/relay/relay_rerank.go index e53e37d4..37178cad 100644 --- a/relay/relay_rerank.go +++ b/relay/relay_rerank.go @@ -9,8 +9,8 @@ import ( "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/service" - "one-api/setting" ) func getRerankPromptToken(rerankRequest dto.RerankRequest) int { @@ -40,43 +40,20 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest) } - // map model name - modelMapping := c.GetString("model_mapping") - //isModelMapped := false - if modelMapping != "" && modelMapping != "{}" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[rerankRequest.Model] != "" { - rerankRequest.Model = modelMap[rerankRequest.Model] - // set upstream model name - //isModelMapped = true - } + err = helper.ModelMappedHelper(c, relayInfo) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - relayInfo.UpstreamModelName = rerankRequest.Model - modelPrice, success := common.GetModelPrice(rerankRequest.Model, false) - groupRatio := setting.GetGroupRatio(relayInfo.Group) - - var preConsumedQuota int - var ratio float64 - var modelRatio float64 + rerankRequest.Model = relayInfo.UpstreamModelName promptToken := getRerankPromptToken(*rerankRequest) - if !success { - preConsumedTokens := promptToken - modelRatio = common.GetModelRatio(rerankRequest.Model) - ratio = modelRatio * groupRatio - preConsumedQuota = int(float64(preConsumedTokens) * ratio) - } else { - preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) - } relayInfo.PromptTokens = promptToken + priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0) + // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo) + preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) if openaiErr != nil { return openaiErr } @@ -124,6 +101,6 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - postConsumeQuota(c, relayInfo, rerankRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "") + postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") return nil } diff --git a/service/log_info_generate.go b/service/log_info_generate.go index 1ce09d92..1e32d6f1 100644 --- a/service/log_info_generate.go +++ b/service/log_info_generate.go @@ -16,6 +16,10 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m if relayInfo.ReasoningEffort != "" { other["reasoning_effort"] = relayInfo.ReasoningEffort } + if relayInfo.IsModelMapped { + other["is_model_mapped"] = true + other["upstream_model_name"] = relayInfo.UpstreamModelName + } adminInfo := make(map[string]interface{}) adminInfo["use_channel"] = ctx.GetStringSlice("use_channel") other["admin_info"] = adminInfo diff --git a/service/quota.go b/service/quota.go index 2ec04fe0..98b8530f 100644 --- a/service/quota.go +++ b/service/quota.go @@ -10,6 +10,7 @@ import ( "one-api/dto" "one-api/model" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/setting" "strings" "time" @@ -68,7 +69,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag return err } - modelName := relayInfo.UpstreamModelName + modelName := relayInfo.OriginModelName textInputTokens := usage.InputTokenDetails.TextTokens textOutTokens := usage.OutputTokenDetails.TextTokens audioInputTokens := usage.InputTokenDetails.AudioTokens @@ -122,7 +123,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod tokenName := ctx.GetString("token_name") completionRatio := common.GetCompletionRatio(modelName) - audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName) + audioRatio := common.GetAudioRatio(relayInfo.OriginModelName) audioCompletionRatio := common.GetAudioCompletionRatio(modelName) quotaInfo := QuotaInfo{ @@ -173,8 +174,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod } func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, - usage *dto.Usage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, - modelPrice float64, usePrice bool, extraContent string) { + usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() textInputTokens := usage.PromptTokensDetails.TextTokens @@ -184,9 +184,14 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, audioOutTokens := usage.CompletionTokenDetails.AudioTokens tokenName := ctx.GetString("token_name") - completionRatio := common.GetCompletionRatio(relayInfo.RecodeModelName) - audioRatio := common.GetAudioRatio(relayInfo.RecodeModelName) - audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.RecodeModelName) + completionRatio := common.GetCompletionRatio(relayInfo.OriginModelName) + audioRatio := common.GetAudioRatio(relayInfo.OriginModelName) + audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.OriginModelName) + + modelRatio := priceData.ModelRatio + groupRatio := priceData.GroupRatio + modelPrice := priceData.ModelPrice + usePrice := priceData.UsePrice quotaInfo := QuotaInfo{ InputDetails: TokenDetails{ @@ -197,7 +202,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, TextTokens: textOutTokens, AudioTokens: audioOutTokens, }, - ModelName: relayInfo.RecodeModelName, + ModelName: relayInfo.OriginModelName, UsePrice: usePrice, ModelRatio: modelRatio, GroupRatio: groupRatio, @@ -220,7 +225,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, quota = 0 logContent += fmt.Sprintf("(可能是上游超时)") common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ - "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.RecodeModelName, preConsumedQuota)) + "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota)) } else { quotaDelta := quota - preConsumedQuota if quotaDelta != 0 { @@ -233,7 +238,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } - logModel := relayInfo.RecodeModelName + logModel := relayInfo.OriginModelName if extraContent != "" { logContent += ", " + extraContent } diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js index e512a9e9..20758b72 100644 --- a/web/src/components/LogsTable.js +++ b/web/src/components/LogsTable.js @@ -15,7 +15,7 @@ import { Button, Descriptions, Form, Layout, - Modal, + Modal, Popover, Select, Space, Spin, @@ -34,6 +34,7 @@ import { import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph'; import { getLogOther } from '../helpers/other.js'; import { StyleContext } from '../context/Style/index.js'; +import { IconInherit, IconRefresh } from '@douyinfe/semi-icons'; const { Header } = Layout; @@ -141,7 +142,78 @@ const LogsTable = () => { ); } - } + } + + function renderModelName(record) { + + let other = getLogOther(record.other); + let modelMapped = other?.is_model_mapped && other?.upstream_model_name && other?.upstream_model_name !== ''; + if (!modelMapped) { + return { + copyText(event, record.model_name).then(r => {}); + }} + > + {' '}{record.model_name}{' '} + ; + } else { + return ( + <> + + + + { + copyText(event, record.model_name).then(r => {}); + }} + > + {t('请求并计费模型')}{' '}{record.model_name}{' '} + + { + copyText(event, other.upstream_model_name).then(r => {}); + }} + > + {t('实际模型')}{' '}{other.upstream_model_name}{' '} + + + + }> + { + copyText(event, record.model_name).then(r => {}); + }} + suffixIcon={} + > + {' '}{record.model_name}{' '} + + + {/**/} + {/* {*/} + {/* copyText(event, other.upstream_model_name).then(r => {});*/} + {/* }}*/} + {/* >*/} + {/* {' '}{other.upstream_model_name}{' '}*/} + {/* */} + {/**/} + + + ); + } + + } const columns = [ { @@ -272,18 +344,7 @@ const LogsTable = () => { dataIndex: 'model_name', render: (text, record, index) => { return record.type === 0 || record.type === 2 ? ( - <> - { - copyText(event, text); - }} - > - {' '} - {text}{' '} - - + <>{renderModelName(record)} ) : ( <> ); @@ -580,6 +641,17 @@ const LogsTable = () => { value: logs[i].content, }); if (logs[i].type === 2) { + let modelMapped = other?.is_model_mapped && other?.upstream_model_name && other?.upstream_model_name !== ''; + if (modelMapped) { + expandDataLocal.push({ + key: t('请求并计费模型'), + value: logs[i].model_name, + }); + expandDataLocal.push({ + key: t('实际模型'), + value: other.upstream_model_name, + }); + } let content = ''; if (other?.ws || other?.audio) { content = renderAudioModelPrice(