fix(网关): 对齐 Claude OAuth 请求适配

This commit is contained in:
cyhhao
2026-01-15 18:54:42 +08:00
parent 04811c00cb
commit 2a7d04fec4
3 changed files with 481 additions and 26 deletions

View File

@@ -25,15 +25,15 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking
// DefaultHeaders 是 Claude Code 客户端默认请求头。 // DefaultHeaders 是 Claude Code 客户端默认请求头。
var DefaultHeaders = map[string]string{ var DefaultHeaders = map[string]string{
"User-Agent": "claude-cli/2.0.62 (external, cli)", "User-Agent": "claude-cli/2.1.2 (external, cli)",
"X-Stainless-Lang": "js", "X-Stainless-Lang": "js",
"X-Stainless-Package-Version": "0.52.0", "X-Stainless-Package-Version": "0.70.0",
"X-Stainless-OS": "Linux", "X-Stainless-OS": "Linux",
"X-Stainless-Arch": "x64", "X-Stainless-Arch": "x64",
"X-Stainless-Runtime": "node", "X-Stainless-Runtime": "node",
"X-Stainless-Runtime-Version": "v22.14.0", "X-Stainless-Runtime-Version": "v24.3.0",
"X-Stainless-Retry-Count": "0", "X-Stainless-Retry-Count": "0",
"X-Stainless-Timeout": "60", "X-Stainless-Timeout": "600",
"X-App": "cli", "X-App": "cli",
"Anthropic-Dangerous-Direct-Browser-Access": "true", "Anthropic-Dangerous-Direct-Browser-Access": "true",
} }
@@ -79,3 +79,39 @@ func DefaultModelIDs() []string {
// DefaultTestModel 测试时使用的默认模型 // DefaultTestModel 测试时使用的默认模型
const DefaultTestModel = "claude-sonnet-4-5-20250929" const DefaultTestModel = "claude-sonnet-4-5-20250929"
// ModelIDOverrides Claude OAuth 请求需要的模型 ID 映射
var ModelIDOverrides = map[string]string{
"claude-sonnet-4-5": "claude-sonnet-4-5-20250929",
"claude-opus-4-5": "claude-opus-4-5-20251101",
"claude-haiku-4-5": "claude-haiku-4-5-20251001",
}
// ModelIDReverseOverrides 用于将上游模型 ID 还原为短名
var ModelIDReverseOverrides = map[string]string{
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
"claude-opus-4-5-20251101": "claude-opus-4-5",
"claude-haiku-4-5-20251001": "claude-haiku-4-5",
}
// NormalizeModelID 根据 Claude OAuth 规则映射模型
func NormalizeModelID(id string) string {
if id == "" {
return id
}
if mapped, ok := ModelIDOverrides[id]; ok {
return mapped
}
return id
}
// DenormalizeModelID 将上游模型 ID 转换为短名
func DenormalizeModelID(id string) string {
if id == "" {
return id
}
if mapped, ok := ModelIDReverseOverrides[id]; ok {
return mapped
}
return id
}

View File

@@ -18,12 +18,14 @@ import (
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
"unicode"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/google/uuid"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
@@ -60,6 +62,36 @@ var (
sseDataRe = regexp.MustCompile(`^data:\s*`) sseDataRe = regexp.MustCompile(`^data:\s*`)
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`) sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`) claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
toolPrefixRe = regexp.MustCompile(`(?i)^(?:oc_|mcp_)`)
toolNameBoundaryRe = regexp.MustCompile(`[^a-zA-Z0-9]+`)
toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`)
claudeToolNameOverrides = map[string]string{
"bash": "Bash",
"read": "Read",
"edit": "Edit",
"write": "Write",
"task": "Task",
"glob": "Glob",
"grep": "Grep",
"webfetch": "WebFetch",
"websearch": "WebSearch",
"todowrite": "TodoWrite",
"question": "AskUserQuestion",
}
openCodeToolOverrides = map[string]string{
"Bash": "bash",
"Read": "read",
"Edit": "edit",
"Write": "write",
"Task": "task",
"Glob": "glob",
"Grep": "grep",
"WebFetch": "webfetch",
"WebSearch": "websearch",
"TodoWrite": "todowrite",
"AskUserQuestion": "question",
}
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表 // claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// 支持多种变体标准版、Agent SDK 版、Explore Agent 版、Compact 版等 // 支持多种变体标准版、Agent SDK 版、Explore Agent 版、Compact 版等
@@ -365,6 +397,268 @@ func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte
return newBody return newBody
} }
type claudeOAuthNormalizeOptions struct {
injectMetadata bool
metadataUserID string
stripSystemCacheControl bool
}
func stripToolPrefix(value string) string {
if value == "" {
return value
}
return toolPrefixRe.ReplaceAllString(value, "")
}
func toPascalCase(value string) string {
if value == "" {
return value
}
normalized := toolNameBoundaryRe.ReplaceAllString(value, " ")
tokens := make([]string, 0)
for _, token := range strings.Fields(normalized) {
expanded := toolNameCamelRe.ReplaceAllString(token, "$1 $2")
parts := strings.Fields(expanded)
if len(parts) > 0 {
tokens = append(tokens, parts...)
}
}
if len(tokens) == 0 {
return value
}
var builder strings.Builder
for _, token := range tokens {
lower := strings.ToLower(token)
if lower == "" {
continue
}
runes := []rune(lower)
runes[0] = unicode.ToUpper(runes[0])
builder.WriteString(string(runes))
}
return builder.String()
}
func toSnakeCase(value string) string {
if value == "" {
return value
}
output := toolNameCamelRe.ReplaceAllString(value, "$1_$2")
output = toolNameBoundaryRe.ReplaceAllString(output, "_")
output = strings.Trim(output, "_")
return strings.ToLower(output)
}
func normalizeToolNameForClaude(name string, cache map[string]string) string {
if name == "" {
return name
}
stripped := stripToolPrefix(name)
mapped, ok := claudeToolNameOverrides[strings.ToLower(stripped)]
if !ok {
mapped = toPascalCase(stripped)
}
if mapped != "" && cache != nil && mapped != stripped {
cache[mapped] = stripped
}
if mapped == "" {
return stripped
}
return mapped
}
func normalizeToolNameForOpenCode(name string, cache map[string]string) string {
if name == "" {
return name
}
if cache != nil {
if mapped, ok := cache[name]; ok {
return mapped
}
}
if mapped, ok := openCodeToolOverrides[name]; ok {
return mapped
}
return toSnakeCase(name)
}
func stripCacheControlFromSystemBlocks(system any) bool {
blocks, ok := system.([]any)
if !ok {
return false
}
changed := false
for _, item := range blocks {
block, ok := item.(map[string]any)
if !ok {
continue
}
if _, exists := block["cache_control"]; !exists {
continue
}
if text, ok := block["text"].(string); ok && text == claudeCodeSystemPrompt {
continue
}
delete(block, "cache_control")
changed = true
}
return changed
}
func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string, map[string]string) {
if len(body) == 0 {
return body, modelID, nil
}
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return body, modelID, nil
}
toolNameMap := make(map[string]string)
if rawModel, ok := req["model"].(string); ok {
normalized := claude.NormalizeModelID(rawModel)
if normalized != rawModel {
req["model"] = normalized
modelID = normalized
}
}
if rawTools, exists := req["tools"]; exists {
switch tools := rawTools.(type) {
case []any:
for idx, tool := range tools {
toolMap, ok := tool.(map[string]any)
if !ok {
continue
}
if name, ok := toolMap["name"].(string); ok {
normalized := normalizeToolNameForClaude(name, toolNameMap)
if normalized != "" && normalized != name {
toolMap["name"] = normalized
}
}
tools[idx] = toolMap
}
req["tools"] = tools
case map[string]any:
normalizedTools := make(map[string]any, len(tools))
for name, value := range tools {
normalized := normalizeToolNameForClaude(name, toolNameMap)
if normalized == "" {
normalized = name
}
if toolMap, ok := value.(map[string]any); ok {
if toolName, ok := toolMap["name"].(string); ok {
mappedName := normalizeToolNameForClaude(toolName, toolNameMap)
if mappedName != "" && mappedName != toolName {
toolMap["name"] = mappedName
}
} else if normalized != name {
toolMap["name"] = normalized
}
normalizedTools[normalized] = toolMap
continue
}
normalizedTools[normalized] = value
}
req["tools"] = normalizedTools
}
} else {
req["tools"] = []any{}
}
if messages, ok := req["messages"].([]any); ok {
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
continue
}
content, ok := msgMap["content"].([]any)
if !ok {
continue
}
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
continue
}
if blockType, _ := blockMap["type"].(string); blockType != "tool_use" {
continue
}
if name, ok := blockMap["name"].(string); ok {
normalized := normalizeToolNameForClaude(name, toolNameMap)
if normalized != "" && normalized != name {
blockMap["name"] = normalized
}
}
}
}
}
if opts.stripSystemCacheControl {
if system, ok := req["system"]; ok {
_ = stripCacheControlFromSystemBlocks(system)
}
}
if opts.injectMetadata && opts.metadataUserID != "" {
metadata, ok := req["metadata"].(map[string]any)
if !ok {
metadata = map[string]any{}
req["metadata"] = metadata
}
if existing, ok := metadata["user_id"].(string); !ok || existing == "" {
metadata["user_id"] = opts.metadataUserID
}
}
if _, ok := req["temperature"]; ok {
delete(req, "temperature")
}
if _, ok := req["tool_choice"]; ok {
delete(req, "tool_choice")
}
newBody, err := json.Marshal(req)
if err != nil {
return body, modelID, toolNameMap
}
return newBody, modelID, toolNameMap
}
func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string {
if parsed == nil || fp == nil || fp.ClientID == "" {
return ""
}
if parsed.MetadataUserID != "" {
return ""
}
accountUUID := account.GetExtraString("account_uuid")
if accountUUID == "" {
return ""
}
sessionHash := s.GenerateSessionHash(parsed)
sessionID := uuid.NewString()
if sessionHash != "" {
seed := fmt.Sprintf("%d::%s", account.ID, sessionHash)
sessionID = generateSessionUUID(seed)
}
return fmt.Sprintf("user_%s_account_%s_session_%s", fp.ClientID, accountUUID, sessionID)
}
func generateSessionUUID(seed string) string {
if seed == "" {
return uuid.NewString()
}
hash := sha256.Sum256([]byte(seed))
bytes := hash[:16]
bytes[6] = (bytes[6] & 0x0f) | 0x40
bytes[8] = (bytes[8] & 0x3f) | 0x80
return fmt.Sprintf("%x-%x-%x-%x-%x",
bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16])
}
// SelectAccount 选择账号(粘性会话+优先级) // SelectAccount 选择账号(粘性会话+优先级)
func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) { func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
return s.SelectAccountForModel(ctx, groupID, sessionHash, "") return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
@@ -1906,21 +2200,36 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
body := parsed.Body body := parsed.Body
reqModel := parsed.Model reqModel := parsed.Model
reqStream := parsed.Stream reqStream := parsed.Stream
originalModel := reqModel
var toolNameMap map[string]string
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要) if account.IsOAuth() {
// 条件1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词 // 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
if account.IsOAuth() && // 条件1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
!isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) && if !isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) &&
!strings.Contains(strings.ToLower(reqModel), "haiku") && !strings.Contains(strings.ToLower(reqModel), "haiku") &&
!systemIncludesClaudeCodePrompt(parsed.System) { !systemIncludesClaudeCodePrompt(parsed.System) {
body = injectClaudeCodePrompt(body, parsed.System) body = injectClaudeCodePrompt(body, parsed.System)
}
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
if s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err == nil && fp != nil {
if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" {
normalizeOpts.injectMetadata = true
normalizeOpts.metadataUserID = metadataUserID
}
}
}
body, reqModel, toolNameMap = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
} }
// 强制执行 cache_control 块数量限制(最多 4 个) // 强制执行 cache_control 块数量限制(最多 4 个)
body = enforceCacheControlLimit(body) body = enforceCacheControlLimit(body)
// 应用模型映射仅对apikey类型账号 // 应用模型映射仅对apikey类型账号
originalModel := reqModel
if account.Type == AccountTypeAPIKey { if account.Type == AccountTypeAPIKey {
mappedModel := account.GetMappedModel(reqModel) mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel { if mappedModel != reqModel {
@@ -1948,10 +2257,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
retryStart := time.Now() retryStart := time.Now()
for attempt := 1; attempt <= maxRetryAttempts; attempt++ { for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取) // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel) upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream)
// Capture upstream request body for ops retry of this attempt. // Capture upstream request body for ops retry of this attempt.
c.Set(OpsUpstreamRequestBodyKey, string(body)) c.Set(OpsUpstreamRequestBodyKey, string(body))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -2029,7 +2337,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// also downgrade tool_use/tool_result blocks to text. // also downgrade tool_use/tool_result blocks to text.
filteredBody := FilterThinkingBlocksForRetry(body) filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel) retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream)
if buildErr == nil { if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency) retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil { if retryErr == nil {
@@ -2061,7 +2369,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed { if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID) log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body) filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel) retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream)
if buildErr2 == nil { if buildErr2 == nil {
retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency) retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency)
if retryErr2 == nil { if retryErr2 == nil {
@@ -2278,7 +2586,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
var firstTokenMs *int var firstTokenMs *int
var clientDisconnect bool var clientDisconnect bool
if reqStream { if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel) streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, toolNameMap)
if err != nil { if err != nil {
if err.Error() == "have error in stream" { if err.Error() == "have error in stream" {
return nil, &UpstreamFailoverError{ return nil, &UpstreamFailoverError{
@@ -2291,7 +2599,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
firstTokenMs = streamResult.firstTokenMs firstTokenMs = streamResult.firstTokenMs
clientDisconnect = streamResult.clientDisconnect clientDisconnect = streamResult.clientDisconnect
} else { } else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel) usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel, toolNameMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -2308,7 +2616,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}, nil }, nil
} }
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) { func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool) (*http.Request, error) {
// 确定目标URL // 确定目标URL
targetURL := claudeAPIURL targetURL := claudeAPIURL
if account.Type == AccountTypeAPIKey { if account.Type == AccountTypeAPIKey {
@@ -2377,6 +2685,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
if req.Header.Get("anthropic-version") == "" { if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01") req.Header.Set("anthropic-version", "2023-06-01")
} }
if tokenType == "oauth" {
applyClaudeOAuthHeaderDefaults(req, reqStream)
}
// 处理anthropic-beta headerOAuth账号需要特殊处理 // 处理anthropic-beta headerOAuth账号需要特殊处理
if tokenType == "oauth" { if tokenType == "oauth" {
@@ -2459,6 +2770,26 @@ func defaultAPIKeyBetaHeader(body []byte) string {
return claude.APIKeyBetaHeader return claude.APIKeyBetaHeader
} }
func applyClaudeOAuthHeaderDefaults(req *http.Request, isStream bool) {
if req == nil {
return
}
if req.Header.Get("accept") == "" {
req.Header.Set("accept", "application/json")
}
for key, value := range claude.DefaultHeaders {
if value == "" {
continue
}
if req.Header.Get(key) == "" {
req.Header.Set(key, value)
}
}
if isStream && req.Header.Get("x-stainless-helper-method") == "" {
req.Header.Set("x-stainless-helper-method", "stream")
}
}
func truncateForLog(b []byte, maxBytes int) string { func truncateForLog(b []byte, maxBytes int) string {
if maxBytes <= 0 { if maxBytes <= 0 {
maxBytes = 2048 maxBytes = 2048
@@ -2739,7 +3070,7 @@ type streamingResult struct {
clientDisconnect bool // 客户端是否在流式传输过程中断开 clientDisconnect bool // 客户端是否在流式传输过程中断开
} }
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) { func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, toolNameMap map[string]string) (*streamingResult, error) {
// 更新5h窗口状态 // 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
@@ -2832,6 +3163,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
} }
needModelReplace := originalModel != mappedModel needModelReplace := originalModel != mappedModel
rewriteTools := account.IsOAuth()
clientDisconnected := false // 客户端断开标志断开后继续读取上游以获取完整usage clientDisconnected := false // 客户端断开标志断开后继续读取上游以获取完整usage
for { for {
@@ -2873,11 +3205,14 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
// Extract data from SSE line (supports both "data: " and "data:" formats) // Extract data from SSE line (supports both "data: " and "data:" formats)
var data string var data string
if sseDataRe.MatchString(line) { if sseDataRe.MatchString(line) {
data = sseDataRe.ReplaceAllString(line, "")
// 如果有模型映射替换响应中的model字段 // 如果有模型映射替换响应中的model字段
if needModelReplace { if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel) line = s.replaceModelInSSELine(line, mappedModel, originalModel)
} }
if rewriteTools {
line = s.replaceToolNamesInSSELine(line, toolNameMap)
}
data = sseDataRe.ReplaceAllString(line, "")
} }
// 写入客户端(统一处理 data 行和非 data 行) // 写入客户端(统一处理 data 行和非 data 行)
@@ -2960,6 +3295,61 @@ func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string)
return "data: " + string(newData) return "data: " + string(newData)
} }
func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
switch v := value.(type) {
case map[string]any:
changed := false
if blockType, _ := v["type"].(string); blockType == "tool_use" {
if name, ok := v["name"].(string); ok {
mapped := normalizeToolNameForOpenCode(name, toolNameMap)
if mapped != name {
v["name"] = mapped
changed = true
}
}
}
for _, item := range v {
if rewriteToolNamesInValue(item, toolNameMap) {
changed = true
}
}
return changed
case []any:
changed := false
for _, item := range v {
if rewriteToolNamesInValue(item, toolNameMap) {
changed = true
}
}
return changed
default:
return false
}
}
func (s *GatewayService) replaceToolNamesInSSELine(line string, toolNameMap map[string]string) string {
if !sseDataRe.MatchString(line) {
return line
}
data := sseDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" {
return line
}
var event map[string]any
if err := json.Unmarshal([]byte(data), &event); err != nil {
return line
}
if !rewriteToolNamesInValue(event, toolNameMap) {
return line
}
newData, err := json.Marshal(event)
if err != nil {
return line
}
return "data: " + string(newData)
}
func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
// 解析message_start获取input tokens标准Claude API格式 // 解析message_start获取input tokens标准Claude API格式
var msgStart struct { var msgStart struct {
@@ -3001,7 +3391,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
} }
} }
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) { func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string, toolNameMap map[string]string) (*ClaudeUsage, error) {
// 更新5h窗口状态 // 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
@@ -3022,6 +3412,9 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
if originalModel != mappedModel { if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel) body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
} }
if account.IsOAuth() {
body = s.replaceToolNamesInResponseBody(body, toolNameMap)
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
@@ -3059,6 +3452,24 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
return newBody return newBody
} }
func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap map[string]string) []byte {
if len(body) == 0 {
return body
}
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
return body
}
if !rewriteToolNamesInValue(resp, toolNameMap) {
return body
}
newBody, err := json.Marshal(resp)
if err != nil {
return body
}
return newBody
}
// RecordUsageInput 记录使用量的输入参数 // RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct { type RecordUsageInput struct {
Result *ForwardResult Result *ForwardResult
@@ -3224,6 +3635,11 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
body := parsed.Body body := parsed.Body
reqModel := parsed.Model reqModel := parsed.Model
if account.IsOAuth() {
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
body, reqModel, _ = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
}
// Antigravity 账户不支持 count_tokens 转发,直接返回空值 // Antigravity 账户不支持 count_tokens 转发,直接返回空值
if account.Platform == PlatformAntigravity { if account.Platform == PlatformAntigravity {
c.JSON(http.StatusOK, gin.H{"input_tokens": 0}) c.JSON(http.StatusOK, gin.H{"input_tokens": 0})
@@ -3412,6 +3828,9 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if req.Header.Get("anthropic-version") == "" { if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01") req.Header.Set("anthropic-version", "2023-06-01")
} }
if tokenType == "oauth" {
applyClaudeOAuthHeaderDefaults(req, false)
}
// OAuth 账号:处理 anthropic-beta header // OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" { if tokenType == "oauth" {

View File

@@ -24,13 +24,13 @@ var (
// 默认指纹值(当客户端未提供时使用) // 默认指纹值(当客户端未提供时使用)
var defaultFingerprint = Fingerprint{ var defaultFingerprint = Fingerprint{
UserAgent: "claude-cli/2.0.62 (external, cli)", UserAgent: "claude-cli/2.1.2 (external, cli)",
StainlessLang: "js", StainlessLang: "js",
StainlessPackageVersion: "0.52.0", StainlessPackageVersion: "0.70.0",
StainlessOS: "Linux", StainlessOS: "Linux",
StainlessArch: "x64", StainlessArch: "x64",
StainlessRuntime: "node", StainlessRuntime: "node",
StainlessRuntimeVersion: "v22.14.0", StainlessRuntimeVersion: "v24.3.0",
} }
// Fingerprint represents account fingerprint data // Fingerprint represents account fingerprint data
@@ -230,7 +230,7 @@ func generateUUIDFromSeed(seed string) string {
} }
// parseUserAgentVersion 解析user-agent版本号 // parseUserAgentVersion 解析user-agent版本号
// 例如claude-cli/2.0.62 -> (2, 0, 62) // 例如claude-cli/2.1.2 -> (2, 1, 2)
func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) { func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) {
// 匹配 xxx/x.y.z 格式 // 匹配 xxx/x.y.z 格式
matches := userAgentVersionRegex.FindStringSubmatch(ua) matches := userAgentVersionRegex.FindStringSubmatch(ua)