fix(网关): 区分 Claude Code OAuth 适配
This commit is contained in:
@@ -707,6 +707,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否为 Claude Code 客户端,设置到 context 中
|
||||
SetClaudeCodeClientContext(c, body)
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
|
||||
parsedReq, err := service.ParseGatewayRequest(body)
|
||||
|
||||
@@ -9,11 +9,15 @@ const (
|
||||
BetaClaudeCode = "claude-code-20250219"
|
||||
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
|
||||
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
|
||||
BetaTokenCounting = "token-counting-2024-11-01"
|
||||
)
|
||||
|
||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
||||
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||
|
||||
// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
|
||||
const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting
|
||||
|
||||
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
|
||||
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
|
||||
|
||||
|
||||
@@ -65,6 +65,8 @@ var (
|
||||
toolPrefixRe = regexp.MustCompile(`(?i)^(?:oc_|mcp_)`)
|
||||
toolNameBoundaryRe = regexp.MustCompile(`[^a-zA-Z0-9]+`)
|
||||
toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`)
|
||||
toolNameFieldRe = regexp.MustCompile(`"name"\s*:\s*"([^"]+)"`)
|
||||
modelFieldRe = regexp.MustCompile(`"model"\s*:\s*"([^"]+)"`)
|
||||
|
||||
claudeToolNameOverrides = map[string]string{
|
||||
"bash": "Bash",
|
||||
@@ -1941,6 +1943,16 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
|
||||
return claudeCliUserAgentRe.MatchString(userAgent)
|
||||
}
|
||||
|
||||
func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequest) bool {
|
||||
if IsClaudeCodeClient(ctx) {
|
||||
return true
|
||||
}
|
||||
if parsed == nil || c == nil {
|
||||
return false
|
||||
}
|
||||
return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
|
||||
}
|
||||
|
||||
// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词
|
||||
// 使用前缀匹配支持多种变体(标准版、Agent SDK 版等)
|
||||
func systemIncludesClaudeCodePrompt(system any) bool {
|
||||
@@ -2203,11 +2215,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
originalModel := reqModel
|
||||
var toolNameMap map[string]string
|
||||
|
||||
if account.IsOAuth() {
|
||||
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
|
||||
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
|
||||
|
||||
if shouldMimicClaudeCode {
|
||||
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
|
||||
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
|
||||
if !isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) &&
|
||||
!strings.Contains(strings.ToLower(reqModel), "haiku") &&
|
||||
if !strings.Contains(strings.ToLower(reqModel), "haiku") &&
|
||||
!systemIncludesClaudeCodePrompt(parsed.System) {
|
||||
body = injectClaudeCodePrompt(body, parsed.System)
|
||||
}
|
||||
@@ -2257,7 +2271,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
retryStart := time.Now()
|
||||
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
||||
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
// Capture upstream request body for ops retry of this attempt.
|
||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||
if err != nil {
|
||||
@@ -2337,7 +2351,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// also downgrade tool_use/tool_result blocks to text.
|
||||
|
||||
filteredBody := FilterThinkingBlocksForRetry(body)
|
||||
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream)
|
||||
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
if buildErr == nil {
|
||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||
if retryErr == nil {
|
||||
@@ -2369,7 +2383,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
|
||||
log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
|
||||
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
|
||||
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream)
|
||||
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
if buildErr2 == nil {
|
||||
retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency)
|
||||
if retryErr2 == nil {
|
||||
@@ -2586,7 +2600,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
var firstTokenMs *int
|
||||
var clientDisconnect bool
|
||||
if reqStream {
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, toolNameMap)
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode)
|
||||
if err != nil {
|
||||
if err.Error() == "have error in stream" {
|
||||
return nil, &UpstreamFailoverError{
|
||||
@@ -2599,7 +2613,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
firstTokenMs = streamResult.firstTokenMs
|
||||
clientDisconnect = streamResult.clientDisconnect
|
||||
} else {
|
||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel, toolNameMap)
|
||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -2616,7 +2630,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool) (*http.Request, error) {
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) {
|
||||
// 确定目标URL
|
||||
targetURL := claudeAPIURL
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
@@ -2632,7 +2646,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
|
||||
// OAuth账号:应用统一指纹
|
||||
var fingerprint *Fingerprint
|
||||
if account.IsOAuth() && s.identityService != nil {
|
||||
if account.IsOAuth() && mimicClaudeCode && s.identityService != nil {
|
||||
// 1. 获取或创建指纹(包含随机生成的ClientID)
|
||||
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
||||
if err != nil {
|
||||
@@ -2685,12 +2699,12 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
if req.Header.Get("anthropic-version") == "" {
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
}
|
||||
if tokenType == "oauth" {
|
||||
if tokenType == "oauth" && mimicClaudeCode {
|
||||
applyClaudeOAuthHeaderDefaults(req, reqStream)
|
||||
}
|
||||
|
||||
// 处理anthropic-beta header(OAuth账号需要特殊处理)
|
||||
if tokenType == "oauth" {
|
||||
if tokenType == "oauth" && mimicClaudeCode {
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
||||
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
|
||||
@@ -3070,7 +3084,7 @@ type streamingResult struct {
|
||||
clientDisconnect bool // 客户端是否在流式传输过程中断开
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, toolNameMap map[string]string) (*streamingResult, error) {
|
||||
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*streamingResult, error) {
|
||||
// 更新5h窗口状态
|
||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||
|
||||
@@ -3163,7 +3177,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
|
||||
needModelReplace := originalModel != mappedModel
|
||||
rewriteTools := account.IsOAuth()
|
||||
rewriteTools := mimicClaudeCode
|
||||
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
|
||||
|
||||
for {
|
||||
@@ -3327,6 +3341,37 @@ func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func replaceToolNamesInText(text string, toolNameMap map[string]string) string {
|
||||
if text == "" {
|
||||
return text
|
||||
}
|
||||
output := toolNameFieldRe.ReplaceAllStringFunc(text, func(match string) string {
|
||||
submatches := toolNameFieldRe.FindStringSubmatch(match)
|
||||
if len(submatches) < 2 {
|
||||
return match
|
||||
}
|
||||
name := submatches[1]
|
||||
mapped := normalizeToolNameForOpenCode(name, toolNameMap)
|
||||
if mapped == name {
|
||||
return match
|
||||
}
|
||||
return strings.Replace(match, name, mapped, 1)
|
||||
})
|
||||
output = modelFieldRe.ReplaceAllStringFunc(output, func(match string) string {
|
||||
submatches := modelFieldRe.FindStringSubmatch(match)
|
||||
if len(submatches) < 2 {
|
||||
return match
|
||||
}
|
||||
model := submatches[1]
|
||||
mapped := claude.DenormalizeModelID(model)
|
||||
if mapped == model {
|
||||
return match
|
||||
}
|
||||
return strings.Replace(match, model, mapped, 1)
|
||||
})
|
||||
return output
|
||||
}
|
||||
|
||||
func (s *GatewayService) replaceToolNamesInSSELine(line string, toolNameMap map[string]string) string {
|
||||
if !sseDataRe.MatchString(line) {
|
||||
return line
|
||||
@@ -3338,7 +3383,11 @@ func (s *GatewayService) replaceToolNamesInSSELine(line string, toolNameMap map[
|
||||
|
||||
var event map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
return line
|
||||
replaced := replaceToolNamesInText(data, toolNameMap)
|
||||
if replaced == data {
|
||||
return line
|
||||
}
|
||||
return "data: " + replaced
|
||||
}
|
||||
if !rewriteToolNamesInValue(event, toolNameMap) {
|
||||
return line
|
||||
@@ -3391,7 +3440,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string, toolNameMap map[string]string) (*ClaudeUsage, error) {
|
||||
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*ClaudeUsage, error) {
|
||||
// 更新5h窗口状态
|
||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||
|
||||
@@ -3412,7 +3461,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
if originalModel != mappedModel {
|
||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||
}
|
||||
if account.IsOAuth() {
|
||||
if mimicClaudeCode {
|
||||
body = s.replaceToolNamesInResponseBody(body, toolNameMap)
|
||||
}
|
||||
|
||||
@@ -3458,7 +3507,11 @@ func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap
|
||||
}
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return body
|
||||
replaced := replaceToolNamesInText(string(body), toolNameMap)
|
||||
if replaced == string(body) {
|
||||
return body
|
||||
}
|
||||
return []byte(replaced)
|
||||
}
|
||||
if !rewriteToolNamesInValue(resp, toolNameMap) {
|
||||
return body
|
||||
@@ -3635,7 +3688,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
body := parsed.Body
|
||||
reqModel := parsed.Model
|
||||
|
||||
if account.IsOAuth() {
|
||||
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
|
||||
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
|
||||
|
||||
if shouldMimicClaudeCode {
|
||||
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
|
||||
body, reqModel, _ = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
||||
}
|
||||
@@ -3666,7 +3722,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// 构建上游请求
|
||||
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel)
|
||||
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel, shouldMimicClaudeCode)
|
||||
if err != nil {
|
||||
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
|
||||
return err
|
||||
@@ -3699,7 +3755,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
|
||||
|
||||
filteredBody := FilterThinkingBlocksForRetry(body)
|
||||
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
|
||||
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, shouldMimicClaudeCode)
|
||||
if buildErr == nil {
|
||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||
if retryErr == nil {
|
||||
@@ -3764,7 +3820,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// buildCountTokensRequest 构建 count_tokens 上游请求
|
||||
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
|
||||
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, mimicClaudeCode bool) (*http.Request, error) {
|
||||
// 确定目标 URL
|
||||
targetURL := claudeAPICountTokensURL
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
@@ -3779,7 +3835,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
}
|
||||
|
||||
// OAuth 账号:应用统一指纹和重写 userID
|
||||
if account.IsOAuth() && s.identityService != nil {
|
||||
if account.IsOAuth() && mimicClaudeCode && s.identityService != nil {
|
||||
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
||||
if err == nil {
|
||||
accountUUID := account.GetExtraString("account_uuid")
|
||||
@@ -3814,7 +3870,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
}
|
||||
|
||||
// OAuth 账号:应用指纹到请求头
|
||||
if account.IsOAuth() && s.identityService != nil {
|
||||
if account.IsOAuth() && mimicClaudeCode && s.identityService != nil {
|
||||
fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
||||
if fp != nil {
|
||||
s.identityService.ApplyFingerprint(req, fp)
|
||||
@@ -3828,13 +3884,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
if req.Header.Get("anthropic-version") == "" {
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
}
|
||||
if tokenType == "oauth" {
|
||||
if tokenType == "oauth" && mimicClaudeCode {
|
||||
applyClaudeOAuthHeaderDefaults(req, false)
|
||||
}
|
||||
|
||||
// OAuth 账号:处理 anthropic-beta header
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||
if tokenType == "oauth" && mimicClaudeCode {
|
||||
req.Header.Set("anthropic-beta", claude.CountTokensBetaHeader)
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
||||
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
|
||||
if requestNeedsBetaFeatures(body) {
|
||||
|
||||
Reference in New Issue
Block a user