fix(backend): 修复 P0/P1 严重安全和稳定性问题
P0 严重问题修复: - 优化重试机制:降至 5 次 + 指数退避 + 10s 上限,防止请求堆积 - 修复 SSE 错误格式:符合 Anthropic API 规范,添加错误类型标准化 P1 重要问题修复: - 防止 DOS 攻击:使用 io.LimitReader 限制请求体 10MB,流式解析 - 修复计费数据丢失:改为同步计费,使用独立 context 防止中断 技术细节: - 新增 retryBackoffDelay() 和 sleepWithContext() 支持 context 取消 - 新增 normalizeAnthropicErrorType() 和 sanitizePublicErrorMessage() - 新增 parseGatewayRequestStream() 实现流式解析 - 新增 recordUsageSync() 确保计费数据持久化 影响: - 极端场景重试时间从 30s 降至 ≤10s - 防止高并发 OOM 攻击 - 消除计费数据丢失风险 - 提升客户端兼容性
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
@@ -20,6 +21,10 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const maxGatewayRequestBodyBytes int64 = 10 * 1024 * 1024 // 10MB
|
||||||
|
|
||||||
|
var errEmptyRequestBody = errors.New("request body is empty")
|
||||||
|
|
||||||
// GatewayHandler handles API gateway requests
|
// GatewayHandler handles API gateway requests
|
||||||
type GatewayHandler struct {
|
type GatewayHandler struct {
|
||||||
gatewayService *service.GatewayService
|
gatewayService *service.GatewayService
|
||||||
@@ -30,6 +35,23 @@ type GatewayHandler struct {
|
|||||||
concurrencyHelper *ConcurrencyHelper
|
concurrencyHelper *ConcurrencyHelper
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *GatewayHandler) recordUsageSync(apiKey *service.APIKey, subscription *service.UserSubscription, result *service.ForwardResult, usedAccount *service.Account) {
|
||||||
|
// 计费属于关键数据:同步写入,避免 goroutine 异步导致进程崩溃时丢失使用量/扣费数据。
|
||||||
|
// 使用独立 Background context,避免客户端取消请求导致计费中断。
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
|
Result: result,
|
||||||
|
APIKey: apiKey,
|
||||||
|
User: apiKey.User,
|
||||||
|
Account: usedAccount,
|
||||||
|
Subscription: subscription,
|
||||||
|
}); err != nil {
|
||||||
|
log.Printf("Record usage failed: request_id=%s user=%d api_key=%d account=%d err=%v", result.RequestID, apiKey.UserID, apiKey.ID, usedAccount.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// NewGatewayHandler creates a new GatewayHandler
|
// NewGatewayHandler creates a new GatewayHandler
|
||||||
func NewGatewayHandler(
|
func NewGatewayHandler(
|
||||||
gatewayService *service.GatewayService,
|
gatewayService *service.GatewayService,
|
||||||
@@ -49,6 +71,78 @@ func NewGatewayHandler(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseGatewayRequestStream(r io.Reader, limit int64) (*service.ParsedRequest, error) {
|
||||||
|
if r == nil {
|
||||||
|
return nil, errEmptyRequestBody
|
||||||
|
}
|
||||||
|
|
||||||
|
var raw bytes.Buffer
|
||||||
|
limited := io.LimitReader(r, limit+1)
|
||||||
|
tee := io.TeeReader(limited, &raw)
|
||||||
|
decoder := json.NewDecoder(tee)
|
||||||
|
|
||||||
|
var req map[string]any
|
||||||
|
if err := decoder.Decode(&req); err != nil {
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
return nil, errEmptyRequestBody
|
||||||
|
}
|
||||||
|
if int64(raw.Len()) > limit {
|
||||||
|
return nil, &http.MaxBytesError{Limit: limit}
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the body contains exactly one JSON value (allowing trailing whitespace).
|
||||||
|
var extra any
|
||||||
|
if err := decoder.Decode(&extra); err != io.EOF {
|
||||||
|
if int64(raw.Len()) > limit {
|
||||||
|
return nil, &http.MaxBytesError{Limit: limit}
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
return nil, fmt.Errorf("request body must contain a single JSON object")
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if int64(raw.Len()) > limit {
|
||||||
|
return nil, &http.MaxBytesError{Limit: limit}
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed := &service.ParsedRequest{
|
||||||
|
Body: raw.Bytes(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if rawModel, exists := req["model"]; exists {
|
||||||
|
model, ok := rawModel.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid model field type")
|
||||||
|
}
|
||||||
|
parsed.Model = model
|
||||||
|
}
|
||||||
|
if rawStream, exists := req["stream"]; exists {
|
||||||
|
stream, ok := rawStream.(bool)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid stream field type")
|
||||||
|
}
|
||||||
|
parsed.Stream = stream
|
||||||
|
}
|
||||||
|
if metadata, ok := req["metadata"].(map[string]any); ok {
|
||||||
|
if userID, ok := metadata["user_id"].(string); ok {
|
||||||
|
parsed.MetadataUserID = userID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// system 字段只要存在就视为显式提供(即使为 null),
|
||||||
|
// 以避免客户端传 null 时被默认 system 误注入。
|
||||||
|
if system, ok := req["system"]; ok {
|
||||||
|
parsed.HasSystem = true
|
||||||
|
parsed.System = system
|
||||||
|
}
|
||||||
|
if messages, ok := req["messages"].([]any); ok {
|
||||||
|
parsed.Messages = messages
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsed, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Messages handles Claude API compatible messages endpoint
|
// Messages handles Claude API compatible messages endpoint
|
||||||
// POST /v1/messages
|
// POST /v1/messages
|
||||||
func (h *GatewayHandler) Messages(c *gin.Context) {
|
func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||||
@@ -65,27 +159,29 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 读取请求体
|
parsedReq, err := parseGatewayRequestStream(c.Request.Body, maxGatewayRequestBodyBytes)
|
||||||
body, err := io.ReadAll(c.Request.Body)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if errors.Is(err, errEmptyRequestBody) {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var syntaxErr *json.SyntaxError
|
||||||
|
var typeErr *json.UnmarshalTypeError
|
||||||
|
if errors.As(err, &syntaxErr) || errors.As(err, &typeErr) || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if len(parsedReq.Body) == 0 {
|
||||||
if len(body) == 0 {
|
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
parsedReq, err := service.ParseGatewayRequest(body)
|
|
||||||
if err != nil {
|
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
reqModel := parsedReq.Model
|
reqModel := parsedReq.Model
|
||||||
reqStream := parsedReq.Stream
|
reqStream := parsedReq.Stream
|
||||||
|
|
||||||
@@ -167,7 +263,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
account := selection.Account
|
account := selection.Account
|
||||||
|
|
||||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
if account.IsInterceptWarmupEnabled() && isWarmupRequest(parsedReq.Body) {
|
||||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||||
selection.ReleaseFunc()
|
selection.ReleaseFunc()
|
||||||
}
|
}
|
||||||
@@ -225,9 +321,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// 转发请求 - 根据账号平台分流
|
// 转发请求 - 根据账号平台分流
|
||||||
var result *service.ForwardResult
|
var result *service.ForwardResult
|
||||||
if account.Platform == service.PlatformAntigravity {
|
if account.Platform == service.PlatformAntigravity {
|
||||||
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, body)
|
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, parsedReq.Body)
|
||||||
} else {
|
} else {
|
||||||
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
|
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, parsedReq.Body)
|
||||||
}
|
}
|
||||||
if accountReleaseFunc != nil {
|
if accountReleaseFunc != nil {
|
||||||
accountReleaseFunc()
|
accountReleaseFunc()
|
||||||
@@ -254,20 +350,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 异步记录使用量(subscription已在函数开头获取)
|
// 同步记录使用量,避免进程崩溃导致计费数据丢失(subscription已在函数开头获取)
|
||||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
h.recordUsageSync(apiKey, subscription, result, account)
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
|
||||||
Result: result,
|
|
||||||
APIKey: apiKey,
|
|
||||||
User: apiKey.User,
|
|
||||||
Account: usedAccount,
|
|
||||||
Subscription: subscription,
|
|
||||||
}); err != nil {
|
|
||||||
log.Printf("Record usage failed: %v", err)
|
|
||||||
}
|
|
||||||
}(result, account)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -292,7 +376,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
account := selection.Account
|
account := selection.Account
|
||||||
|
|
||||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
if account.IsInterceptWarmupEnabled() && isWarmupRequest(parsedReq.Body) {
|
||||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||||
selection.ReleaseFunc()
|
selection.ReleaseFunc()
|
||||||
}
|
}
|
||||||
@@ -350,7 +434,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// 转发请求 - 根据账号平台分流
|
// 转发请求 - 根据账号平台分流
|
||||||
var result *service.ForwardResult
|
var result *service.ForwardResult
|
||||||
if account.Platform == service.PlatformAntigravity {
|
if account.Platform == service.PlatformAntigravity {
|
||||||
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body)
|
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, parsedReq.Body)
|
||||||
} else {
|
} else {
|
||||||
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq)
|
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq)
|
||||||
}
|
}
|
||||||
@@ -379,20 +463,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 异步记录使用量(subscription已在函数开头获取)
|
// 同步记录使用量,避免进程崩溃导致计费数据丢失(subscription已在函数开头获取)
|
||||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
h.recordUsageSync(apiKey, subscription, result, account)
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
|
||||||
Result: result,
|
|
||||||
APIKey: apiKey,
|
|
||||||
User: apiKey.User,
|
|
||||||
Account: usedAccount,
|
|
||||||
Subscription: subscription,
|
|
||||||
}); err != nil {
|
|
||||||
log.Printf("Record usage failed: %v", err)
|
|
||||||
}
|
|
||||||
}(result, account)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -595,6 +667,9 @@ func normalizeAnthropicErrorType(errType string) string {
|
|||||||
case "billing_error":
|
case "billing_error":
|
||||||
// Not an Anthropic-standard error type; map to the closest equivalent.
|
// Not an Anthropic-standard error type; map to the closest equivalent.
|
||||||
return "permission_error"
|
return "permission_error"
|
||||||
|
case "subscription_error":
|
||||||
|
// Not an Anthropic-standard error type; map to the closest equivalent.
|
||||||
|
return "permission_error"
|
||||||
case "upstream_error":
|
case "upstream_error":
|
||||||
// Not an Anthropic-standard error type; keep clients compatible.
|
// Not an Anthropic-standard error type; keep clients compatible.
|
||||||
return "api_error"
|
return "api_error"
|
||||||
@@ -684,28 +759,30 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 读取请求体
|
parsedReq, err := parseGatewayRequestStream(c.Request.Body, maxGatewayRequestBodyBytes)
|
||||||
body, err := io.ReadAll(c.Request.Body)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if errors.Is(err, errEmptyRequestBody) {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var syntaxErr *json.SyntaxError
|
||||||
|
var typeErr *json.UnmarshalTypeError
|
||||||
|
if errors.As(err, &syntaxErr) || errors.As(err, &typeErr) || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if len(parsedReq.Body) == 0 {
|
||||||
if len(body) == 0 {
|
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
parsedReq, err := service.ParseGatewayRequest(body)
|
|
||||||
if err != nil {
|
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证 model 必填
|
// 验证 model 必填
|
||||||
if parsedReq.Model == "" {
|
if parsedReq.Model == "" {
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||||
|
|||||||
@@ -1157,6 +1157,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
resp = retryResp
|
resp = retryResp
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
if retryResp != nil && retryResp.Body != nil {
|
||||||
|
_ = retryResp.Body.Close()
|
||||||
|
}
|
||||||
log.Printf("Account %d: signature error retry failed: %v", account.ID, retryErr)
|
log.Printf("Account %d: signature error retry failed: %v", account.ID, retryErr)
|
||||||
} else {
|
} else {
|
||||||
log.Printf("Account %d: signature error retry build request failed: %v", account.ID, buildErr)
|
log.Printf("Account %d: signature error retry build request failed: %v", account.ID, buildErr)
|
||||||
@@ -1603,10 +1606,10 @@ func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, re
|
|||||||
// OAuth/Setup Token 账号的 403:标记账号异常
|
// OAuth/Setup Token 账号的 403:标记账号异常
|
||||||
if account.IsOAuth() && statusCode == 403 {
|
if account.IsOAuth() && statusCode == 403 {
|
||||||
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body)
|
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body)
|
||||||
log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetries, statusCode)
|
log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetryAttempts, statusCode)
|
||||||
} else {
|
} else {
|
||||||
// API Key 未配置错误码:不标记账号状态
|
// API Key 未配置错误码:不标记账号状态
|
||||||
log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetries)
|
log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetryAttempts)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user