feat: 修复重试后请求结构混乱,修复rerank端点无法使用
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
@@ -16,12 +17,17 @@ import (
|
||||
func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
audioRequest, ok := info.Request.(*dto.AudioRequest)
|
||||
audioReq, ok := info.Request.(*dto.AudioRequest)
|
||||
if !ok {
|
||||
return types.NewError(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err := helper.ModelMappedHelper(c, info, audioRequest)
|
||||
request, err := common.DeepCopy(audioReq)
|
||||
if err != nil {
|
||||
return types.NewError(fmt.Errorf("failed to copy request to AudioRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -32,7 +38,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
}
|
||||
adaptor.Init(info)
|
||||
|
||||
ioReader, err := adaptor.ConvertAudioRequest(c, info, *audioRequest)
|
||||
ioReader, err := adaptor.ConvertAudioRequest(c, info, *request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -21,13 +21,18 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
textRequest, ok := info.Request.(*dto.ClaudeRequest)
|
||||
claudeReq, ok := info.Request.(*dto.ClaudeRequest)
|
||||
|
||||
if !ok {
|
||||
common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.ClaudeRequest, got %T", info.Request))
|
||||
}
|
||||
|
||||
err := helper.ModelMappedHelper(c, info, textRequest)
|
||||
request, err := common.DeepCopy(claudeReq)
|
||||
if err != nil {
|
||||
return types.NewError(fmt.Errorf("failed to copy request to ClaudeRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -38,30 +43,30 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
}
|
||||
adaptor.Init(info)
|
||||
|
||||
if textRequest.MaxTokens == 0 {
|
||||
textRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
|
||||
if request.MaxTokens == 0 {
|
||||
request.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(request.Model))
|
||||
}
|
||||
|
||||
if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
|
||||
strings.HasSuffix(textRequest.Model, "-thinking") {
|
||||
if textRequest.Thinking == nil {
|
||||
strings.HasSuffix(request.Model, "-thinking") {
|
||||
if request.Thinking == nil {
|
||||
// 因为BudgetTokens 必须大于1024
|
||||
if textRequest.MaxTokens < 1280 {
|
||||
textRequest.MaxTokens = 1280
|
||||
if request.MaxTokens < 1280 {
|
||||
request.MaxTokens = 1280
|
||||
}
|
||||
|
||||
// BudgetTokens 为 max_tokens 的 80%
|
||||
textRequest.Thinking = &dto.Thinking{
|
||||
request.Thinking = &dto.Thinking{
|
||||
Type: "enabled",
|
||||
BudgetTokens: common.GetPointer[int](int(float64(textRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
|
||||
BudgetTokens: common.GetPointer[int](int(float64(request.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
|
||||
}
|
||||
// TODO: 临时处理
|
||||
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
|
||||
textRequest.TopP = 0
|
||||
textRequest.Temperature = common.GetPointer[float64](1.0)
|
||||
request.TopP = 0
|
||||
request.Temperature = common.GetPointer[float64](1.0)
|
||||
}
|
||||
textRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
|
||||
info.UpstreamModelName = textRequest.Model
|
||||
request.Model = strings.TrimSuffix(request.Model, "-thinking")
|
||||
info.UpstreamModelName = request.Model
|
||||
}
|
||||
|
||||
var requestBody io.Reader
|
||||
@@ -72,7 +77,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, textRequest)
|
||||
convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -158,7 +158,14 @@ func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
|
||||
if streamSupportedChannels[channelMeta.ChannelType] {
|
||||
channelMeta.SupportStreamOptions = true
|
||||
}
|
||||
|
||||
info.ChannelMeta = channelMeta
|
||||
|
||||
// reset some fields based on channel meta
|
||||
// 重置某些字段,例如模型名称等
|
||||
if info.Request != nil {
|
||||
info.Request.SetModelName(info.OriginModelName)
|
||||
}
|
||||
}
|
||||
|
||||
func (info *RelayInfo) ToString() string {
|
||||
|
||||
@@ -16,15 +16,19 @@ import (
|
||||
)
|
||||
|
||||
func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
embeddingRequest, ok := info.Request.(*dto.EmbeddingRequest)
|
||||
embeddingReq, ok := info.Request.(*dto.EmbeddingRequest)
|
||||
if !ok {
|
||||
common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.EmbeddingRequest, got %T", info.Request))
|
||||
}
|
||||
|
||||
err := helper.ModelMappedHelper(c, info, embeddingRequest)
|
||||
request, err := common.DeepCopy(embeddingReq)
|
||||
if err != nil {
|
||||
return types.NewError(fmt.Errorf("failed to copy request to EmbeddingRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -35,7 +39,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
}
|
||||
adaptor.Init(info)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, info, *embeddingRequest)
|
||||
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, info, *request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -53,13 +53,18 @@ func trimModelThinking(modelName string) string {
|
||||
func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
request, ok := info.Request.(*dto.GeminiChatRequest)
|
||||
geminiReq, ok := info.Request.(*dto.GeminiChatRequest)
|
||||
if !ok {
|
||||
common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.GeminiChatRequest, got %T", info.Request))
|
||||
}
|
||||
|
||||
request, err := common.DeepCopy(geminiReq)
|
||||
if err != nil {
|
||||
return types.NewError(fmt.Errorf("failed to copy request to GeminiChatRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// model mapped 模型映射
|
||||
err := helper.ModelMappedHelper(c, info, request)
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -170,7 +175,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
|
||||
isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents")
|
||||
info.IsGeminiBatchEmbedding = isBatch
|
||||
|
||||
var req any
|
||||
var req dto.Request
|
||||
var err error
|
||||
var inputTexts []string
|
||||
|
||||
|
||||
@@ -4,15 +4,12 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/dto"
|
||||
common2 "one-api/logger"
|
||||
"one-api/relay/common"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/dto"
|
||||
"one-api/relay/common"
|
||||
)
|
||||
|
||||
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) error {
|
||||
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request dto.Request) error {
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
if modelMapping != "" && modelMapping != "{}" {
|
||||
@@ -54,40 +51,7 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) erro
|
||||
}
|
||||
}
|
||||
if request != nil {
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatGemini:
|
||||
// Gemini 模型映射
|
||||
case types.RelayFormatClaude:
|
||||
if claudeRequest, ok := request.(*dto.ClaudeRequest); ok {
|
||||
claudeRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case types.RelayFormatOpenAIResponses:
|
||||
if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok {
|
||||
openAIResponsesRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case types.RelayFormatOpenAIAudio:
|
||||
if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok {
|
||||
openAIAudioRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case types.RelayFormatOpenAIImage:
|
||||
if imageRequest, ok := request.(*dto.ImageRequest); ok {
|
||||
imageRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case types.RelayFormatRerank:
|
||||
if rerankRequest, ok := request.(*dto.RerankRequest); ok {
|
||||
rerankRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case types.RelayFormatEmbedding:
|
||||
if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok {
|
||||
embeddingRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
default:
|
||||
if openAIRequest, ok := request.(*dto.GeneralOpenAIRequest); ok {
|
||||
openAIRequest.Model = info.UpstreamModelName
|
||||
} else {
|
||||
common2.LogWarn(c, fmt.Sprintf("model mapped but request type %T not supported", request))
|
||||
}
|
||||
}
|
||||
request.SetModelName(info.UpstreamModelName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -20,16 +20,19 @@ import (
|
||||
)
|
||||
|
||||
func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
imageRequest, ok := info.Request.(*dto.ImageRequest)
|
||||
|
||||
imageReq, ok := info.Request.(*dto.ImageRequest)
|
||||
if !ok {
|
||||
common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ImageRequest, got %T", info.Request))
|
||||
}
|
||||
|
||||
err := helper.ModelMappedHelper(c, info, imageRequest)
|
||||
request, err := common.DeepCopy(imageReq)
|
||||
if err != nil {
|
||||
return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -49,7 +52,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertImageRequest(c, info, *imageRequest)
|
||||
convertedRequest, err := adaptor.ConvertImageRequest(c, info, *request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
@@ -102,21 +105,21 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
}
|
||||
|
||||
if usage.(*dto.Usage).TotalTokens == 0 {
|
||||
usage.(*dto.Usage).TotalTokens = int(imageRequest.N)
|
||||
usage.(*dto.Usage).TotalTokens = int(request.N)
|
||||
}
|
||||
if usage.(*dto.Usage).PromptTokens == 0 {
|
||||
usage.(*dto.Usage).PromptTokens = int(imageRequest.N)
|
||||
usage.(*dto.Usage).PromptTokens = int(request.N)
|
||||
}
|
||||
|
||||
quality := "standard"
|
||||
if imageRequest.Quality == "hd" {
|
||||
if request.Quality == "hd" {
|
||||
quality = "hd"
|
||||
}
|
||||
|
||||
var logContent string
|
||||
|
||||
if len(imageRequest.Size) > 0 {
|
||||
logContent = fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
|
||||
if len(request.Size) > 0 {
|
||||
logContent = fmt.Sprintf("大小 %s, 品质 %s", request.Size, quality)
|
||||
}
|
||||
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), logContent)
|
||||
|
||||
@@ -25,38 +25,41 @@ import (
|
||||
)
|
||||
|
||||
func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
textRequest, ok := info.Request.(*dto.GeneralOpenAIRequest)
|
||||
|
||||
textReq, ok := info.Request.(*dto.GeneralOpenAIRequest)
|
||||
if !ok {
|
||||
//return types.NewErrorWithStatusCode(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
common.FatalLog("invalid request type, expected dto.GeneralOpenAIRequest, got %T", info.Request)
|
||||
}
|
||||
|
||||
if textRequest.WebSearchOptions != nil {
|
||||
c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
|
||||
request, err := common.DeepCopy(textReq)
|
||||
if err != nil {
|
||||
return types.NewError(fmt.Errorf("failed to copy request to GeneralOpenAIRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err := helper.ModelMappedHelper(c, info, textRequest)
|
||||
if request.WebSearchOptions != nil {
|
||||
c.Set("chat_completion_web_search_context_size", request.WebSearchOptions.SearchContextSize)
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
includeUsage := true
|
||||
// 判断用户是否需要返回使用情况
|
||||
if textRequest.StreamOptions != nil {
|
||||
includeUsage = textRequest.StreamOptions.IncludeUsage
|
||||
if request.StreamOptions != nil {
|
||||
includeUsage = request.StreamOptions.IncludeUsage
|
||||
}
|
||||
|
||||
// 如果不支持StreamOptions,将StreamOptions设置为nil
|
||||
if !info.SupportStreamOptions || !textRequest.Stream {
|
||||
textRequest.StreamOptions = nil
|
||||
if !info.SupportStreamOptions || !request.Stream {
|
||||
request.StreamOptions = nil
|
||||
} else {
|
||||
// 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions
|
||||
if constant.ForceStreamOption {
|
||||
textRequest.StreamOptions = &dto.StreamOptions{
|
||||
request.StreamOptions = &dto.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
}
|
||||
}
|
||||
@@ -81,7 +84,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, textRequest)
|
||||
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -16,23 +16,20 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
||||
token := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
|
||||
for _, document := range rerankRequest.Documents {
|
||||
tkm := service.CountTokenInput(document, rerankRequest.Model)
|
||||
token += tkm
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
rerankRequest, ok := info.Request.(*dto.RerankRequest)
|
||||
rerankReq, ok := info.Request.(*dto.RerankRequest)
|
||||
if !ok {
|
||||
common.FatalLog(fmt.Sprintf("invalid request type, expected dto.RerankRequest, got %T", info.Request))
|
||||
}
|
||||
|
||||
err := helper.ModelMappedHelper(c, info, rerankRequest)
|
||||
request, err := common.DeepCopy(rerankReq)
|
||||
if err != nil {
|
||||
return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -51,7 +48,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertRerankRequest(c, info.RelayMode, *rerankRequest)
|
||||
convertedRequest, err := adaptor.ConvertRerankRequest(c, info.RelayMode, *request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user