diff --git a/dto/claude.go b/dto/claude.go index b5d43f23..1a7eacb1 100644 --- a/dto/claude.go +++ b/dto/claude.go @@ -159,6 +159,27 @@ type InputSchema struct { Required any `json:"required,omitempty"` } +type ClaudeWebSearchTool struct { + Type string `json:"type"` + Name string `json:"name"` + MaxUses int `json:"max_uses,omitempty"` + UserLocation *ClaudeWebSearchUserLocation `json:"user_location,omitempty"` +} + +type ClaudeWebSearchUserLocation struct { + Type string `json:"type"` + Timezone string `json:"timezone,omitempty"` + Country string `json:"country,omitempty"` + Region string `json:"region,omitempty"` + City string `json:"city,omitempty"` +} + +type ClaudeToolChoice struct { + Type string `json:"type"` + Name string `json:"name,omitempty"` + DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"` +} + type ClaudeRequest struct { Model string `json:"model"` Prompt string `json:"prompt,omitempty"` @@ -177,6 +198,59 @@ type ClaudeRequest struct { Thinking *Thinking `json:"thinking,omitempty"` } +// AddTool 添加工具到请求中 +func (c *ClaudeRequest) AddTool(tool any) { + if c.Tools == nil { + c.Tools = make([]any, 0) + } + + switch tools := c.Tools.(type) { + case []any: + c.Tools = append(tools, tool) + default: + // 如果Tools不是[]any类型,重新初始化为[]any + c.Tools = []any{tool} + } +} + +// GetTools 获取工具列表 +func (c *ClaudeRequest) GetTools() []any { + if c.Tools == nil { + return nil + } + + switch tools := c.Tools.(type) { + case []any: + return tools + default: + return nil + } +} + +// ProcessTools 处理工具列表,支持类型断言 +func ProcessTools(tools []any) ([]*Tool, []*ClaudeWebSearchTool) { + var normalTools []*Tool + var webSearchTools []*ClaudeWebSearchTool + + for _, tool := range tools { + switch t := tool.(type) { + case *Tool: + normalTools = append(normalTools, t) + case *ClaudeWebSearchTool: + webSearchTools = append(webSearchTools, t) + case Tool: + normalTools = append(normalTools, &t) + case ClaudeWebSearchTool: + webSearchTools = append(webSearchTools, &t) + default: + // 未知类型,跳过 + continue + } + } + + return normalTools, webSearchTools +} + type Thinking struct { Type string `json:"type"` BudgetTokens *int `json:"budget_tokens,omitempty"` @@ -251,8 +325,13 @@ func (c *ClaudeResponse) GetIndex() int { } type ClaudeUsage struct { - InputTokens int `json:"input_tokens"` - CacheCreationInputTokens int `json:"cache_creation_input_tokens"` - CacheReadInputTokens int `json:"cache_read_input_tokens"` - OutputTokens int `json:"output_tokens"` + InputTokens int `json:"input_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens"` + CacheReadInputTokens int `json:"cache_read_input_tokens"` + OutputTokens int `json:"output_tokens"` + ServerToolUse *ClaudeServerToolUse `json:"server_tool_use"` +} + +type ClaudeServerToolUse struct { + WebSearchRequests int `json:"web_search_requests"` } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index d03c61c2..f20b573d 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -18,6 +18,12 @@ import ( "github.com/gin-gonic/gin" ) +const ( + WebSearchMaxUsesLow = 1 + WebSearchMaxUsesMedium = 5 + WebSearchMaxUsesHigh = 10 +) + func stopReasonClaude2OpenAI(reason string) string { switch reason { case "stop_sequence": @@ -65,7 +71,7 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.Cla } func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) { - claudeTools := make([]dto.Tool, 0, len(textRequest.Tools)) + claudeTools := make([]any, 0, len(textRequest.Tools)) for _, tool := range textRequest.Tools { if params, ok := tool.Function.Parameters.(map[string]any); ok { @@ -85,10 +91,62 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla } claudeTool.InputSchema[s] = a } - claudeTools = append(claudeTools, claudeTool) + claudeTools = append(claudeTools, &claudeTool) } } + // Web search tool + // https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/web-search-tool + if textRequest.WebSearchOptions != nil { + webSearchTool := dto.ClaudeWebSearchTool{ + Type: "web_search_20250305", + Name: "web_search", + } + + // 处理 user_location + if textRequest.WebSearchOptions.UserLocation != nil { + anthropicUserLocation := &dto.ClaudeWebSearchUserLocation{ + Type: "approximate", // 固定为 "approximate" + } + + // 解析 UserLocation JSON + var userLocationMap map[string]interface{} + if err := json.Unmarshal(textRequest.WebSearchOptions.UserLocation, &userLocationMap); err == nil { + // 检查是否有 approximate 字段 + if approximateData, ok := userLocationMap["approximate"].(map[string]interface{}); ok { + if timezone, ok := approximateData["timezone"].(string); ok && timezone != "" { + anthropicUserLocation.Timezone = timezone + } + if country, ok := approximateData["country"].(string); ok && country != "" { + anthropicUserLocation.Country = country + } + if region, ok := approximateData["region"].(string); ok && region != "" { + anthropicUserLocation.Region = region + } + if city, ok := approximateData["city"].(string); ok && city != "" { + anthropicUserLocation.City = city + } + } + } + + webSearchTool.UserLocation = anthropicUserLocation + } + + // 处理 search_context_size 转换为 max_uses + if textRequest.WebSearchOptions.SearchContextSize != "" { + switch textRequest.WebSearchOptions.SearchContextSize { + case "low": + webSearchTool.MaxUses = WebSearchMaxUsesLow + case "medium": + webSearchTool.MaxUses = WebSearchMaxUsesMedium + case "high": + webSearchTool.MaxUses = WebSearchMaxUsesHigh + } + } + + claudeTools = append(claudeTools, &webSearchTool) + } + claudeRequest := dto.ClaudeRequest{ Model: textRequest.Model, MaxTokens: textRequest.MaxTokens, @@ -100,6 +158,14 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla Tools: claudeTools, } + // 处理 tool_choice 和 parallel_tool_calls + if textRequest.ToolChoice != nil || textRequest.ParallelTooCalls != nil { + claudeToolChoice := mapToolChoice(textRequest.ToolChoice, textRequest.ParallelTooCalls) + if claudeToolChoice != nil { + claudeRequest.ToolChoice = claudeToolChoice + } + } + if claudeRequest.MaxTokens == 0 { claudeRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model)) } @@ -124,6 +190,27 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking") } + if textRequest.ReasoningEffort != "" { + switch textRequest.ReasoningEffort { + case "low": + claudeRequest.Thinking = &dto.Thinking{ + Type: "enabled", + BudgetTokens: common.GetPointer[int](1280), + } + case "medium": + claudeRequest.Thinking = &dto.Thinking{ + Type: "enabled", + BudgetTokens: common.GetPointer[int](2048), + } + case "high": + claudeRequest.Thinking = &dto.Thinking{ + Type: "enabled", + BudgetTokens: common.GetPointer[int](4096), + } + } + } + + // 指定了 reasoning 参数,覆盖 budgetTokens if textRequest.Reasoning != nil { var reasoning openrouter.RequestReasoning if err := common.Unmarshal(textRequest.Reasoning, &reasoning); err != nil { @@ -645,6 +732,10 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud responseData = data } + if claudeResponse.Usage.ServerToolUse != nil && claudeResponse.Usage.ServerToolUse.WebSearchRequests > 0 { + c.Set("claude_web_search_requests", claudeResponse.Usage.ServerToolUse.WebSearchRequests) + } + common.IOCopyBytesGracefully(c, nil, responseData) return nil } @@ -672,3 +763,51 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r } return nil, claudeInfo.Usage } + +func mapToolChoice(toolChoice any, parallelToolCalls *bool) *dto.ClaudeToolChoice { + var claudeToolChoice *dto.ClaudeToolChoice + + // 处理 tool_choice 字符串值 + if toolChoiceStr, ok := toolChoice.(string); ok { + switch toolChoiceStr { + case "auto": + claudeToolChoice = &dto.ClaudeToolChoice{ + Type: "auto", + } + case "required": + claudeToolChoice = &dto.ClaudeToolChoice{ + Type: "any", + } + case "none": + claudeToolChoice = &dto.ClaudeToolChoice{ + Type: "none", + } + } + } else if toolChoiceMap, ok := toolChoice.(map[string]interface{}); ok { + // 处理 tool_choice 对象值 + if function, ok := toolChoiceMap["function"].(map[string]interface{}); ok { + if toolName, ok := function["name"].(string); ok { + claudeToolChoice = &dto.ClaudeToolChoice{ + Type: "tool", + Name: toolName, + } + } + } + } + + // 处理 parallel_tool_calls + if parallelToolCalls != nil { + if claudeToolChoice == nil { + // 如果没有 tool_choice,但有 parallel_tool_calls,创建默认的 auto 类型 + claudeToolChoice = &dto.ClaudeToolChoice{ + Type: "auto", + } + } + + // 设置 disable_parallel_tool_use + // 如果 parallel_tool_calls 为 true,则 disable_parallel_tool_use 为 false + claudeToolChoice.DisableParallelToolUse = !*parallelToolCalls + } + + return claudeToolChoice +} diff --git a/relay/relay-text.go b/relay/relay-text.go index 46120529..9ca86a7c 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -379,6 +379,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, // openai web search 工具计费 var dWebSearchQuota decimal.Decimal var webSearchPrice float64 + // response api 格式工具计费 if relayInfo.ResponsesUsageInfo != nil { if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 { // 计算 web search 调用的配额 (配额 = 价格 * 调用次数 / 1000 * 分组倍率) @@ -401,6 +402,17 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, extraContent += fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s,调用花费 %s", searchContextSize, dWebSearchQuota.String()) } + // claude web search tool 计费 + var dClaudeWebSearchQuota decimal.Decimal + var claudeWebSearchPrice float64 + claudeWebSearchCallCount := ctx.GetInt("claude_web_search_requests") + if claudeWebSearchCallCount > 0 { + claudeWebSearchPrice = operation_setting.GetClaudeWebSearchPricePerThousand() + dClaudeWebSearchQuota = decimal.NewFromFloat(claudeWebSearchPrice). + Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit).Mul(decimal.NewFromInt(int64(claudeWebSearchCallCount))) + extraContent += fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s", + claudeWebSearchCallCount, dClaudeWebSearchQuota.String()) + } // file search tool 计费 var dFileSearchQuota decimal.Decimal var fileSearchPrice float64 @@ -524,6 +536,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, other["web_search_call_count"] = 1 other["web_search_price"] = webSearchPrice } + } else if !dClaudeWebSearchQuota.IsZero() { + other["web_search"] = true + other["web_search_call_count"] = claudeWebSearchCallCount + other["web_search_price"] = claudeWebSearchPrice } if !dFileSearchQuota.IsZero() && relayInfo.ResponsesUsageInfo != nil { if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists { diff --git a/setting/operation_setting/tools.go b/setting/operation_setting/tools.go index a401b923..a59090ce 100644 --- a/setting/operation_setting/tools.go +++ b/setting/operation_setting/tools.go @@ -23,6 +23,15 @@ const ( Gemini20FlashInputAudioPrice = 0.70 ) +const ( + // Claude Web search + ClaudeWebSearchPrice = 10.00 +) + +func GetClaudeWebSearchPricePerThousand() float64 { + return ClaudeWebSearchPrice +} + func GetWebSearchPricePerThousand(modelName string, contextSize string) float64 { // 确定模型类型 // https://platform.openai.com/docs/pricing Web search 价格按模型类型和 search context size 收费