fix(backend): 修复 P0/P1 严重安全和稳定性问题
P0 严重问题修复: - 优化重试机制:降至 5 次 + 指数退避 + 10s 上限,防止请求堆积 - 修复 SSE 错误格式:符合 Anthropic API 规范,添加错误类型标准化 P1 重要问题修复: - 防止 DOS 攻击:使用 io.LimitReader 限制请求体 10MB,流式解析 - 修复计费数据丢失:改为同步计费,使用独立 context 防止中断 技术细节: - 新增 retryBackoffDelay() 和 sleepWithContext() 支持 context 取消 - 新增 normalizeAnthropicErrorType() 和 sanitizePublicErrorMessage() - 新增 parseGatewayRequestStream() 实现流式解析 - 新增 recordUsageSync() 确保计费数据持久化 影响: - 极端场景重试时间从 30s 降至 ≤10s - 防止高并发 OOM 攻击 - 消除计费数据丢失风险 - 提升客户端兼容性
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@@ -20,6 +21,10 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const maxGatewayRequestBodyBytes int64 = 10 * 1024 * 1024 // 10MB
|
||||
|
||||
var errEmptyRequestBody = errors.New("request body is empty")
|
||||
|
||||
// GatewayHandler handles API gateway requests
|
||||
type GatewayHandler struct {
|
||||
gatewayService *service.GatewayService
|
||||
@@ -30,6 +35,23 @@ type GatewayHandler struct {
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
}
|
||||
|
||||
func (h *GatewayHandler) recordUsageSync(apiKey *service.APIKey, subscription *service.UserSubscription, result *service.ForwardResult, usedAccount *service.Account) {
|
||||
// 计费属于关键数据:同步写入,避免 goroutine 异步导致进程崩溃时丢失使用量/扣费数据。
|
||||
// 使用独立 Background context,避免客户端取消请求导致计费中断。
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: request_id=%s user=%d api_key=%d account=%d err=%v", result.RequestID, apiKey.UserID, apiKey.ID, usedAccount.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// NewGatewayHandler creates a new GatewayHandler
|
||||
func NewGatewayHandler(
|
||||
gatewayService *service.GatewayService,
|
||||
@@ -49,6 +71,78 @@ func NewGatewayHandler(
|
||||
}
|
||||
}
|
||||
|
||||
func parseGatewayRequestStream(r io.Reader, limit int64) (*service.ParsedRequest, error) {
|
||||
if r == nil {
|
||||
return nil, errEmptyRequestBody
|
||||
}
|
||||
|
||||
var raw bytes.Buffer
|
||||
limited := io.LimitReader(r, limit+1)
|
||||
tee := io.TeeReader(limited, &raw)
|
||||
decoder := json.NewDecoder(tee)
|
||||
|
||||
var req map[string]any
|
||||
if err := decoder.Decode(&req); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil, errEmptyRequestBody
|
||||
}
|
||||
if int64(raw.Len()) > limit {
|
||||
return nil, &http.MaxBytesError{Limit: limit}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Ensure the body contains exactly one JSON value (allowing trailing whitespace).
|
||||
var extra any
|
||||
if err := decoder.Decode(&extra); err != io.EOF {
|
||||
if int64(raw.Len()) > limit {
|
||||
return nil, &http.MaxBytesError{Limit: limit}
|
||||
}
|
||||
if err == nil {
|
||||
return nil, fmt.Errorf("request body must contain a single JSON object")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if int64(raw.Len()) > limit {
|
||||
return nil, &http.MaxBytesError{Limit: limit}
|
||||
}
|
||||
|
||||
parsed := &service.ParsedRequest{
|
||||
Body: raw.Bytes(),
|
||||
}
|
||||
|
||||
if rawModel, exists := req["model"]; exists {
|
||||
model, ok := rawModel.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid model field type")
|
||||
}
|
||||
parsed.Model = model
|
||||
}
|
||||
if rawStream, exists := req["stream"]; exists {
|
||||
stream, ok := rawStream.(bool)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid stream field type")
|
||||
}
|
||||
parsed.Stream = stream
|
||||
}
|
||||
if metadata, ok := req["metadata"].(map[string]any); ok {
|
||||
if userID, ok := metadata["user_id"].(string); ok {
|
||||
parsed.MetadataUserID = userID
|
||||
}
|
||||
}
|
||||
// system 字段只要存在就视为显式提供(即使为 null),
|
||||
// 以避免客户端传 null 时被默认 system 误注入。
|
||||
if system, ok := req["system"]; ok {
|
||||
parsed.HasSystem = true
|
||||
parsed.System = system
|
||||
}
|
||||
if messages, ok := req["messages"].([]any); ok {
|
||||
parsed.Messages = messages
|
||||
}
|
||||
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
// Messages handles Claude API compatible messages endpoint
|
||||
// POST /v1/messages
|
||||
func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
@@ -65,27 +159,29 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 读取请求体
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
parsedReq, err := parseGatewayRequestStream(c.Request.Body, maxGatewayRequestBodyBytes)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errEmptyRequestBody) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
var syntaxErr *json.SyntaxError
|
||||
var typeErr *json.UnmarshalTypeError
|
||||
if errors.As(err, &syntaxErr) || errors.As(err, &typeErr) || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
|
||||
if len(body) == 0 {
|
||||
if len(parsedReq.Body) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
parsedReq, err := service.ParseGatewayRequest(body)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
reqModel := parsedReq.Model
|
||||
reqStream := parsedReq.Stream
|
||||
|
||||
@@ -167,7 +263,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
account := selection.Account
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(parsedReq.Body) {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
@@ -225,9 +321,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 转发请求 - 根据账号平台分流
|
||||
var result *service.ForwardResult
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, body)
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, parsedReq.Body)
|
||||
} else {
|
||||
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
|
||||
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, parsedReq.Body)
|
||||
}
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
@@ -254,20 +350,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
// 同步记录使用量,避免进程崩溃导致计费数据丢失(subscription已在函数开头获取)
|
||||
h.recordUsageSync(apiKey, subscription, result, account)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -292,7 +376,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
account := selection.Account
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(parsedReq.Body) {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
@@ -350,7 +434,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 转发请求 - 根据账号平台分流
|
||||
var result *service.ForwardResult
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, parsedReq.Body)
|
||||
} else {
|
||||
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq)
|
||||
}
|
||||
@@ -379,20 +463,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
// 同步记录使用量,避免进程崩溃导致计费数据丢失(subscription已在函数开头获取)
|
||||
h.recordUsageSync(apiKey, subscription, result, account)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -595,6 +667,9 @@ func normalizeAnthropicErrorType(errType string) string {
|
||||
case "billing_error":
|
||||
// Not an Anthropic-standard error type; map to the closest equivalent.
|
||||
return "permission_error"
|
||||
case "subscription_error":
|
||||
// Not an Anthropic-standard error type; map to the closest equivalent.
|
||||
return "permission_error"
|
||||
case "upstream_error":
|
||||
// Not an Anthropic-standard error type; keep clients compatible.
|
||||
return "api_error"
|
||||
@@ -684,28 +759,30 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 读取请求体
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
parsedReq, err := parseGatewayRequestStream(c.Request.Body, maxGatewayRequestBodyBytes)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errEmptyRequestBody) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
var syntaxErr *json.SyntaxError
|
||||
var typeErr *json.UnmarshalTypeError
|
||||
if errors.As(err, &syntaxErr) || errors.As(err, &typeErr) || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
|
||||
if len(body) == 0 {
|
||||
if len(parsedReq.Body) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
parsedReq, err := service.ParseGatewayRequest(body)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证 model 必填
|
||||
if parsedReq.Model == "" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
|
||||
@@ -1157,6 +1157,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
resp = retryResp
|
||||
break
|
||||
}
|
||||
if retryResp != nil && retryResp.Body != nil {
|
||||
_ = retryResp.Body.Close()
|
||||
}
|
||||
log.Printf("Account %d: signature error retry failed: %v", account.ID, retryErr)
|
||||
} else {
|
||||
log.Printf("Account %d: signature error retry build request failed: %v", account.ID, buildErr)
|
||||
@@ -1603,10 +1606,10 @@ func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, re
|
||||
// OAuth/Setup Token 账号的 403:标记账号异常
|
||||
if account.IsOAuth() && statusCode == 403 {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body)
|
||||
log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetries, statusCode)
|
||||
log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetryAttempts, statusCode)
|
||||
} else {
|
||||
// API Key 未配置错误码:不标记账号状态
|
||||
log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetries)
|
||||
log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetryAttempts)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user