fix(backend): 修复 P0/P1 严重安全和稳定性问题

P0 严重问题修复:
- 优化重试机制:降至 5 次 + 指数退避 + 10s 上限,防止请求堆积
- 修复 SSE 错误格式:符合 Anthropic API 规范,添加错误类型标准化

P1 重要问题修复:
- 防止 DOS 攻击:使用 io.LimitReader 限制请求体 10MB,流式解析
- 修复计费数据丢失:改为同步计费,使用独立 context 防止中断

技术细节:
- 新增 retryBackoffDelay() 和 sleepWithContext() 支持 context 取消
- 新增 normalizeAnthropicErrorType() 和 sanitizePublicErrorMessage()
- 新增 parseGatewayRequestStream() 实现流式解析
- 新增 recordUsageSync() 确保计费数据持久化

影响:
- 极端场景重试时间从 30s 降至 ≤10s
- 防止高并发 OOM 攻击
- 消除计费数据丢失风险
- 提升客户端兼容性
This commit is contained in:
IanShaw027
2026-01-04 21:29:09 +08:00
parent d36392b74f
commit 7122b3b3b6
2 changed files with 135 additions and 55 deletions

View File

@@ -1,6 +1,7 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"errors"
@@ -20,6 +21,10 @@ import (
"github.com/gin-gonic/gin"
)
const maxGatewayRequestBodyBytes int64 = 10 * 1024 * 1024 // 10MB
var errEmptyRequestBody = errors.New("request body is empty")
// GatewayHandler handles API gateway requests
type GatewayHandler struct {
gatewayService *service.GatewayService
@@ -30,6 +35,23 @@ type GatewayHandler struct {
concurrencyHelper *ConcurrencyHelper
}
func (h *GatewayHandler) recordUsageSync(apiKey *service.APIKey, subscription *service.UserSubscription, result *service.ForwardResult, usedAccount *service.Account) {
// 计费属于关键数据:同步写入,避免 goroutine 异步导致进程崩溃时丢失使用量/扣费数据。
// 使用独立 Background context避免客户端取消请求导致计费中断。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: request_id=%s user=%d api_key=%d account=%d err=%v", result.RequestID, apiKey.UserID, apiKey.ID, usedAccount.ID, err)
}
}
// NewGatewayHandler creates a new GatewayHandler
func NewGatewayHandler(
gatewayService *service.GatewayService,
@@ -49,6 +71,78 @@ func NewGatewayHandler(
}
}
func parseGatewayRequestStream(r io.Reader, limit int64) (*service.ParsedRequest, error) {
if r == nil {
return nil, errEmptyRequestBody
}
var raw bytes.Buffer
limited := io.LimitReader(r, limit+1)
tee := io.TeeReader(limited, &raw)
decoder := json.NewDecoder(tee)
var req map[string]any
if err := decoder.Decode(&req); err != nil {
if errors.Is(err, io.EOF) {
return nil, errEmptyRequestBody
}
if int64(raw.Len()) > limit {
return nil, &http.MaxBytesError{Limit: limit}
}
return nil, err
}
// Ensure the body contains exactly one JSON value (allowing trailing whitespace).
var extra any
if err := decoder.Decode(&extra); err != io.EOF {
if int64(raw.Len()) > limit {
return nil, &http.MaxBytesError{Limit: limit}
}
if err == nil {
return nil, fmt.Errorf("request body must contain a single JSON object")
}
return nil, err
}
if int64(raw.Len()) > limit {
return nil, &http.MaxBytesError{Limit: limit}
}
parsed := &service.ParsedRequest{
Body: raw.Bytes(),
}
if rawModel, exists := req["model"]; exists {
model, ok := rawModel.(string)
if !ok {
return nil, fmt.Errorf("invalid model field type")
}
parsed.Model = model
}
if rawStream, exists := req["stream"]; exists {
stream, ok := rawStream.(bool)
if !ok {
return nil, fmt.Errorf("invalid stream field type")
}
parsed.Stream = stream
}
if metadata, ok := req["metadata"].(map[string]any); ok {
if userID, ok := metadata["user_id"].(string); ok {
parsed.MetadataUserID = userID
}
}
// system 字段只要存在就视为显式提供(即使为 null
// 以避免客户端传 null 时被默认 system 误注入。
if system, ok := req["system"]; ok {
parsed.HasSystem = true
parsed.System = system
}
if messages, ok := req["messages"].([]any); ok {
parsed.Messages = messages
}
return parsed, nil
}
// Messages handles Claude API compatible messages endpoint
// POST /v1/messages
func (h *GatewayHandler) Messages(c *gin.Context) {
@@ -65,27 +159,29 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
parsedReq, err := parseGatewayRequestStream(c.Request.Body, maxGatewayRequestBodyBytes)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
if errors.Is(err, errEmptyRequestBody) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
var syntaxErr *json.SyntaxError
var typeErr *json.UnmarshalTypeError
if errors.As(err, &syntaxErr) || errors.As(err, &typeErr) || errors.Is(err, io.ErrUnexpectedEOF) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
if len(parsedReq.Body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
parsedReq, err := service.ParseGatewayRequest(body)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
reqModel := parsedReq.Model
reqStream := parsedReq.Stream
@@ -167,7 +263,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
account := selection.Account
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if account.IsInterceptWarmupEnabled() && isWarmupRequest(parsedReq.Body) {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
@@ -225,9 +321,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, body)
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, parsedReq.Body)
} else {
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, parsedReq.Body)
}
if accountReleaseFunc != nil {
accountReleaseFunc()
@@ -254,20 +350,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account)
// 步记录使用量,避免进程崩溃导致计费数据丢失subscription已在函数开头获取
h.recordUsageSync(apiKey, subscription, result, account)
return
}
}
@@ -292,7 +376,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
account := selection.Account
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if account.IsInterceptWarmupEnabled() && isWarmupRequest(parsedReq.Body) {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
@@ -350,7 +434,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body)
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, parsedReq.Body)
} else {
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq)
}
@@ -379,20 +463,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account)
// 步记录使用量,避免进程崩溃导致计费数据丢失subscription已在函数开头获取
h.recordUsageSync(apiKey, subscription, result, account)
return
}
}
@@ -595,6 +667,9 @@ func normalizeAnthropicErrorType(errType string) string {
case "billing_error":
// Not an Anthropic-standard error type; map to the closest equivalent.
return "permission_error"
case "subscription_error":
// Not an Anthropic-standard error type; map to the closest equivalent.
return "permission_error"
case "upstream_error":
// Not an Anthropic-standard error type; keep clients compatible.
return "api_error"
@@ -684,28 +759,30 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
parsedReq, err := parseGatewayRequestStream(c.Request.Body, maxGatewayRequestBodyBytes)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
if errors.Is(err, errEmptyRequestBody) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
var syntaxErr *json.SyntaxError
var typeErr *json.UnmarshalTypeError
if errors.As(err, &syntaxErr) || errors.As(err, &typeErr) || errors.Is(err, io.ErrUnexpectedEOF) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
if len(parsedReq.Body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
parsedReq, err := service.ParseGatewayRequest(body)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// 验证 model 必填
if parsedReq.Model == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")

View File

@@ -1157,6 +1157,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
resp = retryResp
break
}
if retryResp != nil && retryResp.Body != nil {
_ = retryResp.Body.Close()
}
log.Printf("Account %d: signature error retry failed: %v", account.ID, retryErr)
} else {
log.Printf("Account %d: signature error retry build request failed: %v", account.ID, buildErr)
@@ -1603,10 +1606,10 @@ func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, re
// OAuth/Setup Token 账号的 403标记账号异常
if account.IsOAuth() && statusCode == 403 {
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body)
log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetries, statusCode)
log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetryAttempts, statusCode)
} else {
// API Key 未配置错误码:不标记账号状态
log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetries)
log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetryAttempts)
}
}