diff --git a/dto/openai_request.go b/dto/openai_request.go index e8833b3d..e491812a 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -53,6 +53,7 @@ type GeneralOpenAIRequest struct { Audio any `json:"audio,omitempty"` EnableThinking any `json:"enable_thinking,omitempty"` // ali ExtraBody any `json:"extra_body,omitempty"` + WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"` } type ToolCallRequest struct { @@ -371,6 +372,11 @@ func (m *Message) ParseContent() []MediaContent { return contentList } +type WebSearchOptions struct { + SearchContextSize string `json:"search_context_size,omitempty"` + UserLocation json.RawMessage `json:"user_location,omitempty"` +} + type OpenAIResponsesRequest struct { Model string `json:"model"` Input json.RawMessage `json:"input,omitempty"` diff --git a/relay/relay-text.go b/relay/relay-text.go index 8d5cd384..f1105907 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -47,6 +47,20 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) if textRequest.Model == "" { return nil, errors.New("model is required") } + if textRequest.WebSearchOptions != nil { + if textRequest.WebSearchOptions.SearchContextSize != "" { + validSizes := map[string]bool{ + "high": true, + "medium": true, + "low": true, + } + if !validSizes[textRequest.WebSearchOptions.SearchContextSize] { + return nil, errors.New("invalid search_context_size, must be one of: high, medium, low") + } + } else { + textRequest.WebSearchOptions.SearchContextSize = "medium" + } + } switch relayInfo.RelayMode { case relayconstant.RelayModeCompletions: if textRequest.Prompt == "" { @@ -76,6 +90,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { // get & validate textRequest 获取并验证文本请求 textRequest, err := getAndValidateTextRequest(c, relayInfo) + if textRequest.WebSearchOptions != nil { + c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize) + } + if err != nil { common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) @@ -370,9 +388,20 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, dWebSearchQuota = decimal.NewFromFloat(webSearchPrice). Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))). Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit) - extraContent += fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 $%s", + extraContent += fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 %s", webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String()) } + } else if strings.HasSuffix(modelName, "search-preview") { + // search-preview 模型不支持 response api + searchContextSize := ctx.GetString("chat_completion_web_search_context_size") + if searchContextSize == "" { + searchContextSize = "medium" + } + webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, searchContextSize) + dWebSearchQuota = decimal.NewFromFloat(webSearchPrice). + Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit) + extraContent += fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s,调用花费 %s", + searchContextSize, dWebSearchQuota.String()) } // file search tool 计费 var dFileSearchQuota decimal.Decimal @@ -463,10 +492,16 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, other["image_ratio"] = imageRatio other["image_output"] = imageTokens } - if !dWebSearchQuota.IsZero() && relayInfo.ResponsesUsageInfo != nil { - if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists { + if !dWebSearchQuota.IsZero() { + if relayInfo.ResponsesUsageInfo != nil { + if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists { + other["web_search"] = true + other["web_search_call_count"] = webSearchTool.CallCount + other["web_search_price"] = webSearchPrice + } + } else if strings.HasSuffix(modelName, "search-preview") { other["web_search"] = true - other["web_search_call_count"] = webSearchTool.CallCount + other["web_search_call_count"] = 1 other["web_search_price"] = webSearchPrice } }