refactor: Simplify model mapping and pricing logic across relay modules

This commit is contained in:
1808837298@qq.com
2025-02-20 16:41:46 +08:00
parent 60aac77c08
commit 06da65a9d0
13 changed files with 279 additions and 199 deletions

View File

@@ -10,8 +10,8 @@ import (
"one-api/dto"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
)
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
@@ -47,43 +47,20 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
}
// map model name
modelMapping := c.GetString("model_mapping")
//isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[embeddingRequest.Model] != "" {
embeddingRequest.Model = modelMap[embeddingRequest.Model]
// set upstream model name
//isModelMapped = true
}
err = helper.ModelMappedHelper(c, relayInfo)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
relayInfo.UpstreamModelName = embeddingRequest.Model
modelPrice, success := common.GetModelPrice(embeddingRequest.Model, false)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64
var modelRatio float64
embeddingRequest.Model = relayInfo.UpstreamModelName
promptToken := getEmbeddingPromptToken(*embeddingRequest)
if !success {
preConsumedTokens := promptToken
modelRatio = common.GetModelRatio(embeddingRequest.Model)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
relayInfo.PromptTokens = promptToken
priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
@@ -132,6 +109,6 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
postConsumeQuota(c, relayInfo, embeddingRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil
}