fix: 收敛 Claude Code 探测拦截并补齐回归测试
This commit is contained in:
@@ -2,6 +2,7 @@ package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -111,9 +112,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否为 Claude Code 客户端,设置到 context 中
|
||||
SetClaudeCodeClientContext(c, body)
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
|
||||
parsedReq, err := service.ParseGatewayRequest(body)
|
||||
@@ -121,11 +119,23 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
|
||||
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
|
||||
reqModel := parsedReq.Model
|
||||
reqStream := parsedReq.Stream
|
||||
|
||||
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
|
||||
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
|
||||
if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) {
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
// 检查是否为 Claude Code 客户端,设置到 context 中
|
||||
SetClaudeCodeClientContext(c, body)
|
||||
isClaudeCodeClient := service.IsClaudeCodeClient(c.Request.Context())
|
||||
|
||||
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
|
||||
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
|
||||
// 验证 model 必填
|
||||
@@ -241,7 +251,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
||||
if account.IsInterceptWarmupEnabled() {
|
||||
interceptType := detectInterceptType(body)
|
||||
interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient)
|
||||
if interceptType != InterceptTypeNone {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
@@ -403,7 +413,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
||||
if account.IsInterceptWarmupEnabled() {
|
||||
interceptType := detectInterceptType(body)
|
||||
interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient)
|
||||
if interceptType != InterceptTypeNone {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
@@ -977,10 +987,34 @@ const (
|
||||
InterceptTypeNone InterceptType = iota
|
||||
InterceptTypeWarmup // 预热请求(返回 "New Conversation")
|
||||
InterceptTypeSuggestionMode // SUGGESTION MODE(返回空字符串)
|
||||
InterceptTypeMaxTokensOneHaiku // max_tokens=1 + haiku 探测请求(返回 "#")
|
||||
)
|
||||
|
||||
// isHaikuModel 检查模型名称是否包含 "haiku"(大小写不敏感)
|
||||
func isHaikuModel(model string) bool {
|
||||
return strings.Contains(strings.ToLower(model), "haiku")
|
||||
}
|
||||
|
||||
// isMaxTokensOneHaikuRequest 检查是否为 max_tokens=1 + haiku 模型的探测请求
|
||||
// 这类请求用于 Claude Code 验证 API 连通性
|
||||
// 条件:max_tokens == 1 且 model 包含 "haiku" 且非流式请求
|
||||
func isMaxTokensOneHaikuRequest(model string, maxTokens int, isStream bool) bool {
|
||||
return maxTokens == 1 && isHaikuModel(model) && !isStream
|
||||
}
|
||||
|
||||
// detectInterceptType 检测请求是否需要拦截,返回拦截类型
|
||||
func detectInterceptType(body []byte) InterceptType {
|
||||
// 参数说明:
|
||||
// - body: 请求体字节
|
||||
// - model: 请求的模型名称
|
||||
// - maxTokens: max_tokens 值
|
||||
// - isStream: 是否为流式请求
|
||||
// - isClaudeCodeClient: 是否已通过 Claude Code 客户端校验
|
||||
func detectInterceptType(body []byte, model string, maxTokens int, isStream bool, isClaudeCodeClient bool) InterceptType {
|
||||
// 优先检查 max_tokens=1 + haiku 探测请求(仅非流式)
|
||||
if isClaudeCodeClient && isMaxTokensOneHaikuRequest(model, maxTokens, isStream) {
|
||||
return InterceptTypeMaxTokensOneHaiku
|
||||
}
|
||||
|
||||
// 快速检查:如果不包含任何关键字,直接返回
|
||||
bodyStr := string(body)
|
||||
hasSuggestionMode := strings.Contains(bodyStr, "[SUGGESTION MODE:")
|
||||
@@ -1130,9 +1164,25 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce
|
||||
}
|
||||
}
|
||||
|
||||
// generateRealisticMsgID 生成仿真的消息 ID(msg_bdrk_XXXXXXX 格式)
|
||||
// 格式与 Claude API 真实响应一致,24 位随机字母数字
|
||||
func generateRealisticMsgID() string {
|
||||
const charset = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
const idLen = 24
|
||||
randomBytes := make([]byte, idLen)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return fmt.Sprintf("msg_bdrk_%d", time.Now().UnixNano())
|
||||
}
|
||||
b := make([]byte, idLen)
|
||||
for i := range b {
|
||||
b[i] = charset[int(randomBytes[i])%len(charset)]
|
||||
}
|
||||
return "msg_bdrk_" + string(b)
|
||||
}
|
||||
|
||||
// sendMockInterceptResponse 发送非流式 mock 响应(用于请求拦截)
|
||||
func sendMockInterceptResponse(c *gin.Context, model string, interceptType InterceptType) {
|
||||
var msgID, text string
|
||||
var msgID, text, stopReason string
|
||||
var outputTokens int
|
||||
|
||||
switch interceptType {
|
||||
@@ -1140,24 +1190,42 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter
|
||||
msgID = "msg_mock_suggestion"
|
||||
text = ""
|
||||
outputTokens = 1
|
||||
stopReason = "end_turn"
|
||||
case InterceptTypeMaxTokensOneHaiku:
|
||||
msgID = generateRealisticMsgID()
|
||||
text = "#"
|
||||
outputTokens = 1
|
||||
stopReason = "max_tokens" // max_tokens=1 探测请求的 stop_reason 应为 max_tokens
|
||||
default: // InterceptTypeWarmup
|
||||
msgID = "msg_mock_warmup"
|
||||
text = "New Conversation"
|
||||
outputTokens = 2
|
||||
stopReason = "end_turn"
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
// 构建完整的响应格式(与 Claude API 响应格式一致)
|
||||
response := gin.H{
|
||||
"model": model,
|
||||
"id": msgID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": []gin.H{{"type": "text", "text": text}},
|
||||
"stop_reason": "end_turn",
|
||||
"stop_reason": stopReason,
|
||||
"stop_sequence": nil,
|
||||
"usage": gin.H{
|
||||
"input_tokens": 10,
|
||||
"output_tokens": outputTokens,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation": gin.H{
|
||||
"ephemeral_5m_input_tokens": 0,
|
||||
"ephemeral_1h_input_tokens": 0,
|
||||
},
|
||||
})
|
||||
"output_tokens": outputTokens,
|
||||
"total_tokens": 10 + outputTokens,
|
||||
},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
func billingErrorDetails(err error) (status int, code, message string) {
|
||||
|
||||
65
backend/internal/handler/gateway_handler_intercept_test.go
Normal file
65
backend/internal/handler/gateway_handler_intercept_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDetectInterceptType_MaxTokensOneHaikuRequiresClaudeCodeClient(t *testing.T) {
|
||||
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
|
||||
|
||||
notClaudeCode := detectInterceptType(body, "claude-haiku-4-5", 1, false, false)
|
||||
require.Equal(t, InterceptTypeNone, notClaudeCode)
|
||||
|
||||
isClaudeCode := detectInterceptType(body, "claude-haiku-4-5", 1, false, true)
|
||||
require.Equal(t, InterceptTypeMaxTokensOneHaiku, isClaudeCode)
|
||||
}
|
||||
|
||||
func TestDetectInterceptType_SuggestionModeUnaffected(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[{
|
||||
"role":"user",
|
||||
"content":[{"type":"text","text":"[SUGGESTION MODE:foo]"}]
|
||||
}],
|
||||
"system":[]
|
||||
}`)
|
||||
|
||||
got := detectInterceptType(body, "claude-sonnet-4-5", 256, false, false)
|
||||
require.Equal(t, InterceptTypeSuggestionMode, got)
|
||||
}
|
||||
|
||||
func TestSendMockInterceptResponse_MaxTokensOneHaiku(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(rec)
|
||||
|
||||
sendMockInterceptResponse(ctx, "claude-haiku-4-5", InterceptTypeMaxTokensOneHaiku)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &response))
|
||||
require.Equal(t, "max_tokens", response["stop_reason"])
|
||||
|
||||
id, ok := response["id"].(string)
|
||||
require.True(t, ok)
|
||||
require.True(t, strings.HasPrefix(id, "msg_bdrk_"))
|
||||
|
||||
content, ok := response["content"].([]any)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, content)
|
||||
|
||||
firstBlock, ok := content[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "#", firstBlock["text"])
|
||||
|
||||
usage, ok := response["usage"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, float64(1), usage["output_tokens"])
|
||||
}
|
||||
@@ -24,4 +24,8 @@ const (
|
||||
ThinkingEnabled Key = "ctx_thinking_enabled"
|
||||
// Group 认证后的分组信息,由 API Key 认证中间件设置
|
||||
Group Key = "ctx_group"
|
||||
|
||||
// IsMaxTokensOneHaikuRequest 标识当前请求是否为 max_tokens=1 + haiku 模型的探测请求
|
||||
// 用于 ClaudeCodeOnly 验证绕过(绕过 system prompt 检查,但仍需验证 User-Agent)
|
||||
IsMaxTokensOneHaikuRequest Key = "ctx_is_max_tokens_one_haiku"
|
||||
)
|
||||
|
||||
@@ -56,7 +56,8 @@ func NewClaudeCodeValidator() *ClaudeCodeValidator {
|
||||
//
|
||||
// Step 1: User-Agent 检查 (必需) - 必须是 claude-cli/x.x.x
|
||||
// Step 2: 对于非 messages 路径,只要 UA 匹配就通过
|
||||
// Step 3: 对于 messages 路径,进行严格验证:
|
||||
// Step 3: 检查 max_tokens=1 + haiku 探测请求绕过(UA 已验证)
|
||||
// Step 4: 对于 messages 路径,进行严格验证:
|
||||
// - System prompt 相似度检查
|
||||
// - X-App header 检查
|
||||
// - anthropic-beta header 检查
|
||||
@@ -75,14 +76,20 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo
|
||||
return true
|
||||
}
|
||||
|
||||
// Step 3: messages 路径,进行严格验证
|
||||
// Step 3: 检查 max_tokens=1 + haiku 探测请求绕过
|
||||
// 这类请求用于 Claude Code 验证 API 连通性,不携带 system prompt
|
||||
if isMaxTokensOneHaiku, ok := r.Context().Value(ctxkey.IsMaxTokensOneHaikuRequest).(bool); ok && isMaxTokensOneHaiku {
|
||||
return true // 绕过 system prompt 检查,UA 已在 Step 1 验证
|
||||
}
|
||||
|
||||
// 3.1 检查 system prompt 相似度
|
||||
// Step 4: messages 路径,进行严格验证
|
||||
|
||||
// 4.1 检查 system prompt 相似度
|
||||
if !v.hasClaudeCodeSystemPrompt(body) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 3.2 检查必需的 headers(值不为空即可)
|
||||
// 4.2 检查必需的 headers(值不为空即可)
|
||||
xApp := r.Header.Get("X-App")
|
||||
if xApp == "" {
|
||||
return false
|
||||
@@ -98,7 +105,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo
|
||||
return false
|
||||
}
|
||||
|
||||
// 3.3 验证 metadata.user_id
|
||||
// 4.3 验证 metadata.user_id
|
||||
if body == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
58
backend/internal/service/claude_code_validator_test.go
Normal file
58
backend/internal/service/claude_code_validator_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClaudeCodeValidator_ProbeBypass(t *testing.T) {
|
||||
validator := NewClaudeCodeValidator()
|
||||
req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/messages", nil)
|
||||
req.Header.Set("User-Agent", "claude-cli/1.2.3 (darwin; arm64)")
|
||||
req = req.WithContext(context.WithValue(req.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true))
|
||||
|
||||
ok := validator.Validate(req, map[string]any{
|
||||
"model": "claude-haiku-4-5",
|
||||
"max_tokens": 1,
|
||||
})
|
||||
require.True(t, ok)
|
||||
}
|
||||
|
||||
func TestClaudeCodeValidator_ProbeBypassRequiresUA(t *testing.T) {
|
||||
validator := NewClaudeCodeValidator()
|
||||
req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/messages", nil)
|
||||
req.Header.Set("User-Agent", "curl/8.0.0")
|
||||
req = req.WithContext(context.WithValue(req.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true))
|
||||
|
||||
ok := validator.Validate(req, map[string]any{
|
||||
"model": "claude-haiku-4-5",
|
||||
"max_tokens": 1,
|
||||
})
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestClaudeCodeValidator_MessagesWithoutProbeStillNeedStrictValidation(t *testing.T) {
|
||||
validator := NewClaudeCodeValidator()
|
||||
req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/messages", nil)
|
||||
req.Header.Set("User-Agent", "claude-cli/1.2.3 (darwin; arm64)")
|
||||
|
||||
ok := validator.Validate(req, map[string]any{
|
||||
"model": "claude-haiku-4-5",
|
||||
"max_tokens": 1,
|
||||
})
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestClaudeCodeValidator_NonMessagesPathUAOnly(t *testing.T) {
|
||||
validator := NewClaudeCodeValidator()
|
||||
req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/models", nil)
|
||||
req.Header.Set("User-Agent", "claude-cli/1.2.3 (darwin; arm64)")
|
||||
|
||||
ok := validator.Validate(req, nil)
|
||||
require.True(t, ok)
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
@@ -29,6 +30,7 @@ type ParsedRequest struct {
|
||||
Messages []any // messages 数组
|
||||
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
|
||||
ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名)
|
||||
MaxTokens int // max_tokens 值(用于探测请求拦截)
|
||||
}
|
||||
|
||||
// ParseGatewayRequest 解析网关请求体并返回结构化结果
|
||||
@@ -79,9 +81,55 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// max_tokens
|
||||
if rawMaxTokens, exists := req["max_tokens"]; exists {
|
||||
if maxTokens, ok := parseIntegralNumber(rawMaxTokens); ok {
|
||||
parsed.MaxTokens = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
// parseIntegralNumber 将 JSON 解码后的数字安全转换为 int。
|
||||
// 仅接受“整数值”的输入,小数/NaN/Inf/越界值都会返回 false。
|
||||
func parseIntegralNumber(raw any) (int, bool) {
|
||||
switch v := raw.(type) {
|
||||
case float64:
|
||||
if math.IsNaN(v) || math.IsInf(v, 0) || v != math.Trunc(v) {
|
||||
return 0, false
|
||||
}
|
||||
if v > float64(math.MaxInt) || v < float64(math.MinInt) {
|
||||
return 0, false
|
||||
}
|
||||
return int(v), true
|
||||
case int:
|
||||
return v, true
|
||||
case int8:
|
||||
return int(v), true
|
||||
case int16:
|
||||
return int(v), true
|
||||
case int32:
|
||||
return int(v), true
|
||||
case int64:
|
||||
if v > int64(math.MaxInt) || v < int64(math.MinInt) {
|
||||
return 0, false
|
||||
}
|
||||
return int(v), true
|
||||
case json.Number:
|
||||
i64, err := v.Int64()
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
if i64 > int64(math.MaxInt) || i64 < int64(math.MinInt) {
|
||||
return 0, false
|
||||
}
|
||||
return int(i64), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// FilterThinkingBlocks removes thinking blocks from request body
|
||||
// Returns filtered body or original body if filtering fails (fail-safe)
|
||||
// This prevents 400 errors from invalid thinking block signatures
|
||||
|
||||
@@ -28,6 +28,20 @@ func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) {
|
||||
require.True(t, parsed.ThinkingEnabled)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_MaxTokens(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1}`)
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, parsed.MaxTokens)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_MaxTokensNonIntegralIgnored(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1.5}`)
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, parsed.MaxTokens)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_SystemNull(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-3","system":null}`)
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
|
||||
Reference in New Issue
Block a user