Merge pull request #1384 from QuantumNous/RequestOpenAI2ClaudeMessage

feat: 改进 RequestOpenAI2ClaudeMessage 和添加 claude web search 计费
This commit is contained in:
Calcium-Ion
2025-07-17 19:15:54 +08:00
committed by GitHub
4 changed files with 249 additions and 6 deletions

View File

@@ -159,6 +159,27 @@ type InputSchema struct {
Required any `json:"required,omitempty"`
}
type ClaudeWebSearchTool struct {
Type string `json:"type"`
Name string `json:"name"`
MaxUses int `json:"max_uses,omitempty"`
UserLocation *ClaudeWebSearchUserLocation `json:"user_location,omitempty"`
}
type ClaudeWebSearchUserLocation struct {
Type string `json:"type"`
Timezone string `json:"timezone,omitempty"`
Country string `json:"country,omitempty"`
Region string `json:"region,omitempty"`
City string `json:"city,omitempty"`
}
type ClaudeToolChoice struct {
Type string `json:"type"`
Name string `json:"name,omitempty"`
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
}
type ClaudeRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt,omitempty"`
@@ -177,6 +198,59 @@ type ClaudeRequest struct {
Thinking *Thinking `json:"thinking,omitempty"`
}
// AddTool 添加工具到请求中
func (c *ClaudeRequest) AddTool(tool any) {
if c.Tools == nil {
c.Tools = make([]any, 0)
}
switch tools := c.Tools.(type) {
case []any:
c.Tools = append(tools, tool)
default:
// 如果Tools不是[]any类型重新初始化为[]any
c.Tools = []any{tool}
}
}
// GetTools 获取工具列表
func (c *ClaudeRequest) GetTools() []any {
if c.Tools == nil {
return nil
}
switch tools := c.Tools.(type) {
case []any:
return tools
default:
return nil
}
}
// ProcessTools 处理工具列表,支持类型断言
func ProcessTools(tools []any) ([]*Tool, []*ClaudeWebSearchTool) {
var normalTools []*Tool
var webSearchTools []*ClaudeWebSearchTool
for _, tool := range tools {
switch t := tool.(type) {
case *Tool:
normalTools = append(normalTools, t)
case *ClaudeWebSearchTool:
webSearchTools = append(webSearchTools, t)
case Tool:
normalTools = append(normalTools, &t)
case ClaudeWebSearchTool:
webSearchTools = append(webSearchTools, &t)
default:
// 未知类型,跳过
continue
}
}
return normalTools, webSearchTools
}
type Thinking struct {
Type string `json:"type"`
BudgetTokens *int `json:"budget_tokens,omitempty"`
@@ -251,8 +325,13 @@ func (c *ClaudeResponse) GetIndex() int {
}
type ClaudeUsage struct {
InputTokens int `json:"input_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokens int `json:"input_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"`
OutputTokens int `json:"output_tokens"`
ServerToolUse *ClaudeServerToolUse `json:"server_tool_use"`
}
type ClaudeServerToolUse struct {
WebSearchRequests int `json:"web_search_requests"`
}

View File

@@ -18,6 +18,12 @@ import (
"github.com/gin-gonic/gin"
)
const (
WebSearchMaxUsesLow = 1
WebSearchMaxUsesMedium = 5
WebSearchMaxUsesHigh = 10
)
func stopReasonClaude2OpenAI(reason string) string {
switch reason {
case "stop_sequence":
@@ -65,7 +71,7 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.Cla
}
func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
claudeTools := make([]dto.Tool, 0, len(textRequest.Tools))
claudeTools := make([]any, 0, len(textRequest.Tools))
for _, tool := range textRequest.Tools {
if params, ok := tool.Function.Parameters.(map[string]any); ok {
@@ -85,10 +91,62 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
}
claudeTool.InputSchema[s] = a
}
claudeTools = append(claudeTools, claudeTool)
claudeTools = append(claudeTools, &claudeTool)
}
}
// Web search tool
// https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/web-search-tool
if textRequest.WebSearchOptions != nil {
webSearchTool := dto.ClaudeWebSearchTool{
Type: "web_search_20250305",
Name: "web_search",
}
// 处理 user_location
if textRequest.WebSearchOptions.UserLocation != nil {
anthropicUserLocation := &dto.ClaudeWebSearchUserLocation{
Type: "approximate", // 固定为 "approximate"
}
// 解析 UserLocation JSON
var userLocationMap map[string]interface{}
if err := json.Unmarshal(textRequest.WebSearchOptions.UserLocation, &userLocationMap); err == nil {
// 检查是否有 approximate 字段
if approximateData, ok := userLocationMap["approximate"].(map[string]interface{}); ok {
if timezone, ok := approximateData["timezone"].(string); ok && timezone != "" {
anthropicUserLocation.Timezone = timezone
}
if country, ok := approximateData["country"].(string); ok && country != "" {
anthropicUserLocation.Country = country
}
if region, ok := approximateData["region"].(string); ok && region != "" {
anthropicUserLocation.Region = region
}
if city, ok := approximateData["city"].(string); ok && city != "" {
anthropicUserLocation.City = city
}
}
}
webSearchTool.UserLocation = anthropicUserLocation
}
// 处理 search_context_size 转换为 max_uses
if textRequest.WebSearchOptions.SearchContextSize != "" {
switch textRequest.WebSearchOptions.SearchContextSize {
case "low":
webSearchTool.MaxUses = WebSearchMaxUsesLow
case "medium":
webSearchTool.MaxUses = WebSearchMaxUsesMedium
case "high":
webSearchTool.MaxUses = WebSearchMaxUsesHigh
}
}
claudeTools = append(claudeTools, &webSearchTool)
}
claudeRequest := dto.ClaudeRequest{
Model: textRequest.Model,
MaxTokens: textRequest.MaxTokens,
@@ -100,6 +158,14 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
Tools: claudeTools,
}
// 处理 tool_choice 和 parallel_tool_calls
if textRequest.ToolChoice != nil || textRequest.ParallelTooCalls != nil {
claudeToolChoice := mapToolChoice(textRequest.ToolChoice, textRequest.ParallelTooCalls)
if claudeToolChoice != nil {
claudeRequest.ToolChoice = claudeToolChoice
}
}
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
}
@@ -124,6 +190,27 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
}
if textRequest.ReasoningEffort != "" {
switch textRequest.ReasoningEffort {
case "low":
claudeRequest.Thinking = &dto.Thinking{
Type: "enabled",
BudgetTokens: common.GetPointer[int](1280),
}
case "medium":
claudeRequest.Thinking = &dto.Thinking{
Type: "enabled",
BudgetTokens: common.GetPointer[int](2048),
}
case "high":
claudeRequest.Thinking = &dto.Thinking{
Type: "enabled",
BudgetTokens: common.GetPointer[int](4096),
}
}
}
// 指定了 reasoning 参数,覆盖 budgetTokens
if textRequest.Reasoning != nil {
var reasoning openrouter.RequestReasoning
if err := common.Unmarshal(textRequest.Reasoning, &reasoning); err != nil {
@@ -645,6 +732,10 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
responseData = data
}
if claudeResponse.Usage.ServerToolUse != nil && claudeResponse.Usage.ServerToolUse.WebSearchRequests > 0 {
c.Set("claude_web_search_requests", claudeResponse.Usage.ServerToolUse.WebSearchRequests)
}
common.IOCopyBytesGracefully(c, nil, responseData)
return nil
}
@@ -672,3 +763,51 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
}
return nil, claudeInfo.Usage
}
func mapToolChoice(toolChoice any, parallelToolCalls *bool) *dto.ClaudeToolChoice {
var claudeToolChoice *dto.ClaudeToolChoice
// 处理 tool_choice 字符串值
if toolChoiceStr, ok := toolChoice.(string); ok {
switch toolChoiceStr {
case "auto":
claudeToolChoice = &dto.ClaudeToolChoice{
Type: "auto",
}
case "required":
claudeToolChoice = &dto.ClaudeToolChoice{
Type: "any",
}
case "none":
claudeToolChoice = &dto.ClaudeToolChoice{
Type: "none",
}
}
} else if toolChoiceMap, ok := toolChoice.(map[string]interface{}); ok {
// 处理 tool_choice 对象值
if function, ok := toolChoiceMap["function"].(map[string]interface{}); ok {
if toolName, ok := function["name"].(string); ok {
claudeToolChoice = &dto.ClaudeToolChoice{
Type: "tool",
Name: toolName,
}
}
}
}
// 处理 parallel_tool_calls
if parallelToolCalls != nil {
if claudeToolChoice == nil {
// 如果没有 tool_choice但有 parallel_tool_calls创建默认的 auto 类型
claudeToolChoice = &dto.ClaudeToolChoice{
Type: "auto",
}
}
// 设置 disable_parallel_tool_use
// 如果 parallel_tool_calls 为 true则 disable_parallel_tool_use 为 false
claudeToolChoice.DisableParallelToolUse = !*parallelToolCalls
}
return claudeToolChoice
}

View File

@@ -379,6 +379,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
// openai web search 工具计费
var dWebSearchQuota decimal.Decimal
var webSearchPrice float64
// response api 格式工具计费
if relayInfo.ResponsesUsageInfo != nil {
if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 {
// 计算 web search 调用的配额 (配额 = 价格 * 调用次数 / 1000 * 分组倍率)
@@ -401,6 +402,17 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
extraContent += fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s调用花费 %s",
searchContextSize, dWebSearchQuota.String())
}
// claude web search tool 计费
var dClaudeWebSearchQuota decimal.Decimal
var claudeWebSearchPrice float64
claudeWebSearchCallCount := ctx.GetInt("claude_web_search_requests")
if claudeWebSearchCallCount > 0 {
claudeWebSearchPrice = operation_setting.GetClaudeWebSearchPricePerThousand()
dClaudeWebSearchQuota = decimal.NewFromFloat(claudeWebSearchPrice).
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit).Mul(decimal.NewFromInt(int64(claudeWebSearchCallCount)))
extraContent += fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s",
claudeWebSearchCallCount, dClaudeWebSearchQuota.String())
}
// file search tool 计费
var dFileSearchQuota decimal.Decimal
var fileSearchPrice float64
@@ -524,6 +536,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
other["web_search_call_count"] = 1
other["web_search_price"] = webSearchPrice
}
} else if !dClaudeWebSearchQuota.IsZero() {
other["web_search"] = true
other["web_search_call_count"] = claudeWebSearchCallCount
other["web_search_price"] = claudeWebSearchPrice
}
if !dFileSearchQuota.IsZero() && relayInfo.ResponsesUsageInfo != nil {
if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists {

View File

@@ -23,6 +23,15 @@ const (
Gemini20FlashInputAudioPrice = 0.70
)
const (
// Claude Web search
ClaudeWebSearchPrice = 10.00
)
func GetClaudeWebSearchPricePerThousand() float64 {
return ClaudeWebSearchPrice
}
func GetWebSearchPricePerThousand(modelName string, contextSize string) float64 {
// 确定模型类型
// https://platform.openai.com/docs/pricing Web search 价格按模型类型和 search context size 收费