fix: 收敛 Claude Code 探测拦截并补齐回归测试
This commit is contained in:
@@ -2,6 +2,7 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -111,9 +112,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查是否为 Claude Code 客户端,设置到 context 中
|
|
||||||
SetClaudeCodeClientContext(c, body)
|
|
||||||
|
|
||||||
setOpsRequestContext(c, "", false, body)
|
setOpsRequestContext(c, "", false, body)
|
||||||
|
|
||||||
parsedReq, err := service.ParseGatewayRequest(body)
|
parsedReq, err := service.ParseGatewayRequest(body)
|
||||||
@@ -121,11 +119,23 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
|
|
||||||
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
|
|
||||||
reqModel := parsedReq.Model
|
reqModel := parsedReq.Model
|
||||||
reqStream := parsedReq.Stream
|
reqStream := parsedReq.Stream
|
||||||
|
|
||||||
|
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
|
||||||
|
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
|
||||||
|
if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) {
|
||||||
|
ctx := context.WithValue(c.Request.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true)
|
||||||
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否为 Claude Code 客户端,设置到 context 中
|
||||||
|
SetClaudeCodeClientContext(c, body)
|
||||||
|
isClaudeCodeClient := service.IsClaudeCodeClient(c.Request.Context())
|
||||||
|
|
||||||
|
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
|
||||||
|
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
|
||||||
|
|
||||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||||
|
|
||||||
// 验证 model 必填
|
// 验证 model 必填
|
||||||
@@ -241,7 +251,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
|
|
||||||
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
||||||
if account.IsInterceptWarmupEnabled() {
|
if account.IsInterceptWarmupEnabled() {
|
||||||
interceptType := detectInterceptType(body)
|
interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient)
|
||||||
if interceptType != InterceptTypeNone {
|
if interceptType != InterceptTypeNone {
|
||||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||||
selection.ReleaseFunc()
|
selection.ReleaseFunc()
|
||||||
@@ -403,7 +413,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
|
|
||||||
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
||||||
if account.IsInterceptWarmupEnabled() {
|
if account.IsInterceptWarmupEnabled() {
|
||||||
interceptType := detectInterceptType(body)
|
interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient)
|
||||||
if interceptType != InterceptTypeNone {
|
if interceptType != InterceptTypeNone {
|
||||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||||
selection.ReleaseFunc()
|
selection.ReleaseFunc()
|
||||||
@@ -974,13 +984,37 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
|||||||
type InterceptType int
|
type InterceptType int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
InterceptTypeNone InterceptType = iota
|
InterceptTypeNone InterceptType = iota
|
||||||
InterceptTypeWarmup // 预热请求(返回 "New Conversation")
|
InterceptTypeWarmup // 预热请求(返回 "New Conversation")
|
||||||
InterceptTypeSuggestionMode // SUGGESTION MODE(返回空字符串)
|
InterceptTypeSuggestionMode // SUGGESTION MODE(返回空字符串)
|
||||||
|
InterceptTypeMaxTokensOneHaiku // max_tokens=1 + haiku 探测请求(返回 "#")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// isHaikuModel 检查模型名称是否包含 "haiku"(大小写不敏感)
|
||||||
|
func isHaikuModel(model string) bool {
|
||||||
|
return strings.Contains(strings.ToLower(model), "haiku")
|
||||||
|
}
|
||||||
|
|
||||||
|
// isMaxTokensOneHaikuRequest 检查是否为 max_tokens=1 + haiku 模型的探测请求
|
||||||
|
// 这类请求用于 Claude Code 验证 API 连通性
|
||||||
|
// 条件:max_tokens == 1 且 model 包含 "haiku" 且非流式请求
|
||||||
|
func isMaxTokensOneHaikuRequest(model string, maxTokens int, isStream bool) bool {
|
||||||
|
return maxTokens == 1 && isHaikuModel(model) && !isStream
|
||||||
|
}
|
||||||
|
|
||||||
// detectInterceptType 检测请求是否需要拦截,返回拦截类型
|
// detectInterceptType 检测请求是否需要拦截,返回拦截类型
|
||||||
func detectInterceptType(body []byte) InterceptType {
|
// 参数说明:
|
||||||
|
// - body: 请求体字节
|
||||||
|
// - model: 请求的模型名称
|
||||||
|
// - maxTokens: max_tokens 值
|
||||||
|
// - isStream: 是否为流式请求
|
||||||
|
// - isClaudeCodeClient: 是否已通过 Claude Code 客户端校验
|
||||||
|
func detectInterceptType(body []byte, model string, maxTokens int, isStream bool, isClaudeCodeClient bool) InterceptType {
|
||||||
|
// 优先检查 max_tokens=1 + haiku 探测请求(仅非流式)
|
||||||
|
if isClaudeCodeClient && isMaxTokensOneHaikuRequest(model, maxTokens, isStream) {
|
||||||
|
return InterceptTypeMaxTokensOneHaiku
|
||||||
|
}
|
||||||
|
|
||||||
// 快速检查:如果不包含任何关键字,直接返回
|
// 快速检查:如果不包含任何关键字,直接返回
|
||||||
bodyStr := string(body)
|
bodyStr := string(body)
|
||||||
hasSuggestionMode := strings.Contains(bodyStr, "[SUGGESTION MODE:")
|
hasSuggestionMode := strings.Contains(bodyStr, "[SUGGESTION MODE:")
|
||||||
@@ -1130,9 +1164,25 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generateRealisticMsgID 生成仿真的消息 ID(msg_bdrk_XXXXXXX 格式)
|
||||||
|
// 格式与 Claude API 真实响应一致,24 位随机字母数字
|
||||||
|
func generateRealisticMsgID() string {
|
||||||
|
const charset = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||||
|
const idLen = 24
|
||||||
|
randomBytes := make([]byte, idLen)
|
||||||
|
if _, err := rand.Read(randomBytes); err != nil {
|
||||||
|
return fmt.Sprintf("msg_bdrk_%d", time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
b := make([]byte, idLen)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = charset[int(randomBytes[i])%len(charset)]
|
||||||
|
}
|
||||||
|
return "msg_bdrk_" + string(b)
|
||||||
|
}
|
||||||
|
|
||||||
// sendMockInterceptResponse 发送非流式 mock 响应(用于请求拦截)
|
// sendMockInterceptResponse 发送非流式 mock 响应(用于请求拦截)
|
||||||
func sendMockInterceptResponse(c *gin.Context, model string, interceptType InterceptType) {
|
func sendMockInterceptResponse(c *gin.Context, model string, interceptType InterceptType) {
|
||||||
var msgID, text string
|
var msgID, text, stopReason string
|
||||||
var outputTokens int
|
var outputTokens int
|
||||||
|
|
||||||
switch interceptType {
|
switch interceptType {
|
||||||
@@ -1140,24 +1190,42 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter
|
|||||||
msgID = "msg_mock_suggestion"
|
msgID = "msg_mock_suggestion"
|
||||||
text = ""
|
text = ""
|
||||||
outputTokens = 1
|
outputTokens = 1
|
||||||
|
stopReason = "end_turn"
|
||||||
|
case InterceptTypeMaxTokensOneHaiku:
|
||||||
|
msgID = generateRealisticMsgID()
|
||||||
|
text = "#"
|
||||||
|
outputTokens = 1
|
||||||
|
stopReason = "max_tokens" // max_tokens=1 探测请求的 stop_reason 应为 max_tokens
|
||||||
default: // InterceptTypeWarmup
|
default: // InterceptTypeWarmup
|
||||||
msgID = "msg_mock_warmup"
|
msgID = "msg_mock_warmup"
|
||||||
text = "New Conversation"
|
text = "New Conversation"
|
||||||
outputTokens = 2
|
outputTokens = 2
|
||||||
|
stopReason = "end_turn"
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
// 构建完整的响应格式(与 Claude API 响应格式一致)
|
||||||
"id": msgID,
|
response := gin.H{
|
||||||
"type": "message",
|
"model": model,
|
||||||
"role": "assistant",
|
"id": msgID,
|
||||||
"model": model,
|
"type": "message",
|
||||||
"content": []gin.H{{"type": "text", "text": text}},
|
"role": "assistant",
|
||||||
"stop_reason": "end_turn",
|
"content": []gin.H{{"type": "text", "text": text}},
|
||||||
|
"stop_reason": stopReason,
|
||||||
|
"stop_sequence": nil,
|
||||||
"usage": gin.H{
|
"usage": gin.H{
|
||||||
"input_tokens": 10,
|
"input_tokens": 10,
|
||||||
|
"cache_creation_input_tokens": 0,
|
||||||
|
"cache_read_input_tokens": 0,
|
||||||
|
"cache_creation": gin.H{
|
||||||
|
"ephemeral_5m_input_tokens": 0,
|
||||||
|
"ephemeral_1h_input_tokens": 0,
|
||||||
|
},
|
||||||
"output_tokens": outputTokens,
|
"output_tokens": outputTokens,
|
||||||
|
"total_tokens": 10 + outputTokens,
|
||||||
},
|
},
|
||||||
})
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
func billingErrorDetails(err error) (status int, code, message string) {
|
func billingErrorDetails(err error) (status int, code, message string) {
|
||||||
|
|||||||
65
backend/internal/handler/gateway_handler_intercept_test.go
Normal file
65
backend/internal/handler/gateway_handler_intercept_test.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDetectInterceptType_MaxTokensOneHaikuRequiresClaudeCodeClient(t *testing.T) {
|
||||||
|
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
|
||||||
|
|
||||||
|
notClaudeCode := detectInterceptType(body, "claude-haiku-4-5", 1, false, false)
|
||||||
|
require.Equal(t, InterceptTypeNone, notClaudeCode)
|
||||||
|
|
||||||
|
isClaudeCode := detectInterceptType(body, "claude-haiku-4-5", 1, false, true)
|
||||||
|
require.Equal(t, InterceptTypeMaxTokensOneHaiku, isClaudeCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetectInterceptType_SuggestionModeUnaffected(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"messages":[{
|
||||||
|
"role":"user",
|
||||||
|
"content":[{"type":"text","text":"[SUGGESTION MODE:foo]"}]
|
||||||
|
}],
|
||||||
|
"system":[]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
got := detectInterceptType(body, "claude-sonnet-4-5", 256, false, false)
|
||||||
|
require.Equal(t, InterceptTypeSuggestionMode, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendMockInterceptResponse_MaxTokensOneHaiku(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
sendMockInterceptResponse(ctx, "claude-haiku-4-5", InterceptTypeMaxTokensOneHaiku)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
|
||||||
|
var response map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &response))
|
||||||
|
require.Equal(t, "max_tokens", response["stop_reason"])
|
||||||
|
|
||||||
|
id, ok := response["id"].(string)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.True(t, strings.HasPrefix(id, "msg_bdrk_"))
|
||||||
|
|
||||||
|
content, ok := response["content"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.NotEmpty(t, content)
|
||||||
|
|
||||||
|
firstBlock, ok := content[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "#", firstBlock["text"])
|
||||||
|
|
||||||
|
usage, ok := response["usage"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, float64(1), usage["output_tokens"])
|
||||||
|
}
|
||||||
@@ -24,4 +24,8 @@ const (
|
|||||||
ThinkingEnabled Key = "ctx_thinking_enabled"
|
ThinkingEnabled Key = "ctx_thinking_enabled"
|
||||||
// Group 认证后的分组信息,由 API Key 认证中间件设置
|
// Group 认证后的分组信息,由 API Key 认证中间件设置
|
||||||
Group Key = "ctx_group"
|
Group Key = "ctx_group"
|
||||||
|
|
||||||
|
// IsMaxTokensOneHaikuRequest 标识当前请求是否为 max_tokens=1 + haiku 模型的探测请求
|
||||||
|
// 用于 ClaudeCodeOnly 验证绕过(绕过 system prompt 检查,但仍需验证 User-Agent)
|
||||||
|
IsMaxTokensOneHaikuRequest Key = "ctx_is_max_tokens_one_haiku"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -56,7 +56,8 @@ func NewClaudeCodeValidator() *ClaudeCodeValidator {
|
|||||||
//
|
//
|
||||||
// Step 1: User-Agent 检查 (必需) - 必须是 claude-cli/x.x.x
|
// Step 1: User-Agent 检查 (必需) - 必须是 claude-cli/x.x.x
|
||||||
// Step 2: 对于非 messages 路径,只要 UA 匹配就通过
|
// Step 2: 对于非 messages 路径,只要 UA 匹配就通过
|
||||||
// Step 3: 对于 messages 路径,进行严格验证:
|
// Step 3: 检查 max_tokens=1 + haiku 探测请求绕过(UA 已验证)
|
||||||
|
// Step 4: 对于 messages 路径,进行严格验证:
|
||||||
// - System prompt 相似度检查
|
// - System prompt 相似度检查
|
||||||
// - X-App header 检查
|
// - X-App header 检查
|
||||||
// - anthropic-beta header 检查
|
// - anthropic-beta header 检查
|
||||||
@@ -75,14 +76,20 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step 3: messages 路径,进行严格验证
|
// Step 3: 检查 max_tokens=1 + haiku 探测请求绕过
|
||||||
|
// 这类请求用于 Claude Code 验证 API 连通性,不携带 system prompt
|
||||||
|
if isMaxTokensOneHaiku, ok := r.Context().Value(ctxkey.IsMaxTokensOneHaikuRequest).(bool); ok && isMaxTokensOneHaiku {
|
||||||
|
return true // 绕过 system prompt 检查,UA 已在 Step 1 验证
|
||||||
|
}
|
||||||
|
|
||||||
// 3.1 检查 system prompt 相似度
|
// Step 4: messages 路径,进行严格验证
|
||||||
|
|
||||||
|
// 4.1 检查 system prompt 相似度
|
||||||
if !v.hasClaudeCodeSystemPrompt(body) {
|
if !v.hasClaudeCodeSystemPrompt(body) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3.2 检查必需的 headers(值不为空即可)
|
// 4.2 检查必需的 headers(值不为空即可)
|
||||||
xApp := r.Header.Get("X-App")
|
xApp := r.Header.Get("X-App")
|
||||||
if xApp == "" {
|
if xApp == "" {
|
||||||
return false
|
return false
|
||||||
@@ -98,7 +105,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3.3 验证 metadata.user_id
|
// 4.3 验证 metadata.user_id
|
||||||
if body == nil {
|
if body == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
58
backend/internal/service/claude_code_validator_test.go
Normal file
58
backend/internal/service/claude_code_validator_test.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClaudeCodeValidator_ProbeBypass(t *testing.T) {
|
||||||
|
validator := NewClaudeCodeValidator()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/messages", nil)
|
||||||
|
req.Header.Set("User-Agent", "claude-cli/1.2.3 (darwin; arm64)")
|
||||||
|
req = req.WithContext(context.WithValue(req.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true))
|
||||||
|
|
||||||
|
ok := validator.Validate(req, map[string]any{
|
||||||
|
"model": "claude-haiku-4-5",
|
||||||
|
"max_tokens": 1,
|
||||||
|
})
|
||||||
|
require.True(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeCodeValidator_ProbeBypassRequiresUA(t *testing.T) {
|
||||||
|
validator := NewClaudeCodeValidator()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/messages", nil)
|
||||||
|
req.Header.Set("User-Agent", "curl/8.0.0")
|
||||||
|
req = req.WithContext(context.WithValue(req.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true))
|
||||||
|
|
||||||
|
ok := validator.Validate(req, map[string]any{
|
||||||
|
"model": "claude-haiku-4-5",
|
||||||
|
"max_tokens": 1,
|
||||||
|
})
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeCodeValidator_MessagesWithoutProbeStillNeedStrictValidation(t *testing.T) {
|
||||||
|
validator := NewClaudeCodeValidator()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/messages", nil)
|
||||||
|
req.Header.Set("User-Agent", "claude-cli/1.2.3 (darwin; arm64)")
|
||||||
|
|
||||||
|
ok := validator.Validate(req, map[string]any{
|
||||||
|
"model": "claude-haiku-4-5",
|
||||||
|
"max_tokens": 1,
|
||||||
|
})
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeCodeValidator_NonMessagesPathUAOnly(t *testing.T) {
|
||||||
|
validator := NewClaudeCodeValidator()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/models", nil)
|
||||||
|
req.Header.Set("User-Agent", "claude-cli/1.2.3 (darwin; arm64)")
|
||||||
|
|
||||||
|
ok := validator.Validate(req, nil)
|
||||||
|
require.True(t, ok)
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
)
|
)
|
||||||
@@ -29,6 +30,7 @@ type ParsedRequest struct {
|
|||||||
Messages []any // messages 数组
|
Messages []any // messages 数组
|
||||||
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
|
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
|
||||||
ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名)
|
ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名)
|
||||||
|
MaxTokens int // max_tokens 值(用于探测请求拦截)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseGatewayRequest 解析网关请求体并返回结构化结果
|
// ParseGatewayRequest 解析网关请求体并返回结构化结果
|
||||||
@@ -79,9 +81,55 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// max_tokens
|
||||||
|
if rawMaxTokens, exists := req["max_tokens"]; exists {
|
||||||
|
if maxTokens, ok := parseIntegralNumber(rawMaxTokens); ok {
|
||||||
|
parsed.MaxTokens = maxTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return parsed, nil
|
return parsed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseIntegralNumber 将 JSON 解码后的数字安全转换为 int。
|
||||||
|
// 仅接受“整数值”的输入,小数/NaN/Inf/越界值都会返回 false。
|
||||||
|
func parseIntegralNumber(raw any) (int, bool) {
|
||||||
|
switch v := raw.(type) {
|
||||||
|
case float64:
|
||||||
|
if math.IsNaN(v) || math.IsInf(v, 0) || v != math.Trunc(v) {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
if v > float64(math.MaxInt) || v < float64(math.MinInt) {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return int(v), true
|
||||||
|
case int:
|
||||||
|
return v, true
|
||||||
|
case int8:
|
||||||
|
return int(v), true
|
||||||
|
case int16:
|
||||||
|
return int(v), true
|
||||||
|
case int32:
|
||||||
|
return int(v), true
|
||||||
|
case int64:
|
||||||
|
if v > int64(math.MaxInt) || v < int64(math.MinInt) {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return int(v), true
|
||||||
|
case json.Number:
|
||||||
|
i64, err := v.Int64()
|
||||||
|
if err != nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
if i64 > int64(math.MaxInt) || i64 < int64(math.MinInt) {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return int(i64), true
|
||||||
|
default:
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// FilterThinkingBlocks removes thinking blocks from request body
|
// FilterThinkingBlocks removes thinking blocks from request body
|
||||||
// Returns filtered body or original body if filtering fails (fail-safe)
|
// Returns filtered body or original body if filtering fails (fail-safe)
|
||||||
// This prevents 400 errors from invalid thinking block signatures
|
// This prevents 400 errors from invalid thinking block signatures
|
||||||
|
|||||||
@@ -28,6 +28,20 @@ func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) {
|
|||||||
require.True(t, parsed.ThinkingEnabled)
|
require.True(t, parsed.ThinkingEnabled)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseGatewayRequest_MaxTokens(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1}`)
|
||||||
|
parsed, err := ParseGatewayRequest(body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, parsed.MaxTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseGatewayRequest_MaxTokensNonIntegralIgnored(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1.5}`)
|
||||||
|
parsed, err := ParseGatewayRequest(body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 0, parsed.MaxTokens)
|
||||||
|
}
|
||||||
|
|
||||||
func TestParseGatewayRequest_SystemNull(t *testing.T) {
|
func TestParseGatewayRequest_SystemNull(t *testing.T) {
|
||||||
body := []byte(`{"model":"claude-3","system":null}`)
|
body := []byte(`{"model":"claude-3","system":null}`)
|
||||||
parsed, err := ParseGatewayRequest(body)
|
parsed, err := ParseGatewayRequest(body)
|
||||||
|
|||||||
Reference in New Issue
Block a user