diff --git a/controller/channel-test.go b/controller/channel-test.go index d162d8cf..26c97056 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -90,7 +90,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr info := relaycommon.GenRelayInfo(c) - err = helper.ModelMappedHelper(c, info) + err = helper.ModelMappedHelper(c, info, nil) if err != nil { return err, nil } diff --git a/relay/relay-audio.go b/relay/audio_handler.go similarity index 96% rename from relay/relay-audio.go rename to relay/audio_handler.go index deb45c58..e55de042 100644 --- a/relay/relay-audio.go +++ b/relay/audio_handler.go @@ -55,7 +55,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. } func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { - relayInfo := relaycommon.GenRelayInfo(c) + relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c) audioRequest, err := getAndValidAudioRequest(c, relayInfo) if err != nil { @@ -89,13 +89,11 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { } }() - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, audioRequest) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - audioRequest.Model = relayInfo.UpstreamModelName - adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index a81eb3a9..968d9c9b 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -73,12 +73,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { // 新增逻辑:处理 -thinking- 格式 - if strings.Contains(info.OriginModelName, "-thinking-") { + if strings.Contains(info.UpstreamModelName, "-thinking-") { parts := strings.Split(info.UpstreamModelName, "-thinking-") info.UpstreamModelName = parts[0] - } else if strings.HasSuffix(info.OriginModelName, "-thinking") { // 旧的适配 + } else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配 info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking") - } else if strings.HasSuffix(info.OriginModelName, "-nothinking") { + } else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") { info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking") } } diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index d4b7c209..ef2c35be 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -99,7 +99,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon } if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { - modelName := info.OriginModelName + modelName := info.UpstreamModelName isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") && !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") && !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25") diff --git a/relay/claude_handler.go b/relay/claude_handler.go index 567378fb..42139ddf 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -46,13 +46,11 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) { relayInfo.IsStream = true } - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, textRequest) if err != nil { return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - textRequest.Model = relayInfo.UpstreamModelName - promptTokens, err := getClaudePromptTokens(textRequest, relayInfo) // count messages token error 计算promptTokens错误 if err != nil { diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index a842a58d..3759c363 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -34,9 +34,14 @@ type ClaudeConvertInfo struct { } const ( - RelayFormatOpenAI = "openai" - RelayFormatClaude = "claude" - RelayFormatGemini = "gemini" + RelayFormatOpenAI = "openai" + RelayFormatClaude = "claude" + RelayFormatGemini = "gemini" + RelayFormatOpenAIResponses = "openai_responses" + RelayFormatOpenAIAudio = "openai_audio" + RelayFormatOpenAIImage = "openai_image" + RelayFormatRerank = "rerank" + RelayFormatEmbedding = "embedding" ) type RerankerInfo struct { @@ -143,6 +148,7 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo { func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo { info := GenRelayInfo(c) info.RelayMode = relayconstant.RelayModeRerank + info.RelayFormat = RelayFormatRerank info.RerankerInfo = &RerankerInfo{ Documents: req.Documents, ReturnDocuments: req.GetReturnDocuments(), @@ -150,9 +156,25 @@ func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo { return info } +func GenRelayInfoOpenAIAudio(c *gin.Context) *RelayInfo { + info := GenRelayInfo(c) + info.RelayFormat = RelayFormatOpenAIAudio + return info +} + +func GenRelayInfoEmbedding(c *gin.Context) *RelayInfo { + info := GenRelayInfo(c) + info.RelayFormat = RelayFormatEmbedding + return info +} + func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo { info := GenRelayInfo(c) info.RelayMode = relayconstant.RelayModeResponses + info.RelayFormat = RelayFormatOpenAIResponses + + info.SupportStreamOptions = false + info.ResponsesUsageInfo = &ResponsesUsageInfo{ BuiltInTools: make(map[string]*BuildInToolInfo), } @@ -175,6 +197,19 @@ func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *Rel return info } +func GenRelayInfoGemini(c *gin.Context) *RelayInfo { + info := GenRelayInfo(c) + info.RelayFormat = RelayFormatGemini + info.ShouldIncludeUsage = false + return info +} + +func GenRelayInfoImage(c *gin.Context) *RelayInfo { + info := GenRelayInfo(c) + info.RelayFormat = RelayFormatOpenAIImage + return info +} + func GenRelayInfo(c *gin.Context) *RelayInfo { channelType := c.GetInt("channel_type") channelId := c.GetInt("channel_id") @@ -243,10 +278,6 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { if streamSupportedChannels[info.ChannelType] { info.SupportStreamOptions = true } - // responses 模式不支持 StreamOptions - if relayconstant.RelayModeResponses == info.RelayMode { - info.SupportStreamOptions = false - } return info } diff --git a/relay/relay_embedding.go b/relay/embedding_handler.go similarity index 96% rename from relay/relay_embedding.go rename to relay/embedding_handler.go index b4909849..fbf4990a 100644 --- a/relay/relay_embedding.go +++ b/relay/embedding_handler.go @@ -33,7 +33,7 @@ func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embed } func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { - relayInfo := relaycommon.GenRelayInfo(c) + relayInfo := relaycommon.GenRelayInfoEmbedding(c) var embeddingRequest *dto.EmbeddingRequest err := common.UnmarshalBodyReusable(c, &embeddingRequest) @@ -47,13 +47,11 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest) } - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - embeddingRequest.Model = relayInfo.UpstreamModelName - promptToken := getEmbeddingPromptToken(*embeddingRequest) relayInfo.PromptTokens = promptToken diff --git a/relay/relay-gemini.go b/relay/gemini_handler.go similarity index 98% rename from relay/relay-gemini.go rename to relay/gemini_handler.go index 455b31b7..fa41cc7b 100644 --- a/relay/relay-gemini.go +++ b/relay/gemini_handler.go @@ -83,7 +83,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest) } - relayInfo := relaycommon.GenRelayInfo(c) + relayInfo := relaycommon.GenRelayInfoGemini(c) // 检查 Gemini 流式模式 checkGeminiStreamMode(c, relayInfo) @@ -97,7 +97,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { } // model mapped 模型映射 - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, req) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest) } diff --git a/relay/helper/model_mapped.go b/relay/helper/model_mapped.go index 9bf67c03..c1735149 100644 --- a/relay/helper/model_mapped.go +++ b/relay/helper/model_mapped.go @@ -4,12 +4,14 @@ import ( "encoding/json" "errors" "fmt" + common2 "one-api/common" + "one-api/dto" "one-api/relay/common" "github.com/gin-gonic/gin" ) -func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error { +func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) error { // map model name modelMapping := c.GetString("model_mapping") if modelMapping != "" && modelMapping != "{}" { @@ -50,5 +52,41 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error { info.UpstreamModelName = currentModel } } + if request != nil { + switch info.RelayFormat { + case common.RelayFormatGemini: + // Gemini 模型映射 + case common.RelayFormatClaude: + if claudeRequest, ok := request.(*dto.ClaudeRequest); ok { + claudeRequest.Model = info.UpstreamModelName + } + case common.RelayFormatOpenAIResponses: + if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok { + openAIResponsesRequest.Model = info.UpstreamModelName + } + case common.RelayFormatOpenAIAudio: + if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok { + openAIAudioRequest.Model = info.UpstreamModelName + } + case common.RelayFormatOpenAIImage: + if imageRequest, ok := request.(*dto.ImageRequest); ok { + imageRequest.Model = info.UpstreamModelName + } + case common.RelayFormatRerank: + if rerankRequest, ok := request.(*dto.RerankRequest); ok { + rerankRequest.Model = info.UpstreamModelName + } + case common.RelayFormatEmbedding: + if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok { + embeddingRequest.Model = info.UpstreamModelName + } + default: + if openAIRequest, ok := request.(*dto.GeneralOpenAIRequest); ok { + openAIRequest.Model = info.UpstreamModelName + } else { + common2.LogWarn(c, fmt.Sprintf("model mapped but request type %T not supported", request)) + } + } + } return nil } diff --git a/relay/relay-image.go b/relay/image_handler.go similarity index 98% rename from relay/relay-image.go rename to relay/image_handler.go index 197a8af6..57917025 100644 --- a/relay/relay-image.go +++ b/relay/image_handler.go @@ -102,7 +102,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. } func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { - relayInfo := relaycommon.GenRelayInfo(c) + relayInfo := relaycommon.GenRelayInfoImage(c) imageRequest, err := getAndValidImageRequest(c, relayInfo) if err != nil { @@ -110,13 +110,11 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) } - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, imageRequest) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - imageRequest.Model = relayInfo.UpstreamModelName - priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) diff --git a/relay/relay-text.go b/relay/relay-text.go index 24fb8155..bf5a0259 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -108,13 +108,11 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { } } - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, textRequest) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - textRequest.Model = relayInfo.UpstreamModelName - // 获取 promptTokens,如果上下文中已经存在,则直接使用 var promptTokens int if value, exists := c.Get("prompt_tokens"); exists { diff --git a/relay/relay_rerank.go b/relay/rerank_handler.go similarity index 97% rename from relay/relay_rerank.go rename to relay/rerank_handler.go index 6ca98de7..4d02c84f 100644 --- a/relay/relay_rerank.go +++ b/relay/rerank_handler.go @@ -42,13 +42,11 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest) } - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, rerankRequest) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - rerankRequest.Model = relayInfo.UpstreamModelName - promptToken := getRerankPromptToken(*rerankRequest) relayInfo.PromptTokens = promptToken diff --git a/relay/relay-responses.go b/relay/responses_handler.go similarity index 98% rename from relay/relay-responses.go rename to relay/responses_handler.go index fd3ddb5a..8e8a3451 100644 --- a/relay/relay-responses.go +++ b/relay/responses_handler.go @@ -63,11 +63,11 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) } } - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, req) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest) } - req.Model = relayInfo.UpstreamModelName + if value, exists := c.Get("prompt_tokens"); exists { promptTokens := value.(int) relayInfo.SetPromptTokens(promptTokens)