fix(网关): 区分 Claude Code OAuth 适配

This commit is contained in:
cyhhao
2026-01-15 19:17:07 +08:00
parent 2a7d04fec4
commit b8c48fb477
3 changed files with 90 additions and 27 deletions

View File

@@ -707,6 +707,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
// 检查是否为 Claude Code 客户端,设置到 context 中
SetClaudeCodeClientContext(c, body)
setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body)

View File

@@ -9,11 +9,15 @@ const (
BetaClaudeCode = "claude-code-20250219"
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
BetaTokenCounting = "token-counting-2024-11-01"
)
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header不需要 claude-code beta
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking

View File

@@ -65,6 +65,8 @@ var (
toolPrefixRe = regexp.MustCompile(`(?i)^(?:oc_|mcp_)`)
toolNameBoundaryRe = regexp.MustCompile(`[^a-zA-Z0-9]+`)
toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`)
toolNameFieldRe = regexp.MustCompile(`"name"\s*:\s*"([^"]+)"`)
modelFieldRe = regexp.MustCompile(`"model"\s*:\s*"([^"]+)"`)
claudeToolNameOverrides = map[string]string{
"bash": "Bash",
@@ -1941,6 +1943,16 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
return claudeCliUserAgentRe.MatchString(userAgent)
}
func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequest) bool {
if IsClaudeCodeClient(ctx) {
return true
}
if parsed == nil || c == nil {
return false
}
return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
}
// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词
// 使用前缀匹配支持多种变体标准版、Agent SDK 版等)
func systemIncludesClaudeCodePrompt(system any) bool {
@@ -2203,11 +2215,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
originalModel := reqModel
var toolNameMap map[string]string
if account.IsOAuth() {
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode {
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
// 条件1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
if !isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) &&
!strings.Contains(strings.ToLower(reqModel), "haiku") &&
if !strings.Contains(strings.ToLower(reqModel), "haiku") &&
!systemIncludesClaudeCodePrompt(parsed.System) {
body = injectClaudeCodePrompt(body, parsed.System)
}
@@ -2257,7 +2271,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
retryStart := time.Now()
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
// Capture upstream request body for ops retry of this attempt.
c.Set(OpsUpstreamRequestBodyKey, string(body))
if err != nil {
@@ -2337,7 +2351,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// also downgrade tool_use/tool_result blocks to text.
filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream)
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil {
@@ -2369,7 +2383,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream)
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
if buildErr2 == nil {
retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency)
if retryErr2 == nil {
@@ -2586,7 +2600,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
var firstTokenMs *int
var clientDisconnect bool
if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, toolNameMap)
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode)
if err != nil {
if err.Error() == "have error in stream" {
return nil, &UpstreamFailoverError{
@@ -2599,7 +2613,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
firstTokenMs = streamResult.firstTokenMs
clientDisconnect = streamResult.clientDisconnect
} else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel, toolNameMap)
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode)
if err != nil {
return nil, err
}
@@ -2616,7 +2630,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}, nil
}
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool) (*http.Request, error) {
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) {
// 确定目标URL
targetURL := claudeAPIURL
if account.Type == AccountTypeAPIKey {
@@ -2632,7 +2646,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// OAuth账号应用统一指纹
var fingerprint *Fingerprint
if account.IsOAuth() && s.identityService != nil {
if account.IsOAuth() && mimicClaudeCode && s.identityService != nil {
// 1. 获取或创建指纹包含随机生成的ClientID
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err != nil {
@@ -2685,12 +2699,12 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")
}
if tokenType == "oauth" {
if tokenType == "oauth" && mimicClaudeCode {
applyClaudeOAuthHeaderDefaults(req, reqStream)
}
// 处理anthropic-beta headerOAuth账号需要特殊处理
if tokenType == "oauth" {
if tokenType == "oauth" && mimicClaudeCode {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
@@ -3070,7 +3084,7 @@ type streamingResult struct {
clientDisconnect bool // 客户端是否在流式传输过程中断开
}
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, toolNameMap map[string]string) (*streamingResult, error) {
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*streamingResult, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
@@ -3163,7 +3177,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
needModelReplace := originalModel != mappedModel
rewriteTools := account.IsOAuth()
rewriteTools := mimicClaudeCode
clientDisconnected := false // 客户端断开标志断开后继续读取上游以获取完整usage
for {
@@ -3327,6 +3341,37 @@ func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
}
}
func replaceToolNamesInText(text string, toolNameMap map[string]string) string {
if text == "" {
return text
}
output := toolNameFieldRe.ReplaceAllStringFunc(text, func(match string) string {
submatches := toolNameFieldRe.FindStringSubmatch(match)
if len(submatches) < 2 {
return match
}
name := submatches[1]
mapped := normalizeToolNameForOpenCode(name, toolNameMap)
if mapped == name {
return match
}
return strings.Replace(match, name, mapped, 1)
})
output = modelFieldRe.ReplaceAllStringFunc(output, func(match string) string {
submatches := modelFieldRe.FindStringSubmatch(match)
if len(submatches) < 2 {
return match
}
model := submatches[1]
mapped := claude.DenormalizeModelID(model)
if mapped == model {
return match
}
return strings.Replace(match, model, mapped, 1)
})
return output
}
func (s *GatewayService) replaceToolNamesInSSELine(line string, toolNameMap map[string]string) string {
if !sseDataRe.MatchString(line) {
return line
@@ -3338,7 +3383,11 @@ func (s *GatewayService) replaceToolNamesInSSELine(line string, toolNameMap map[
var event map[string]any
if err := json.Unmarshal([]byte(data), &event); err != nil {
return line
replaced := replaceToolNamesInText(data, toolNameMap)
if replaced == data {
return line
}
return "data: " + replaced
}
if !rewriteToolNamesInValue(event, toolNameMap) {
return line
@@ -3391,7 +3440,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
}
}
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string, toolNameMap map[string]string) (*ClaudeUsage, error) {
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*ClaudeUsage, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
@@ -3412,7 +3461,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
if account.IsOAuth() {
if mimicClaudeCode {
body = s.replaceToolNamesInResponseBody(body, toolNameMap)
}
@@ -3458,7 +3507,11 @@ func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap
}
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
return body
replaced := replaceToolNamesInText(string(body), toolNameMap)
if replaced == string(body) {
return body
}
return []byte(replaced)
}
if !rewriteToolNamesInValue(resp, toolNameMap) {
return body
@@ -3635,7 +3688,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
body := parsed.Body
reqModel := parsed.Model
if account.IsOAuth() {
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode {
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
body, reqModel, _ = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
}
@@ -3666,7 +3722,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// 构建上游请求
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel)
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel, shouldMimicClaudeCode)
if err != nil {
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
return err
@@ -3699,7 +3755,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, shouldMimicClaudeCode)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil {
@@ -3764,7 +3820,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// buildCountTokensRequest 构建 count_tokens 上游请求
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, mimicClaudeCode bool) (*http.Request, error) {
// 确定目标 URL
targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeAPIKey {
@@ -3779,7 +3835,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
// OAuth 账号:应用统一指纹和重写 userID
if account.IsOAuth() && s.identityService != nil {
if account.IsOAuth() && mimicClaudeCode && s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err == nil {
accountUUID := account.GetExtraString("account_uuid")
@@ -3814,7 +3870,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
// OAuth 账号:应用指纹到请求头
if account.IsOAuth() && s.identityService != nil {
if account.IsOAuth() && mimicClaudeCode && s.identityService != nil {
fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if fp != nil {
s.identityService.ApplyFingerprint(req, fp)
@@ -3828,13 +3884,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")
}
if tokenType == "oauth" {
if tokenType == "oauth" && mimicClaudeCode {
applyClaudeOAuthHeaderDefaults(req, false)
}
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
if tokenType == "oauth" && mimicClaudeCode {
req.Header.Set("anthropic-beta", claude.CountTokensBetaHeader)
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {