From 6aaa4aee6a5a6d431582ba2477a2a2151de24ba1 Mon Sep 17 00:00:00 2001 From: shaw Date: Sat, 7 Feb 2026 19:04:08 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=94=B6=E6=95=9B=20Claude=20Code=20?= =?UTF-8?q?=E6=8E=A2=E6=B5=8B=E6=8B=A6=E6=88=AA=E5=B9=B6=E8=A1=A5=E9=BD=90?= =?UTF-8?q?=E5=9B=9E=E5=BD=92=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/handler/gateway_handler.go | 110 ++++++++++++++---- .../handler/gateway_handler_intercept_test.go | 65 +++++++++++ backend/internal/pkg/ctxkey/ctxkey.go | 4 + .../internal/service/claude_code_validator.go | 17 ++- .../service/claude_code_validator_test.go | 58 +++++++++ backend/internal/service/gateway_request.go | 48 ++++++++ .../internal/service/gateway_request_test.go | 14 +++ 7 files changed, 290 insertions(+), 26 deletions(-) create mode 100644 backend/internal/handler/gateway_handler_intercept_test.go create mode 100644 backend/internal/service/claude_code_validator_test.go diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 7e6b2f03..ca4442e4 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -2,6 +2,7 @@ package handler import ( "context" + "crypto/rand" "encoding/json" "errors" "fmt" @@ -111,9 +112,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } - // 检查是否为 Claude Code 客户端,设置到 context 中 - SetClaudeCodeClientContext(c, body) - setOpsRequestContext(c, "", false, body) parsedReq, err := service.ParseGatewayRequest(body) @@ -121,11 +119,23 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return } - // 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用 - c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled)) reqModel := parsedReq.Model reqStream := parsedReq.Stream + // 设置 max_tokens=1 + haiku 探测请求标识到 context 中 + // 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断 + if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) { + ctx := context.WithValue(c.Request.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true) + c.Request = c.Request.WithContext(ctx) + } + + // 检查是否为 Claude Code 客户端,设置到 context 中 + SetClaudeCodeClientContext(c, body) + isClaudeCodeClient := service.IsClaudeCodeClient(c.Request.Context()) + + // 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用 + c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled)) + setOpsRequestContext(c, reqModel, reqStream, body) // 验证 model 必填 @@ -241,7 +251,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 检查请求拦截(预热请求、SUGGESTION MODE等) if account.IsInterceptWarmupEnabled() { - interceptType := detectInterceptType(body) + interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient) if interceptType != InterceptTypeNone { if selection.Acquired && selection.ReleaseFunc != nil { selection.ReleaseFunc() @@ -403,7 +413,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 检查请求拦截(预热请求、SUGGESTION MODE等) if account.IsInterceptWarmupEnabled() { - interceptType := detectInterceptType(body) + interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient) if interceptType != InterceptTypeNone { if selection.Acquired && selection.ReleaseFunc != nil { selection.ReleaseFunc() @@ -974,13 +984,37 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { type InterceptType int const ( - InterceptTypeNone InterceptType = iota - InterceptTypeWarmup // 预热请求(返回 "New Conversation") - InterceptTypeSuggestionMode // SUGGESTION MODE(返回空字符串) + InterceptTypeNone InterceptType = iota + InterceptTypeWarmup // 预热请求(返回 "New Conversation") + InterceptTypeSuggestionMode // SUGGESTION MODE(返回空字符串) + InterceptTypeMaxTokensOneHaiku // max_tokens=1 + haiku 探测请求(返回 "#") ) +// isHaikuModel 检查模型名称是否包含 "haiku"(大小写不敏感) +func isHaikuModel(model string) bool { + return strings.Contains(strings.ToLower(model), "haiku") +} + +// isMaxTokensOneHaikuRequest 检查是否为 max_tokens=1 + haiku 模型的探测请求 +// 这类请求用于 Claude Code 验证 API 连通性 +// 条件:max_tokens == 1 且 model 包含 "haiku" 且非流式请求 +func isMaxTokensOneHaikuRequest(model string, maxTokens int, isStream bool) bool { + return maxTokens == 1 && isHaikuModel(model) && !isStream +} + // detectInterceptType 检测请求是否需要拦截,返回拦截类型 -func detectInterceptType(body []byte) InterceptType { +// 参数说明: +// - body: 请求体字节 +// - model: 请求的模型名称 +// - maxTokens: max_tokens 值 +// - isStream: 是否为流式请求 +// - isClaudeCodeClient: 是否已通过 Claude Code 客户端校验 +func detectInterceptType(body []byte, model string, maxTokens int, isStream bool, isClaudeCodeClient bool) InterceptType { + // 优先检查 max_tokens=1 + haiku 探测请求(仅非流式) + if isClaudeCodeClient && isMaxTokensOneHaikuRequest(model, maxTokens, isStream) { + return InterceptTypeMaxTokensOneHaiku + } + // 快速检查:如果不包含任何关键字,直接返回 bodyStr := string(body) hasSuggestionMode := strings.Contains(bodyStr, "[SUGGESTION MODE:") @@ -1130,9 +1164,25 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce } } +// generateRealisticMsgID 生成仿真的消息 ID(msg_bdrk_XXXXXXX 格式) +// 格式与 Claude API 真实响应一致,24 位随机字母数字 +func generateRealisticMsgID() string { + const charset = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + const idLen = 24 + randomBytes := make([]byte, idLen) + if _, err := rand.Read(randomBytes); err != nil { + return fmt.Sprintf("msg_bdrk_%d", time.Now().UnixNano()) + } + b := make([]byte, idLen) + for i := range b { + b[i] = charset[int(randomBytes[i])%len(charset)] + } + return "msg_bdrk_" + string(b) +} + // sendMockInterceptResponse 发送非流式 mock 响应(用于请求拦截) func sendMockInterceptResponse(c *gin.Context, model string, interceptType InterceptType) { - var msgID, text string + var msgID, text, stopReason string var outputTokens int switch interceptType { @@ -1140,24 +1190,42 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter msgID = "msg_mock_suggestion" text = "" outputTokens = 1 + stopReason = "end_turn" + case InterceptTypeMaxTokensOneHaiku: + msgID = generateRealisticMsgID() + text = "#" + outputTokens = 1 + stopReason = "max_tokens" // max_tokens=1 探测请求的 stop_reason 应为 max_tokens default: // InterceptTypeWarmup msgID = "msg_mock_warmup" text = "New Conversation" outputTokens = 2 + stopReason = "end_turn" } - c.JSON(http.StatusOK, gin.H{ - "id": msgID, - "type": "message", - "role": "assistant", - "model": model, - "content": []gin.H{{"type": "text", "text": text}}, - "stop_reason": "end_turn", + // 构建完整的响应格式(与 Claude API 响应格式一致) + response := gin.H{ + "model": model, + "id": msgID, + "type": "message", + "role": "assistant", + "content": []gin.H{{"type": "text", "text": text}}, + "stop_reason": stopReason, + "stop_sequence": nil, "usage": gin.H{ - "input_tokens": 10, + "input_tokens": 10, + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0, + "cache_creation": gin.H{ + "ephemeral_5m_input_tokens": 0, + "ephemeral_1h_input_tokens": 0, + }, "output_tokens": outputTokens, + "total_tokens": 10 + outputTokens, }, - }) + } + + c.JSON(http.StatusOK, response) } func billingErrorDetails(err error) (status int, code, message string) { diff --git a/backend/internal/handler/gateway_handler_intercept_test.go b/backend/internal/handler/gateway_handler_intercept_test.go new file mode 100644 index 00000000..9e7d77a1 --- /dev/null +++ b/backend/internal/handler/gateway_handler_intercept_test.go @@ -0,0 +1,65 @@ +package handler + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestDetectInterceptType_MaxTokensOneHaikuRequiresClaudeCodeClient(t *testing.T) { + body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`) + + notClaudeCode := detectInterceptType(body, "claude-haiku-4-5", 1, false, false) + require.Equal(t, InterceptTypeNone, notClaudeCode) + + isClaudeCode := detectInterceptType(body, "claude-haiku-4-5", 1, false, true) + require.Equal(t, InterceptTypeMaxTokensOneHaiku, isClaudeCode) +} + +func TestDetectInterceptType_SuggestionModeUnaffected(t *testing.T) { + body := []byte(`{ + "messages":[{ + "role":"user", + "content":[{"type":"text","text":"[SUGGESTION MODE:foo]"}] + }], + "system":[] + }`) + + got := detectInterceptType(body, "claude-sonnet-4-5", 256, false, false) + require.Equal(t, InterceptTypeSuggestionMode, got) +} + +func TestSendMockInterceptResponse_MaxTokensOneHaiku(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rec) + + sendMockInterceptResponse(ctx, "claude-haiku-4-5", InterceptTypeMaxTokensOneHaiku) + + require.Equal(t, http.StatusOK, rec.Code) + + var response map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &response)) + require.Equal(t, "max_tokens", response["stop_reason"]) + + id, ok := response["id"].(string) + require.True(t, ok) + require.True(t, strings.HasPrefix(id, "msg_bdrk_")) + + content, ok := response["content"].([]any) + require.True(t, ok) + require.NotEmpty(t, content) + + firstBlock, ok := content[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "#", firstBlock["text"]) + + usage, ok := response["usage"].(map[string]any) + require.True(t, ok) + require.Equal(t, float64(1), usage["output_tokens"]) +} diff --git a/backend/internal/pkg/ctxkey/ctxkey.go b/backend/internal/pkg/ctxkey/ctxkey.go index 6e173775..9bf563e7 100644 --- a/backend/internal/pkg/ctxkey/ctxkey.go +++ b/backend/internal/pkg/ctxkey/ctxkey.go @@ -24,4 +24,8 @@ const ( ThinkingEnabled Key = "ctx_thinking_enabled" // Group 认证后的分组信息,由 API Key 认证中间件设置 Group Key = "ctx_group" + + // IsMaxTokensOneHaikuRequest 标识当前请求是否为 max_tokens=1 + haiku 模型的探测请求 + // 用于 ClaudeCodeOnly 验证绕过(绕过 system prompt 检查,但仍需验证 User-Agent) + IsMaxTokensOneHaikuRequest Key = "ctx_is_max_tokens_one_haiku" ) diff --git a/backend/internal/service/claude_code_validator.go b/backend/internal/service/claude_code_validator.go index ab86f1e8..6d06c83e 100644 --- a/backend/internal/service/claude_code_validator.go +++ b/backend/internal/service/claude_code_validator.go @@ -56,7 +56,8 @@ func NewClaudeCodeValidator() *ClaudeCodeValidator { // // Step 1: User-Agent 检查 (必需) - 必须是 claude-cli/x.x.x // Step 2: 对于非 messages 路径,只要 UA 匹配就通过 -// Step 3: 对于 messages 路径,进行严格验证: +// Step 3: 检查 max_tokens=1 + haiku 探测请求绕过(UA 已验证) +// Step 4: 对于 messages 路径,进行严格验证: // - System prompt 相似度检查 // - X-App header 检查 // - anthropic-beta header 检查 @@ -75,14 +76,20 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo return true } - // Step 3: messages 路径,进行严格验证 + // Step 3: 检查 max_tokens=1 + haiku 探测请求绕过 + // 这类请求用于 Claude Code 验证 API 连通性,不携带 system prompt + if isMaxTokensOneHaiku, ok := r.Context().Value(ctxkey.IsMaxTokensOneHaikuRequest).(bool); ok && isMaxTokensOneHaiku { + return true // 绕过 system prompt 检查,UA 已在 Step 1 验证 + } - // 3.1 检查 system prompt 相似度 + // Step 4: messages 路径,进行严格验证 + + // 4.1 检查 system prompt 相似度 if !v.hasClaudeCodeSystemPrompt(body) { return false } - // 3.2 检查必需的 headers(值不为空即可) + // 4.2 检查必需的 headers(值不为空即可) xApp := r.Header.Get("X-App") if xApp == "" { return false @@ -98,7 +105,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo return false } - // 3.3 验证 metadata.user_id + // 4.3 验证 metadata.user_id if body == nil { return false } diff --git a/backend/internal/service/claude_code_validator_test.go b/backend/internal/service/claude_code_validator_test.go new file mode 100644 index 00000000..a4cd1886 --- /dev/null +++ b/backend/internal/service/claude_code_validator_test.go @@ -0,0 +1,58 @@ +package service + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +func TestClaudeCodeValidator_ProbeBypass(t *testing.T) { + validator := NewClaudeCodeValidator() + req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.2.3 (darwin; arm64)") + req = req.WithContext(context.WithValue(req.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true)) + + ok := validator.Validate(req, map[string]any{ + "model": "claude-haiku-4-5", + "max_tokens": 1, + }) + require.True(t, ok) +} + +func TestClaudeCodeValidator_ProbeBypassRequiresUA(t *testing.T) { + validator := NewClaudeCodeValidator() + req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/messages", nil) + req.Header.Set("User-Agent", "curl/8.0.0") + req = req.WithContext(context.WithValue(req.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true)) + + ok := validator.Validate(req, map[string]any{ + "model": "claude-haiku-4-5", + "max_tokens": 1, + }) + require.False(t, ok) +} + +func TestClaudeCodeValidator_MessagesWithoutProbeStillNeedStrictValidation(t *testing.T) { + validator := NewClaudeCodeValidator() + req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.2.3 (darwin; arm64)") + + ok := validator.Validate(req, map[string]any{ + "model": "claude-haiku-4-5", + "max_tokens": 1, + }) + require.False(t, ok) +} + +func TestClaudeCodeValidator_NonMessagesPathUAOnly(t *testing.T) { + validator := NewClaudeCodeValidator() + req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/models", nil) + req.Header.Set("User-Agent", "claude-cli/1.2.3 (darwin; arm64)") + + ok := validator.Validate(req, nil) + require.True(t, ok) +} diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go index 3d82ee2e..0ecd18aa 100644 --- a/backend/internal/service/gateway_request.go +++ b/backend/internal/service/gateway_request.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "math" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" ) @@ -29,6 +30,7 @@ type ParsedRequest struct { Messages []any // messages 数组 HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入) ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名) + MaxTokens int // max_tokens 值(用于探测请求拦截) } // ParseGatewayRequest 解析网关请求体并返回结构化结果 @@ -79,9 +81,55 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) { } } + // max_tokens + if rawMaxTokens, exists := req["max_tokens"]; exists { + if maxTokens, ok := parseIntegralNumber(rawMaxTokens); ok { + parsed.MaxTokens = maxTokens + } + } + return parsed, nil } +// parseIntegralNumber 将 JSON 解码后的数字安全转换为 int。 +// 仅接受“整数值”的输入,小数/NaN/Inf/越界值都会返回 false。 +func parseIntegralNumber(raw any) (int, bool) { + switch v := raw.(type) { + case float64: + if math.IsNaN(v) || math.IsInf(v, 0) || v != math.Trunc(v) { + return 0, false + } + if v > float64(math.MaxInt) || v < float64(math.MinInt) { + return 0, false + } + return int(v), true + case int: + return v, true + case int8: + return int(v), true + case int16: + return int(v), true + case int32: + return int(v), true + case int64: + if v > int64(math.MaxInt) || v < int64(math.MinInt) { + return 0, false + } + return int(v), true + case json.Number: + i64, err := v.Int64() + if err != nil { + return 0, false + } + if i64 > int64(math.MaxInt) || i64 < int64(math.MinInt) { + return 0, false + } + return int(i64), true + default: + return 0, false + } +} + // FilterThinkingBlocks removes thinking blocks from request body // Returns filtered body or original body if filtering fails (fail-safe) // This prevents 400 errors from invalid thinking block signatures diff --git a/backend/internal/service/gateway_request_test.go b/backend/internal/service/gateway_request_test.go index 03167618..4e390b0a 100644 --- a/backend/internal/service/gateway_request_test.go +++ b/backend/internal/service/gateway_request_test.go @@ -28,6 +28,20 @@ func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) { require.True(t, parsed.ThinkingEnabled) } +func TestParseGatewayRequest_MaxTokens(t *testing.T) { + body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1}`) + parsed, err := ParseGatewayRequest(body) + require.NoError(t, err) + require.Equal(t, 1, parsed.MaxTokens) +} + +func TestParseGatewayRequest_MaxTokensNonIntegralIgnored(t *testing.T) { + body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1.5}`) + parsed, err := ParseGatewayRequest(body) + require.NoError(t, err) + require.Equal(t, 0, parsed.MaxTokens) +} + func TestParseGatewayRequest_SystemNull(t *testing.T) { body := []byte(`{"model":"claude-3","system":null}`) parsed, err := ParseGatewayRequest(body)