feat: 修复重试后请求结构混乱,修复rerank端点无法使用

This commit is contained in:
CaIon
2025-08-23 13:12:15 +08:00
parent e581422810
commit 4f23e53002
20 changed files with 273 additions and 106 deletions

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
@@ -16,12 +17,17 @@ import (
func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
info.InitChannelMeta(c)
audioRequest, ok := info.Request.(*dto.AudioRequest)
audioReq, ok := info.Request.(*dto.AudioRequest)
if !ok {
return types.NewError(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
err := helper.ModelMappedHelper(c, info, audioRequest)
request, err := common.DeepCopy(audioReq)
if err != nil {
return types.NewError(fmt.Errorf("failed to copy request to AudioRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
err = helper.ModelMappedHelper(c, info, request)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
@@ -32,7 +38,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
}
adaptor.Init(info)
ioReader, err := adaptor.ConvertAudioRequest(c, info, *audioRequest)
ioReader, err := adaptor.ConvertAudioRequest(c, info, *request)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}

View File

@@ -21,13 +21,18 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
info.InitChannelMeta(c)
textRequest, ok := info.Request.(*dto.ClaudeRequest)
claudeReq, ok := info.Request.(*dto.ClaudeRequest)
if !ok {
common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.ClaudeRequest, got %T", info.Request))
}
err := helper.ModelMappedHelper(c, info, textRequest)
request, err := common.DeepCopy(claudeReq)
if err != nil {
return types.NewError(fmt.Errorf("failed to copy request to ClaudeRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
err = helper.ModelMappedHelper(c, info, request)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
@@ -38,30 +43,30 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
}
adaptor.Init(info)
if textRequest.MaxTokens == 0 {
textRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
if request.MaxTokens == 0 {
request.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(request.Model))
}
if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
strings.HasSuffix(textRequest.Model, "-thinking") {
if textRequest.Thinking == nil {
strings.HasSuffix(request.Model, "-thinking") {
if request.Thinking == nil {
// 因为BudgetTokens 必须大于1024
if textRequest.MaxTokens < 1280 {
textRequest.MaxTokens = 1280
if request.MaxTokens < 1280 {
request.MaxTokens = 1280
}
// BudgetTokens 为 max_tokens 的 80%
textRequest.Thinking = &dto.Thinking{
request.Thinking = &dto.Thinking{
Type: "enabled",
BudgetTokens: common.GetPointer[int](int(float64(textRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
BudgetTokens: common.GetPointer[int](int(float64(request.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
}
// TODO: 临时处理
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
textRequest.TopP = 0
textRequest.Temperature = common.GetPointer[float64](1.0)
request.TopP = 0
request.Temperature = common.GetPointer[float64](1.0)
}
textRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
info.UpstreamModelName = textRequest.Model
request.Model = strings.TrimSuffix(request.Model, "-thinking")
info.UpstreamModelName = request.Model
}
var requestBody io.Reader
@@ -72,7 +77,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
}
requestBody = bytes.NewBuffer(body)
} else {
convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, textRequest)
convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, request)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}

View File

@@ -158,7 +158,14 @@ func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
if streamSupportedChannels[channelMeta.ChannelType] {
channelMeta.SupportStreamOptions = true
}
info.ChannelMeta = channelMeta
// reset some fields based on channel meta
// 重置某些字段,例如模型名称等
if info.Request != nil {
info.Request.SetModelName(info.OriginModelName)
}
}
func (info *RelayInfo) ToString() string {

View File

@@ -16,15 +16,19 @@ import (
)
func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
info.InitChannelMeta(c)
embeddingRequest, ok := info.Request.(*dto.EmbeddingRequest)
embeddingReq, ok := info.Request.(*dto.EmbeddingRequest)
if !ok {
common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.EmbeddingRequest, got %T", info.Request))
}
err := helper.ModelMappedHelper(c, info, embeddingRequest)
request, err := common.DeepCopy(embeddingReq)
if err != nil {
return types.NewError(fmt.Errorf("failed to copy request to EmbeddingRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
err = helper.ModelMappedHelper(c, info, request)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
@@ -35,7 +39,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
}
adaptor.Init(info)
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, info, *embeddingRequest)
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, info, *request)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}

View File

@@ -53,13 +53,18 @@ func trimModelThinking(modelName string) string {
func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
info.InitChannelMeta(c)
request, ok := info.Request.(*dto.GeminiChatRequest)
geminiReq, ok := info.Request.(*dto.GeminiChatRequest)
if !ok {
common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.GeminiChatRequest, got %T", info.Request))
}
request, err := common.DeepCopy(geminiReq)
if err != nil {
return types.NewError(fmt.Errorf("failed to copy request to GeminiChatRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
// model mapped 模型映射
err := helper.ModelMappedHelper(c, info, request)
err = helper.ModelMappedHelper(c, info, request)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
@@ -170,7 +175,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents")
info.IsGeminiBatchEmbedding = isBatch
var req any
var req dto.Request
var err error
var inputTexts []string

View File

@@ -4,15 +4,12 @@ import (
"encoding/json"
"errors"
"fmt"
"one-api/dto"
common2 "one-api/logger"
"one-api/relay/common"
"one-api/types"
"github.com/gin-gonic/gin"
"one-api/dto"
"one-api/relay/common"
)
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) error {
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request dto.Request) error {
// map model name
modelMapping := c.GetString("model_mapping")
if modelMapping != "" && modelMapping != "{}" {
@@ -54,40 +51,7 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) erro
}
}
if request != nil {
switch info.RelayFormat {
case types.RelayFormatGemini:
// Gemini 模型映射
case types.RelayFormatClaude:
if claudeRequest, ok := request.(*dto.ClaudeRequest); ok {
claudeRequest.Model = info.UpstreamModelName
}
case types.RelayFormatOpenAIResponses:
if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok {
openAIResponsesRequest.Model = info.UpstreamModelName
}
case types.RelayFormatOpenAIAudio:
if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok {
openAIAudioRequest.Model = info.UpstreamModelName
}
case types.RelayFormatOpenAIImage:
if imageRequest, ok := request.(*dto.ImageRequest); ok {
imageRequest.Model = info.UpstreamModelName
}
case types.RelayFormatRerank:
if rerankRequest, ok := request.(*dto.RerankRequest); ok {
rerankRequest.Model = info.UpstreamModelName
}
case types.RelayFormatEmbedding:
if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok {
embeddingRequest.Model = info.UpstreamModelName
}
default:
if openAIRequest, ok := request.(*dto.GeneralOpenAIRequest); ok {
openAIRequest.Model = info.UpstreamModelName
} else {
common2.LogWarn(c, fmt.Sprintf("model mapped but request type %T not supported", request))
}
}
request.SetModelName(info.UpstreamModelName)
}
return nil
}

View File

@@ -20,16 +20,19 @@ import (
)
func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
info.InitChannelMeta(c)
imageRequest, ok := info.Request.(*dto.ImageRequest)
imageReq, ok := info.Request.(*dto.ImageRequest)
if !ok {
common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ImageRequest, got %T", info.Request))
}
err := helper.ModelMappedHelper(c, info, imageRequest)
request, err := common.DeepCopy(imageReq)
if err != nil {
return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
err = helper.ModelMappedHelper(c, info, request)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
@@ -49,7 +52,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
}
requestBody = bytes.NewBuffer(body)
} else {
convertedRequest, err := adaptor.ConvertImageRequest(c, info, *imageRequest)
convertedRequest, err := adaptor.ConvertImageRequest(c, info, *request)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
}
@@ -102,21 +105,21 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
}
if usage.(*dto.Usage).TotalTokens == 0 {
usage.(*dto.Usage).TotalTokens = int(imageRequest.N)
usage.(*dto.Usage).TotalTokens = int(request.N)
}
if usage.(*dto.Usage).PromptTokens == 0 {
usage.(*dto.Usage).PromptTokens = int(imageRequest.N)
usage.(*dto.Usage).PromptTokens = int(request.N)
}
quality := "standard"
if imageRequest.Quality == "hd" {
if request.Quality == "hd" {
quality = "hd"
}
var logContent string
if len(imageRequest.Size) > 0 {
logContent = fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
if len(request.Size) > 0 {
logContent = fmt.Sprintf("大小 %s, 品质 %s", request.Size, quality)
}
postConsumeQuota(c, info, usage.(*dto.Usage), logContent)

View File

@@ -25,38 +25,41 @@ import (
)
func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
info.InitChannelMeta(c)
textRequest, ok := info.Request.(*dto.GeneralOpenAIRequest)
textReq, ok := info.Request.(*dto.GeneralOpenAIRequest)
if !ok {
//return types.NewErrorWithStatusCode(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
common.FatalLog("invalid request type, expected dto.GeneralOpenAIRequest, got %T", info.Request)
}
if textRequest.WebSearchOptions != nil {
c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
request, err := common.DeepCopy(textReq)
if err != nil {
return types.NewError(fmt.Errorf("failed to copy request to GeneralOpenAIRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
err := helper.ModelMappedHelper(c, info, textRequest)
if request.WebSearchOptions != nil {
c.Set("chat_completion_web_search_context_size", request.WebSearchOptions.SearchContextSize)
}
err = helper.ModelMappedHelper(c, info, request)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
includeUsage := true
// 判断用户是否需要返回使用情况
if textRequest.StreamOptions != nil {
includeUsage = textRequest.StreamOptions.IncludeUsage
if request.StreamOptions != nil {
includeUsage = request.StreamOptions.IncludeUsage
}
// 如果不支持StreamOptions将StreamOptions设置为nil
if !info.SupportStreamOptions || !textRequest.Stream {
textRequest.StreamOptions = nil
if !info.SupportStreamOptions || !request.Stream {
request.StreamOptions = nil
} else {
// 如果支持StreamOptions且请求中没有设置StreamOptions根据配置文件设置StreamOptions
if constant.ForceStreamOption {
textRequest.StreamOptions = &dto.StreamOptions{
request.StreamOptions = &dto.StreamOptions{
IncludeUsage: true,
}
}
@@ -81,7 +84,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
}
requestBody = bytes.NewBuffer(body)
} else {
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, textRequest)
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}

View File

@@ -16,23 +16,20 @@ import (
"github.com/gin-gonic/gin"
)
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
token := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
for _, document := range rerankRequest.Documents {
tkm := service.CountTokenInput(document, rerankRequest.Model)
token += tkm
}
return token
}
func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
info.InitChannelMeta(c)
rerankRequest, ok := info.Request.(*dto.RerankRequest)
rerankReq, ok := info.Request.(*dto.RerankRequest)
if !ok {
common.FatalLog(fmt.Sprintf("invalid request type, expected dto.RerankRequest, got %T", info.Request))
}
err := helper.ModelMappedHelper(c, info, rerankRequest)
request, err := common.DeepCopy(rerankReq)
if err != nil {
return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
err = helper.ModelMappedHelper(c, info, request)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
@@ -51,7 +48,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
}
requestBody = bytes.NewBuffer(body)
} else {
convertedRequest, err := adaptor.ConvertRerankRequest(c, info.RelayMode, *rerankRequest)
convertedRequest, err := adaptor.ConvertRerankRequest(c, info.RelayMode, *request)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}