feat: implement new handlers for audio, image, embedding, and responses processing

- Added new handlers: AudioHelper, ImageHelper, EmbeddingHelper, and ResponsesHelper to manage respective requests.
- Updated ModelMappedHelper to accept request parameters for better model mapping.
- Enhanced error handling and validation across new handlers to ensure robust request processing.
- Introduced support for new relay formats in relay_info and updated relevant functions accordingly.
This commit is contained in:
CaIon
2025-06-20 16:02:23 +08:00
parent b087b20bac
commit d3286893c4
13 changed files with 95 additions and 38 deletions

View File

@@ -90,7 +90,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
info := relaycommon.GenRelayInfo(c) info := relaycommon.GenRelayInfo(c)
err = helper.ModelMappedHelper(c, info) err = helper.ModelMappedHelper(c, info, nil)
if err != nil { if err != nil {
return err, nil return err, nil
} }

View File

@@ -55,7 +55,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
} }
func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfo(c) relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c)
audioRequest, err := getAndValidAudioRequest(c, relayInfo) audioRequest, err := getAndValidAudioRequest(c, relayInfo)
if err != nil { if err != nil {
@@ -89,13 +89,11 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
} }
}() }()
err = helper.ModelMappedHelper(c, relayInfo) err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
} }
audioRequest.Model = relayInfo.UpstreamModelName
adaptor := GetAdaptor(relayInfo.ApiType) adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil { if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)

View File

@@ -73,12 +73,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
// 新增逻辑:处理 -thinking-<budget> 格式 // 新增逻辑:处理 -thinking-<budget> 格式
if strings.Contains(info.OriginModelName, "-thinking-") { if strings.Contains(info.UpstreamModelName, "-thinking-") {
parts := strings.Split(info.UpstreamModelName, "-thinking-") parts := strings.Split(info.UpstreamModelName, "-thinking-")
info.UpstreamModelName = parts[0] info.UpstreamModelName = parts[0]
} else if strings.HasSuffix(info.OriginModelName, "-thinking") { // 旧的适配 } else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking") info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") { } else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking") info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
} }
} }

View File

@@ -99,7 +99,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
} }
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
modelName := info.OriginModelName modelName := info.UpstreamModelName
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") && isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") && !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25") !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")

View File

@@ -46,13 +46,11 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
relayInfo.IsStream = true relayInfo.IsStream = true
} }
err = helper.ModelMappedHelper(c, relayInfo) err = helper.ModelMappedHelper(c, relayInfo, textRequest)
if err != nil { if err != nil {
return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
} }
textRequest.Model = relayInfo.UpstreamModelName
promptTokens, err := getClaudePromptTokens(textRequest, relayInfo) promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误 // count messages token error 计算promptTokens错误
if err != nil { if err != nil {

View File

@@ -34,9 +34,14 @@ type ClaudeConvertInfo struct {
} }
const ( const (
RelayFormatOpenAI = "openai" RelayFormatOpenAI = "openai"
RelayFormatClaude = "claude" RelayFormatClaude = "claude"
RelayFormatGemini = "gemini" RelayFormatGemini = "gemini"
RelayFormatOpenAIResponses = "openai_responses"
RelayFormatOpenAIAudio = "openai_audio"
RelayFormatOpenAIImage = "openai_image"
RelayFormatRerank = "rerank"
RelayFormatEmbedding = "embedding"
) )
type RerankerInfo struct { type RerankerInfo struct {
@@ -143,6 +148,7 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo { func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
info := GenRelayInfo(c) info := GenRelayInfo(c)
info.RelayMode = relayconstant.RelayModeRerank info.RelayMode = relayconstant.RelayModeRerank
info.RelayFormat = RelayFormatRerank
info.RerankerInfo = &RerankerInfo{ info.RerankerInfo = &RerankerInfo{
Documents: req.Documents, Documents: req.Documents,
ReturnDocuments: req.GetReturnDocuments(), ReturnDocuments: req.GetReturnDocuments(),
@@ -150,9 +156,25 @@ func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
return info return info
} }
func GenRelayInfoOpenAIAudio(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c)
info.RelayFormat = RelayFormatOpenAIAudio
return info
}
func GenRelayInfoEmbedding(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c)
info.RelayFormat = RelayFormatEmbedding
return info
}
func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo { func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo {
info := GenRelayInfo(c) info := GenRelayInfo(c)
info.RelayMode = relayconstant.RelayModeResponses info.RelayMode = relayconstant.RelayModeResponses
info.RelayFormat = RelayFormatOpenAIResponses
info.SupportStreamOptions = false
info.ResponsesUsageInfo = &ResponsesUsageInfo{ info.ResponsesUsageInfo = &ResponsesUsageInfo{
BuiltInTools: make(map[string]*BuildInToolInfo), BuiltInTools: make(map[string]*BuildInToolInfo),
} }
@@ -175,6 +197,19 @@ func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *Rel
return info return info
} }
func GenRelayInfoGemini(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c)
info.RelayFormat = RelayFormatGemini
info.ShouldIncludeUsage = false
return info
}
func GenRelayInfoImage(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c)
info.RelayFormat = RelayFormatOpenAIImage
return info
}
func GenRelayInfo(c *gin.Context) *RelayInfo { func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := c.GetInt("channel_type") channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
@@ -243,10 +278,6 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
if streamSupportedChannels[info.ChannelType] { if streamSupportedChannels[info.ChannelType] {
info.SupportStreamOptions = true info.SupportStreamOptions = true
} }
// responses 模式不支持 StreamOptions
if relayconstant.RelayModeResponses == info.RelayMode {
info.SupportStreamOptions = false
}
return info return info
} }

View File

@@ -33,7 +33,7 @@ func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embed
} }
func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfo(c) relayInfo := relaycommon.GenRelayInfoEmbedding(c)
var embeddingRequest *dto.EmbeddingRequest var embeddingRequest *dto.EmbeddingRequest
err := common.UnmarshalBodyReusable(c, &embeddingRequest) err := common.UnmarshalBodyReusable(c, &embeddingRequest)
@@ -47,13 +47,11 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
} }
err = helper.ModelMappedHelper(c, relayInfo) err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
} }
embeddingRequest.Model = relayInfo.UpstreamModelName
promptToken := getEmbeddingPromptToken(*embeddingRequest) promptToken := getEmbeddingPromptToken(*embeddingRequest)
relayInfo.PromptTokens = promptToken relayInfo.PromptTokens = promptToken

View File

@@ -83,7 +83,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest)
} }
relayInfo := relaycommon.GenRelayInfo(c) relayInfo := relaycommon.GenRelayInfoGemini(c)
// 检查 Gemini 流式模式 // 检查 Gemini 流式模式
checkGeminiStreamMode(c, relayInfo) checkGeminiStreamMode(c, relayInfo)
@@ -97,7 +97,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
} }
// model mapped 模型映射 // model mapped 模型映射
err = helper.ModelMappedHelper(c, relayInfo) err = helper.ModelMappedHelper(c, relayInfo, req)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
} }

View File

@@ -4,12 +4,14 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
common2 "one-api/common"
"one-api/dto"
"one-api/relay/common" "one-api/relay/common"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error { func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) error {
// map model name // map model name
modelMapping := c.GetString("model_mapping") modelMapping := c.GetString("model_mapping")
if modelMapping != "" && modelMapping != "{}" { if modelMapping != "" && modelMapping != "{}" {
@@ -50,5 +52,41 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
info.UpstreamModelName = currentModel info.UpstreamModelName = currentModel
} }
} }
if request != nil {
switch info.RelayFormat {
case common.RelayFormatGemini:
// Gemini 模型映射
case common.RelayFormatClaude:
if claudeRequest, ok := request.(*dto.ClaudeRequest); ok {
claudeRequest.Model = info.UpstreamModelName
}
case common.RelayFormatOpenAIResponses:
if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok {
openAIResponsesRequest.Model = info.UpstreamModelName
}
case common.RelayFormatOpenAIAudio:
if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok {
openAIAudioRequest.Model = info.UpstreamModelName
}
case common.RelayFormatOpenAIImage:
if imageRequest, ok := request.(*dto.ImageRequest); ok {
imageRequest.Model = info.UpstreamModelName
}
case common.RelayFormatRerank:
if rerankRequest, ok := request.(*dto.RerankRequest); ok {
rerankRequest.Model = info.UpstreamModelName
}
case common.RelayFormatEmbedding:
if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok {
embeddingRequest.Model = info.UpstreamModelName
}
default:
if openAIRequest, ok := request.(*dto.GeneralOpenAIRequest); ok {
openAIRequest.Model = info.UpstreamModelName
} else {
common2.LogWarn(c, fmt.Sprintf("model mapped but request type %T not supported", request))
}
}
}
return nil return nil
} }

View File

@@ -102,7 +102,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
} }
func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
relayInfo := relaycommon.GenRelayInfo(c) relayInfo := relaycommon.GenRelayInfoImage(c)
imageRequest, err := getAndValidImageRequest(c, relayInfo) imageRequest, err := getAndValidImageRequest(c, relayInfo)
if err != nil { if err != nil {
@@ -110,13 +110,11 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
} }
err = helper.ModelMappedHelper(c, relayInfo) err = helper.ModelMappedHelper(c, relayInfo, imageRequest)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
} }
imageRequest.Model = relayInfo.UpstreamModelName
priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0) priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)

View File

@@ -108,13 +108,11 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
} }
} }
err = helper.ModelMappedHelper(c, relayInfo) err = helper.ModelMappedHelper(c, relayInfo, textRequest)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
} }
textRequest.Model = relayInfo.UpstreamModelName
// 获取 promptTokens如果上下文中已经存在则直接使用 // 获取 promptTokens如果上下文中已经存在则直接使用
var promptTokens int var promptTokens int
if value, exists := c.Get("prompt_tokens"); exists { if value, exists := c.Get("prompt_tokens"); exists {

View File

@@ -42,13 +42,11 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
} }
err = helper.ModelMappedHelper(c, relayInfo) err = helper.ModelMappedHelper(c, relayInfo, rerankRequest)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
} }
rerankRequest.Model = relayInfo.UpstreamModelName
promptToken := getRerankPromptToken(*rerankRequest) promptToken := getRerankPromptToken(*rerankRequest)
relayInfo.PromptTokens = promptToken relayInfo.PromptTokens = promptToken

View File

@@ -63,11 +63,11 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
} }
} }
err = helper.ModelMappedHelper(c, relayInfo) err = helper.ModelMappedHelper(c, relayInfo, req)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
} }
req.Model = relayInfo.UpstreamModelName
if value, exists := c.Get("prompt_tokens"); exists { if value, exists := c.Get("prompt_tokens"); exists {
promptTokens := value.(int) promptTokens := value.(int)
relayInfo.SetPromptTokens(promptTokens) relayInfo.SetPromptTokens(promptTokens)