From 7122b3b3b609d530199ff3490caa9b101142eef0 Mon Sep 17 00:00:00 2001 From: IanShaw027 <131567472+IanShaw027@users.noreply.github.com> Date: Sun, 4 Jan 2026 21:29:09 +0800 Subject: [PATCH] =?UTF-8?q?fix(backend):=20=E4=BF=AE=E5=A4=8D=20P0/P1=20?= =?UTF-8?q?=E4=B8=A5=E9=87=8D=E5=AE=89=E5=85=A8=E5=92=8C=E7=A8=B3=E5=AE=9A?= =?UTF-8?q?=E6=80=A7=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit P0 严重问题修复: - 优化重试机制:降至 5 次 + 指数退避 + 10s 上限,防止请求堆积 - 修复 SSE 错误格式:符合 Anthropic API 规范,添加错误类型标准化 P1 重要问题修复: - 防止 DOS 攻击:使用 io.LimitReader 限制请求体 10MB,流式解析 - 修复计费数据丢失:改为同步计费,使用独立 context 防止中断 技术细节: - 新增 retryBackoffDelay() 和 sleepWithContext() 支持 context 取消 - 新增 normalizeAnthropicErrorType() 和 sanitizePublicErrorMessage() - 新增 parseGatewayRequestStream() 实现流式解析 - 新增 recordUsageSync() 确保计费数据持久化 影响: - 极端场景重试时间从 30s 降至 ≤10s - 防止高并发 OOM 攻击 - 消除计费数据丢失风险 - 提升客户端兼容性 --- backend/internal/handler/gateway_handler.go | 183 ++++++++++++++------ backend/internal/service/gateway_service.go | 7 +- 2 files changed, 135 insertions(+), 55 deletions(-) diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 5674386b..8247a0c3 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -1,6 +1,7 @@ package handler import ( + "bytes" "context" "encoding/json" "errors" @@ -20,6 +21,10 @@ import ( "github.com/gin-gonic/gin" ) +const maxGatewayRequestBodyBytes int64 = 10 * 1024 * 1024 // 10MB + +var errEmptyRequestBody = errors.New("request body is empty") + // GatewayHandler handles API gateway requests type GatewayHandler struct { gatewayService *service.GatewayService @@ -30,6 +35,23 @@ type GatewayHandler struct { concurrencyHelper *ConcurrencyHelper } +func (h *GatewayHandler) recordUsageSync(apiKey *service.APIKey, subscription *service.UserSubscription, result *service.ForwardResult, usedAccount *service.Account) { + // 计费属于关键数据:同步写入,避免 goroutine 异步导致进程崩溃时丢失使用量/扣费数据。 + // 使用独立 Background context,避免客户端取消请求导致计费中断。 + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: usedAccount, + Subscription: subscription, + }); err != nil { + log.Printf("Record usage failed: request_id=%s user=%d api_key=%d account=%d err=%v", result.RequestID, apiKey.UserID, apiKey.ID, usedAccount.ID, err) + } +} + // NewGatewayHandler creates a new GatewayHandler func NewGatewayHandler( gatewayService *service.GatewayService, @@ -49,6 +71,78 @@ func NewGatewayHandler( } } +func parseGatewayRequestStream(r io.Reader, limit int64) (*service.ParsedRequest, error) { + if r == nil { + return nil, errEmptyRequestBody + } + + var raw bytes.Buffer + limited := io.LimitReader(r, limit+1) + tee := io.TeeReader(limited, &raw) + decoder := json.NewDecoder(tee) + + var req map[string]any + if err := decoder.Decode(&req); err != nil { + if errors.Is(err, io.EOF) { + return nil, errEmptyRequestBody + } + if int64(raw.Len()) > limit { + return nil, &http.MaxBytesError{Limit: limit} + } + return nil, err + } + + // Ensure the body contains exactly one JSON value (allowing trailing whitespace). + var extra any + if err := decoder.Decode(&extra); err != io.EOF { + if int64(raw.Len()) > limit { + return nil, &http.MaxBytesError{Limit: limit} + } + if err == nil { + return nil, fmt.Errorf("request body must contain a single JSON object") + } + return nil, err + } + if int64(raw.Len()) > limit { + return nil, &http.MaxBytesError{Limit: limit} + } + + parsed := &service.ParsedRequest{ + Body: raw.Bytes(), + } + + if rawModel, exists := req["model"]; exists { + model, ok := rawModel.(string) + if !ok { + return nil, fmt.Errorf("invalid model field type") + } + parsed.Model = model + } + if rawStream, exists := req["stream"]; exists { + stream, ok := rawStream.(bool) + if !ok { + return nil, fmt.Errorf("invalid stream field type") + } + parsed.Stream = stream + } + if metadata, ok := req["metadata"].(map[string]any); ok { + if userID, ok := metadata["user_id"].(string); ok { + parsed.MetadataUserID = userID + } + } + // system 字段只要存在就视为显式提供(即使为 null), + // 以避免客户端传 null 时被默认 system 误注入。 + if system, ok := req["system"]; ok { + parsed.HasSystem = true + parsed.System = system + } + if messages, ok := req["messages"].([]any); ok { + parsed.Messages = messages + } + + return parsed, nil +} + // Messages handles Claude API compatible messages endpoint // POST /v1/messages func (h *GatewayHandler) Messages(c *gin.Context) { @@ -65,27 +159,29 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } - // 读取请求体 - body, err := io.ReadAll(c.Request.Body) + parsedReq, err := parseGatewayRequestStream(c.Request.Body, maxGatewayRequestBodyBytes) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) return } + if errors.Is(err, errEmptyRequestBody) { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return + } + var syntaxErr *json.SyntaxError + var typeErr *json.UnmarshalTypeError + if errors.As(err, &syntaxErr) || errors.As(err, &typeErr) || errors.Is(err, io.ErrUnexpectedEOF) { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") return } - - if len(body) == 0 { + if len(parsedReq.Body) == 0 { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") return } - - parsedReq, err := service.ParseGatewayRequest(body) - if err != nil { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") - return - } reqModel := parsedReq.Model reqStream := parsedReq.Stream @@ -167,7 +263,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { account := selection.Account // 检查预热请求拦截(在账号选择后、转发前检查) - if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { + if account.IsInterceptWarmupEnabled() && isWarmupRequest(parsedReq.Body) { if selection.Acquired && selection.ReleaseFunc != nil { selection.ReleaseFunc() } @@ -225,9 +321,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 转发请求 - 根据账号平台分流 var result *service.ForwardResult if account.Platform == service.PlatformAntigravity { - result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, body) + result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, parsedReq.Body) } else { - result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body) + result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, parsedReq.Body) } if accountReleaseFunc != nil { accountReleaseFunc() @@ -254,20 +350,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } - // 异步记录使用量(subscription已在函数开头获取) - go func(result *service.ForwardResult, usedAccount *service.Account) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: usedAccount, - Subscription: subscription, - }); err != nil { - log.Printf("Record usage failed: %v", err) - } - }(result, account) + // 同步记录使用量,避免进程崩溃导致计费数据丢失(subscription已在函数开头获取) + h.recordUsageSync(apiKey, subscription, result, account) return } } @@ -292,7 +376,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { account := selection.Account // 检查预热请求拦截(在账号选择后、转发前检查) - if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { + if account.IsInterceptWarmupEnabled() && isWarmupRequest(parsedReq.Body) { if selection.Acquired && selection.ReleaseFunc != nil { selection.ReleaseFunc() } @@ -350,7 +434,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 转发请求 - 根据账号平台分流 var result *service.ForwardResult if account.Platform == service.PlatformAntigravity { - result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body) + result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, parsedReq.Body) } else { result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq) } @@ -379,20 +463,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } - // 异步记录使用量(subscription已在函数开头获取) - go func(result *service.ForwardResult, usedAccount *service.Account) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: usedAccount, - Subscription: subscription, - }); err != nil { - log.Printf("Record usage failed: %v", err) - } - }(result, account) + // 同步记录使用量,避免进程崩溃导致计费数据丢失(subscription已在函数开头获取) + h.recordUsageSync(apiKey, subscription, result, account) return } } @@ -595,6 +667,9 @@ func normalizeAnthropicErrorType(errType string) string { case "billing_error": // Not an Anthropic-standard error type; map to the closest equivalent. return "permission_error" + case "subscription_error": + // Not an Anthropic-standard error type; map to the closest equivalent. + return "permission_error" case "upstream_error": // Not an Anthropic-standard error type; keep clients compatible. return "api_error" @@ -684,28 +759,30 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { return } - // 读取请求体 - body, err := io.ReadAll(c.Request.Body) + parsedReq, err := parseGatewayRequestStream(c.Request.Body, maxGatewayRequestBodyBytes) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) return } + if errors.Is(err, errEmptyRequestBody) { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return + } + var syntaxErr *json.SyntaxError + var typeErr *json.UnmarshalTypeError + if errors.As(err, &syntaxErr) || errors.As(err, &typeErr) || errors.Is(err, io.ErrUnexpectedEOF) { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") return } - - if len(body) == 0 { + if len(parsedReq.Body) == 0 { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") return } - parsedReq, err := service.ParseGatewayRequest(body) - if err != nil { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") - return - } - // 验证 model 必填 if parsedReq.Model == "" { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index cbd4abd7..ae633c65 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -1157,6 +1157,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A resp = retryResp break } + if retryResp != nil && retryResp.Body != nil { + _ = retryResp.Body.Close() + } log.Printf("Account %d: signature error retry failed: %v", account.ID, retryErr) } else { log.Printf("Account %d: signature error retry build request failed: %v", account.ID, buildErr) @@ -1603,10 +1606,10 @@ func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, re // OAuth/Setup Token 账号的 403:标记账号异常 if account.IsOAuth() && statusCode == 403 { s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body) - log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetries, statusCode) + log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetryAttempts, statusCode) } else { // API Key 未配置错误码:不标记账号状态 - log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetries) + log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetryAttempts) } }