diff --git a/dto/openai_request.go b/dto/openai_request.go index bda1bb17..16cdf3a2 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -2,6 +2,7 @@ package dto import ( "encoding/json" + "one-api/common" "strings" ) @@ -57,6 +58,13 @@ type GeneralOpenAIRequest struct { WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"` } +func (r *GeneralOpenAIRequest) ToMap() map[string]any { + result := make(map[string]any) + data, _ := common.EncodeJson(r) + _ = common.DecodeJson(data, &result) + return result +} + type ToolCallRequest struct { ID string `json:"id,omitempty"` Type string `json:"type"` @@ -74,11 +82,11 @@ type StreamOptions struct { IncludeUsage bool `json:"include_usage,omitempty"` } -func (r GeneralOpenAIRequest) GetMaxTokens() int { +func (r *GeneralOpenAIRequest) GetMaxTokens() int { return int(r.MaxTokens) } -func (r GeneralOpenAIRequest) ParseInput() []string { +func (r *GeneralOpenAIRequest) ParseInput() []string { if r.Input == nil { return nil } diff --git a/relay/channel/baidu_v2/adaptor.go b/relay/channel/baidu_v2/adaptor.go index 77afe2dd..2b8a52a2 100644 --- a/relay/channel/baidu_v2/adaptor.go +++ b/relay/channel/baidu_v2/adaptor.go @@ -9,6 +9,7 @@ import ( "one-api/relay/channel" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + "strings" "github.com/gin-gonic/gin" ) @@ -49,6 +50,18 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } + if strings.HasSuffix(info.UpstreamModelName, "-search") { + info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-search") + request.Model = info.UpstreamModelName + toMap := request.ToMap() + toMap["web_search"] = map[string]any{ + "enable": true, + "enable_citation": true, + "enable_trace": true, + "enable_status": false, + } + return toMap, nil + } return request, nil }