refactor: Improve channel testing and model price handling

This commit is contained in:
1808837298@qq.com
2025-03-02 15:47:12 +08:00
parent 816e831a2e
commit d042a1bd55
2 changed files with 27 additions and 32 deletions

View File

@@ -17,8 +17,8 @@ import (
"one-api/relay" "one-api/relay"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/constant" "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service" "one-api/service"
"one-api/setting"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@@ -73,18 +73,6 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
} }
} }
modelMapping := *channel.ModelMapping
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[testModel] != "" {
testModel = modelMap[testModel]
}
}
cache, err := model.GetUserCache(1) cache, err := model.GetUserCache(1)
if err != nil { if err != nil {
return err, nil return err, nil
@@ -98,7 +86,13 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
middleware.SetupContextForSelectedChannel(c, channel, testModel) middleware.SetupContextForSelectedChannel(c, channel, testModel)
meta := relaycommon.GenRelayInfo(c) info := relaycommon.GenRelayInfo(c)
err = helper.ModelMappedHelper(c, info)
if err != nil {
return err, nil
}
apiType, _ := constant.ChannelType2APIType(channel.Type) apiType, _ := constant.ChannelType2APIType(channel.Type)
adaptor := relay.GetAdaptor(apiType) adaptor := relay.GetAdaptor(apiType)
if adaptor == nil { if adaptor == nil {
@@ -106,12 +100,12 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
} }
request := buildTestRequest(testModel) request := buildTestRequest(testModel)
meta.UpstreamModelName = testModel info.OriginModelName = testModel
common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %v ", channel.Id, testModel, meta)) common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %v ", channel.Id, testModel, info))
adaptor.Init(meta) adaptor.Init(info)
convertedRequest, err := adaptor.ConvertRequest(c, meta, request) convertedRequest, err := adaptor.ConvertRequest(c, info, request)
if err != nil { if err != nil {
return err, nil return err, nil
} }
@@ -121,7 +115,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
} }
requestBody := bytes.NewBuffer(jsonData) requestBody := bytes.NewBuffer(jsonData)
c.Request.Body = io.NopCloser(requestBody) c.Request.Body = io.NopCloser(requestBody)
resp, err := adaptor.DoRequest(c, meta, requestBody) resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil { if err != nil {
return err, nil return err, nil
} }
@@ -133,7 +127,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
} }
} }
usageA, respErr := adaptor.DoResponse(c, httpResp, meta) usageA, respErr := adaptor.DoResponse(c, httpResp, info)
if respErr != nil { if respErr != nil {
return fmt.Errorf("%s", respErr.Error.Message), respErr return fmt.Errorf("%s", respErr.Error.Message), respErr
} }
@@ -146,27 +140,24 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
if err != nil { if err != nil {
return err, nil return err, nil
} }
modelPrice, usePrice := setting.GetModelPrice(testModel, false) priceData, err := helper.ModelPriceHelper(c, info, usage.PromptTokens, int(request.MaxTokens))
modelRatio, success := setting.GetModelRatio(testModel) if err != nil {
if !usePrice && !success { return err, nil
return fmt.Errorf("模型 %s 倍率和价格均未设置,请设置或者开启自用模式", testModel), nil
} }
completionRatio := setting.GetCompletionRatio(testModel)
ratio := modelRatio
quota := 0 quota := 0
if !usePrice { if !priceData.UsePrice {
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*completionRatio)) quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
quota = int(math.Round(float64(quota) * ratio)) quota = int(math.Round(float64(quota) * priceData.ModelRatio))
if ratio != 0 && quota <= 0 { if priceData.ModelRatio != 0 && quota <= 0 {
quota = 1 quota = 1
} }
} else { } else {
quota = int(modelPrice * common.QuotaPerUnit) quota = int(priceData.ModelPrice * common.QuotaPerUnit)
} }
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
consumedTime := float64(milliseconds) / 1000.0 consumedTime := float64(milliseconds) / 1000.0
other := service.GenerateTextOtherInfo(c, meta, modelRatio, 1, completionRatio, modelPrice) other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio, priceData.ModelPrice)
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试", model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试",
quota, "模型测试", 0, quota, int(consumedTime), false, "default", other) quota, "模型测试", 0, quota, int(consumedTime), false, "default", other)
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))

View File

@@ -11,6 +11,7 @@ import (
type PriceData struct { type PriceData struct {
ModelPrice float64 ModelPrice float64
ModelRatio float64 ModelRatio float64
CompletionRatio float64
GroupRatio float64 GroupRatio float64
UsePrice bool UsePrice bool
ShouldPreConsumedQuota int ShouldPreConsumedQuota int
@@ -21,6 +22,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
groupRatio := setting.GetGroupRatio(info.Group) groupRatio := setting.GetGroupRatio(info.Group)
var preConsumedQuota int var preConsumedQuota int
var modelRatio float64 var modelRatio float64
var completionRatio float64
if !usePrice { if !usePrice {
preConsumedTokens := common.PreConsumedQuota preConsumedTokens := common.PreConsumedQuota
if maxTokens != 0 { if maxTokens != 0 {
@@ -35,6 +37,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置, 请联系管理员设置Model %s ratio or price not set, please contact administrator to set", info.OriginModelName, info.OriginModelName) return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置, 请联系管理员设置Model %s ratio or price not set, please contact administrator to set", info.OriginModelName, info.OriginModelName)
} }
} }
completionRatio = setting.GetCompletionRatio(info.OriginModelName)
ratio := modelRatio * groupRatio ratio := modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio) preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else { } else {
@@ -43,6 +46,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
return PriceData{ return PriceData{
ModelPrice: modelPrice, ModelPrice: modelPrice,
ModelRatio: modelRatio, ModelRatio: modelRatio,
CompletionRatio: completionRatio,
GroupRatio: groupRatio, GroupRatio: groupRatio,
UsePrice: usePrice, UsePrice: usePrice,
ShouldPreConsumedQuota: preConsumedQuota, ShouldPreConsumedQuota: preConsumedQuota,