diff --git a/controller/channel-test.go b/controller/channel-test.go index 68e4d939..cf253900 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -17,8 +17,8 @@ import ( "one-api/relay" relaycommon "one-api/relay/common" "one-api/relay/constant" + "one-api/relay/helper" "one-api/service" - "one-api/setting" "strconv" "strings" "sync" @@ -73,18 +73,6 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr } } - modelMapping := *channel.ModelMapping - if modelMapping != "" && modelMapping != "{}" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[testModel] != "" { - testModel = modelMap[testModel] - } - } - cache, err := model.GetUserCache(1) if err != nil { return err, nil @@ -98,7 +86,13 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr middleware.SetupContextForSelectedChannel(c, channel, testModel) - meta := relaycommon.GenRelayInfo(c) + info := relaycommon.GenRelayInfo(c) + + err = helper.ModelMappedHelper(c, info) + if err != nil { + return err, nil + } + apiType, _ := constant.ChannelType2APIType(channel.Type) adaptor := relay.GetAdaptor(apiType) if adaptor == nil { @@ -106,12 +100,12 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr } request := buildTestRequest(testModel) - meta.UpstreamModelName = testModel - common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %v ", channel.Id, testModel, meta)) + info.OriginModelName = testModel + common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %v ", channel.Id, testModel, info)) - adaptor.Init(meta) + adaptor.Init(info) - convertedRequest, err := adaptor.ConvertRequest(c, meta, request) + convertedRequest, err := adaptor.ConvertRequest(c, info, request) if err != nil { return err, nil } @@ -121,7 +115,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr } requestBody := bytes.NewBuffer(jsonData) c.Request.Body = io.NopCloser(requestBody) - resp, err := adaptor.DoRequest(c, meta, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return err, nil } @@ -133,7 +127,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err } } - usageA, respErr := adaptor.DoResponse(c, httpResp, meta) + usageA, respErr := adaptor.DoResponse(c, httpResp, info) if respErr != nil { return fmt.Errorf("%s", respErr.Error.Message), respErr } @@ -146,27 +140,24 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr if err != nil { return err, nil } - modelPrice, usePrice := setting.GetModelPrice(testModel, false) - modelRatio, success := setting.GetModelRatio(testModel) - if !usePrice && !success { - return fmt.Errorf("模型 %s 倍率和价格均未设置,请设置或者开启自用模式", testModel), nil + priceData, err := helper.ModelPriceHelper(c, info, usage.PromptTokens, int(request.MaxTokens)) + if err != nil { + return err, nil } - completionRatio := setting.GetCompletionRatio(testModel) - ratio := modelRatio quota := 0 - if !usePrice { - quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*completionRatio)) - quota = int(math.Round(float64(quota) * ratio)) - if ratio != 0 && quota <= 0 { + if !priceData.UsePrice { + quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio)) + quota = int(math.Round(float64(quota) * priceData.ModelRatio)) + if priceData.ModelRatio != 0 && quota <= 0 { quota = 1 } } else { - quota = int(modelPrice * common.QuotaPerUnit) + quota = int(priceData.ModelPrice * common.QuotaPerUnit) } tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() consumedTime := float64(milliseconds) / 1000.0 - other := service.GenerateTextOtherInfo(c, meta, modelRatio, 1, completionRatio, modelPrice) + other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio, priceData.ModelPrice) model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试", quota, "模型测试", 0, quota, int(consumedTime), false, "default", other) common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) diff --git a/relay/helper/price.go b/relay/helper/price.go index 97cbf162..51f64082 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -11,6 +11,7 @@ import ( type PriceData struct { ModelPrice float64 ModelRatio float64 + CompletionRatio float64 GroupRatio float64 UsePrice bool ShouldPreConsumedQuota int @@ -21,6 +22,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens groupRatio := setting.GetGroupRatio(info.Group) var preConsumedQuota int var modelRatio float64 + var completionRatio float64 if !usePrice { preConsumedTokens := common.PreConsumedQuota if maxTokens != 0 { @@ -35,6 +37,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置, 请联系管理员设置;Model %s ratio or price not set, please contact administrator to set", info.OriginModelName, info.OriginModelName) } } + completionRatio = setting.GetCompletionRatio(info.OriginModelName) ratio := modelRatio * groupRatio preConsumedQuota = int(float64(preConsumedTokens) * ratio) } else { @@ -43,6 +46,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens return PriceData{ ModelPrice: modelPrice, ModelRatio: modelRatio, + CompletionRatio: completionRatio, GroupRatio: groupRatio, UsePrice: usePrice, ShouldPreConsumedQuota: preConsumedQuota,