refactor: Improve channel testing and model price handling
This commit is contained in:
@@ -17,8 +17,8 @@ import (
|
|||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -73,18 +73,6 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
modelMapping := *channel.ModelMapping
|
|
||||||
if modelMapping != "" && modelMapping != "{}" {
|
|
||||||
modelMap := make(map[string]string)
|
|
||||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
|
||||||
if err != nil {
|
|
||||||
return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if modelMap[testModel] != "" {
|
|
||||||
testModel = modelMap[testModel]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cache, err := model.GetUserCache(1)
|
cache, err := model.GetUserCache(1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
@@ -98,7 +86,13 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|||||||
|
|
||||||
middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
||||||
|
|
||||||
meta := relaycommon.GenRelayInfo(c)
|
info := relaycommon.GenRelayInfo(c)
|
||||||
|
|
||||||
|
err = helper.ModelMappedHelper(c, info)
|
||||||
|
if err != nil {
|
||||||
|
return err, nil
|
||||||
|
}
|
||||||
|
|
||||||
apiType, _ := constant.ChannelType2APIType(channel.Type)
|
apiType, _ := constant.ChannelType2APIType(channel.Type)
|
||||||
adaptor := relay.GetAdaptor(apiType)
|
adaptor := relay.GetAdaptor(apiType)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
@@ -106,12 +100,12 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|||||||
}
|
}
|
||||||
|
|
||||||
request := buildTestRequest(testModel)
|
request := buildTestRequest(testModel)
|
||||||
meta.UpstreamModelName = testModel
|
info.OriginModelName = testModel
|
||||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %v ", channel.Id, testModel, meta))
|
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %v ", channel.Id, testModel, info))
|
||||||
|
|
||||||
adaptor.Init(meta)
|
adaptor.Init(info)
|
||||||
|
|
||||||
convertedRequest, err := adaptor.ConvertRequest(c, meta, request)
|
convertedRequest, err := adaptor.ConvertRequest(c, info, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
@@ -121,7 +115,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|||||||
}
|
}
|
||||||
requestBody := bytes.NewBuffer(jsonData)
|
requestBody := bytes.NewBuffer(jsonData)
|
||||||
c.Request.Body = io.NopCloser(requestBody)
|
c.Request.Body = io.NopCloser(requestBody)
|
||||||
resp, err := adaptor.DoRequest(c, meta, requestBody)
|
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
@@ -133,7 +127,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|||||||
return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
|
return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
usageA, respErr := adaptor.DoResponse(c, httpResp, meta)
|
usageA, respErr := adaptor.DoResponse(c, httpResp, info)
|
||||||
if respErr != nil {
|
if respErr != nil {
|
||||||
return fmt.Errorf("%s", respErr.Error.Message), respErr
|
return fmt.Errorf("%s", respErr.Error.Message), respErr
|
||||||
}
|
}
|
||||||
@@ -146,27 +140,24 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
modelPrice, usePrice := setting.GetModelPrice(testModel, false)
|
priceData, err := helper.ModelPriceHelper(c, info, usage.PromptTokens, int(request.MaxTokens))
|
||||||
modelRatio, success := setting.GetModelRatio(testModel)
|
if err != nil {
|
||||||
if !usePrice && !success {
|
return err, nil
|
||||||
return fmt.Errorf("模型 %s 倍率和价格均未设置,请设置或者开启自用模式", testModel), nil
|
|
||||||
}
|
}
|
||||||
completionRatio := setting.GetCompletionRatio(testModel)
|
|
||||||
ratio := modelRatio
|
|
||||||
quota := 0
|
quota := 0
|
||||||
if !usePrice {
|
if !priceData.UsePrice {
|
||||||
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*completionRatio))
|
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
|
||||||
quota = int(math.Round(float64(quota) * ratio))
|
quota = int(math.Round(float64(quota) * priceData.ModelRatio))
|
||||||
if ratio != 0 && quota <= 0 {
|
if priceData.ModelRatio != 0 && quota <= 0 {
|
||||||
quota = 1
|
quota = 1
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
quota = int(modelPrice * common.QuotaPerUnit)
|
quota = int(priceData.ModelPrice * common.QuotaPerUnit)
|
||||||
}
|
}
|
||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
consumedTime := float64(milliseconds) / 1000.0
|
consumedTime := float64(milliseconds) / 1000.0
|
||||||
other := service.GenerateTextOtherInfo(c, meta, modelRatio, 1, completionRatio, modelPrice)
|
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio, priceData.ModelPrice)
|
||||||
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试",
|
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试",
|
||||||
quota, "模型测试", 0, quota, int(consumedTime), false, "default", other)
|
quota, "模型测试", 0, quota, int(consumedTime), false, "default", other)
|
||||||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
type PriceData struct {
|
type PriceData struct {
|
||||||
ModelPrice float64
|
ModelPrice float64
|
||||||
ModelRatio float64
|
ModelRatio float64
|
||||||
|
CompletionRatio float64
|
||||||
GroupRatio float64
|
GroupRatio float64
|
||||||
UsePrice bool
|
UsePrice bool
|
||||||
ShouldPreConsumedQuota int
|
ShouldPreConsumedQuota int
|
||||||
@@ -21,6 +22,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
|||||||
groupRatio := setting.GetGroupRatio(info.Group)
|
groupRatio := setting.GetGroupRatio(info.Group)
|
||||||
var preConsumedQuota int
|
var preConsumedQuota int
|
||||||
var modelRatio float64
|
var modelRatio float64
|
||||||
|
var completionRatio float64
|
||||||
if !usePrice {
|
if !usePrice {
|
||||||
preConsumedTokens := common.PreConsumedQuota
|
preConsumedTokens := common.PreConsumedQuota
|
||||||
if maxTokens != 0 {
|
if maxTokens != 0 {
|
||||||
@@ -35,6 +37,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
|||||||
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置, 请联系管理员设置;Model %s ratio or price not set, please contact administrator to set", info.OriginModelName, info.OriginModelName)
|
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置, 请联系管理员设置;Model %s ratio or price not set, please contact administrator to set", info.OriginModelName, info.OriginModelName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
completionRatio = setting.GetCompletionRatio(info.OriginModelName)
|
||||||
ratio := modelRatio * groupRatio
|
ratio := modelRatio * groupRatio
|
||||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||||
} else {
|
} else {
|
||||||
@@ -43,6 +46,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
|||||||
return PriceData{
|
return PriceData{
|
||||||
ModelPrice: modelPrice,
|
ModelPrice: modelPrice,
|
||||||
ModelRatio: modelRatio,
|
ModelRatio: modelRatio,
|
||||||
|
CompletionRatio: completionRatio,
|
||||||
GroupRatio: groupRatio,
|
GroupRatio: groupRatio,
|
||||||
UsePrice: usePrice,
|
UsePrice: usePrice,
|
||||||
ShouldPreConsumedQuota: preConsumedQuota,
|
ShouldPreConsumedQuota: preConsumedQuota,
|
||||||
|
|||||||
Reference in New Issue
Block a user