feat(backend): 提交后端审计修复与配套测试改动
This commit is contained in:
@@ -404,6 +404,14 @@ gateway:
|
|||||||
- `server.trusted_proxies` 启用可信代理解析 X-Forwarded-For
|
- `server.trusted_proxies` 启用可信代理解析 X-Forwarded-For
|
||||||
- `turnstile.required` 在 release 模式强制启用 Turnstile
|
- `turnstile.required` 在 release 模式强制启用 Turnstile
|
||||||
|
|
||||||
|
**网关防御纵深建议(重点)**
|
||||||
|
|
||||||
|
- `gateway.upstream_response_read_max_bytes`:限制非流式上游响应读取大小(默认 `8MB`),用于防止异常响应导致内存放大。
|
||||||
|
- `gateway.proxy_probe_response_read_max_bytes`:限制代理探测响应读取大小(默认 `1MB`)。
|
||||||
|
- `gateway.gemini_debug_response_headers`:默认 `false`,仅在排障时短时开启,避免高频请求日志开销。
|
||||||
|
- `/auth/register`、`/auth/login`、`/auth/login/2fa`、`/auth/send-verify-code` 已提供服务端兜底限流(Redis 故障时 fail-close)。
|
||||||
|
- 推荐将 WAF/CDN 作为第一层防护,服务端限流与响应读取上限作为第二层兜底;两层同时保留,避免旁路流量与误配置风险。
|
||||||
|
|
||||||
**⚠️ 安全警告:HTTP URL 配置**
|
**⚠️ 安全警告:HTTP URL 配置**
|
||||||
|
|
||||||
当 `security.url_allowlist.enabled=false` 时,系统默认执行最小 URL 校验,**拒绝 HTTP URL**,仅允许 HTTPS。要允许 HTTP URL(例如用于开发或内网测试),必须显式设置:
|
当 `security.url_allowlist.enabled=false` 时,系统默认执行最小 URL 校验,**拒绝 HTTP URL**,仅允许 HTTPS。要允许 HTTP URL(例如用于开发或内网测试),必须显式设置:
|
||||||
|
|||||||
@@ -308,6 +308,12 @@ type GatewayConfig struct {
|
|||||||
ResponseHeaderTimeout int `mapstructure:"response_header_timeout"`
|
ResponseHeaderTimeout int `mapstructure:"response_header_timeout"`
|
||||||
// 请求体最大字节数,用于网关请求体大小限制
|
// 请求体最大字节数,用于网关请求体大小限制
|
||||||
MaxBodySize int64 `mapstructure:"max_body_size"`
|
MaxBodySize int64 `mapstructure:"max_body_size"`
|
||||||
|
// 非流式上游响应体读取上限(字节),用于防止无界读取导致内存放大
|
||||||
|
UpstreamResponseReadMaxBytes int64 `mapstructure:"upstream_response_read_max_bytes"`
|
||||||
|
// 代理探测响应体读取上限(字节)
|
||||||
|
ProxyProbeResponseReadMaxBytes int64 `mapstructure:"proxy_probe_response_read_max_bytes"`
|
||||||
|
// Gemini 上游响应头调试日志开关(默认关闭,避免高频日志开销)
|
||||||
|
GeminiDebugResponseHeaders bool `mapstructure:"gemini_debug_response_headers"`
|
||||||
// ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy)
|
// ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy)
|
||||||
ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"`
|
ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"`
|
||||||
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
|
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
|
||||||
@@ -1059,6 +1065,9 @@ func setDefaults() {
|
|||||||
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
|
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
|
||||||
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
|
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
|
||||||
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
|
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
|
||||||
|
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
|
||||||
|
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
|
||||||
|
viper.SetDefault("gateway.gemini_debug_response_headers", false)
|
||||||
viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024))
|
viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024))
|
||||||
viper.SetDefault("gateway.sora_stream_timeout_seconds", 900)
|
viper.SetDefault("gateway.sora_stream_timeout_seconds", 900)
|
||||||
viper.SetDefault("gateway.sora_request_timeout_seconds", 180)
|
viper.SetDefault("gateway.sora_request_timeout_seconds", 180)
|
||||||
@@ -1465,6 +1474,12 @@ func (c *Config) Validate() error {
|
|||||||
if c.Gateway.MaxBodySize <= 0 {
|
if c.Gateway.MaxBodySize <= 0 {
|
||||||
return fmt.Errorf("gateway.max_body_size must be positive")
|
return fmt.Errorf("gateway.max_body_size must be positive")
|
||||||
}
|
}
|
||||||
|
if c.Gateway.UpstreamResponseReadMaxBytes <= 0 {
|
||||||
|
return fmt.Errorf("gateway.upstream_response_read_max_bytes must be positive")
|
||||||
|
}
|
||||||
|
if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 {
|
||||||
|
return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive")
|
||||||
|
}
|
||||||
if c.Gateway.SoraMaxBodySize < 0 {
|
if c.Gateway.SoraMaxBodySize < 0 {
|
||||||
return fmt.Errorf("gateway.sora_max_body_size must be non-negative")
|
return fmt.Errorf("gateway.sora_max_body_size must be non-negative")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -418,8 +418,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// 错误响应已在Forward中处理,这里只记录日志
|
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||||
reqLog.Error("gateway.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
reqLog.Error("gateway.forward_failed",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -683,8 +687,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// 错误响应已在Forward中处理,这里只记录日志
|
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||||
reqLog.Error("gateway.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
reqLog.Error("gateway.forward_failed",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1117,6 +1125,15 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
|
|||||||
h.errorResponse(c, status, errType, message)
|
h.errorResponse(c, status, errType, message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
|
||||||
|
func (h *GatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool {
|
||||||
|
if c == nil || c.Writer == nil || c.Writer.Written() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// errorResponse 返回Claude API格式的错误响应
|
// errorResponse 返回Claude API格式的错误响应
|
||||||
func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||||
c.JSON(status, gin.H{
|
c.JSON(status, gin.H{
|
||||||
|
|||||||
@@ -0,0 +1,49 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGatewayEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
h := &GatewayHandler{}
|
||||||
|
wrote := h.ensureForwardErrorResponse(c, false)
|
||||||
|
|
||||||
|
require.True(t, wrote)
|
||||||
|
require.Equal(t, http.StatusBadGateway, w.Code)
|
||||||
|
|
||||||
|
var parsed map[string]any
|
||||||
|
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "error", parsed["type"])
|
||||||
|
errorObj, ok := parsed["error"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "upstream_error", errorObj["type"])
|
||||||
|
assert.Equal(t, "Upstream request failed", errorObj["message"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
c.String(http.StatusTeapot, "already written")
|
||||||
|
|
||||||
|
h := &GatewayHandler{}
|
||||||
|
wrote := h.ensureForwardErrorResponse(c, false)
|
||||||
|
|
||||||
|
require.False(t, wrote)
|
||||||
|
require.Equal(t, http.StatusTeapot, w.Code)
|
||||||
|
assert.Equal(t, "already written", w.Body.String())
|
||||||
|
}
|
||||||
@@ -365,8 +365,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// Error response already handled in Forward, just log
|
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||||
reqLog.Error("openai.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
reqLog.Error("openai.forward_failed",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -521,6 +525,15 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
|
|||||||
h.errorResponse(c, status, errType, message)
|
h.errorResponse(c, status, errType, message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
|
||||||
|
func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool {
|
||||||
|
if c == nil || c.Writer == nil || c.Writer.Written() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// errorResponse returns OpenAI API format error response
|
// errorResponse returns OpenAI API format error response
|
||||||
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||||
c.JSON(status, gin.H{
|
c.JSON(status, gin.H{
|
||||||
|
|||||||
@@ -105,6 +105,42 @@ func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) {
|
|||||||
assert.Equal(t, "test error", errorObj["message"])
|
assert.Equal(t, "test error", errorObj["message"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
h := &OpenAIGatewayHandler{}
|
||||||
|
wrote := h.ensureForwardErrorResponse(c, false)
|
||||||
|
|
||||||
|
require.True(t, wrote)
|
||||||
|
require.Equal(t, http.StatusBadGateway, w.Code)
|
||||||
|
|
||||||
|
var parsed map[string]any
|
||||||
|
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||||
|
require.NoError(t, err)
|
||||||
|
errorObj, ok := parsed["error"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "upstream_error", errorObj["type"])
|
||||||
|
assert.Equal(t, "Upstream request failed", errorObj["message"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
c.String(http.StatusTeapot, "already written")
|
||||||
|
|
||||||
|
h := &OpenAIGatewayHandler{}
|
||||||
|
wrote := h.ensureForwardErrorResponse(c, false)
|
||||||
|
|
||||||
|
require.False(t, wrote)
|
||||||
|
require.Equal(t, http.StatusTeapot, w.Code)
|
||||||
|
assert.Equal(t, "already written", w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
// TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性
|
// TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性
|
||||||
func TestOpenAIHandler_GjsonExtraction(t *testing.T) {
|
func TestOpenAIHandler_GjsonExtraction(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
|||||||
@@ -44,6 +44,16 @@ func GetClientIP(c *gin.Context) string {
|
|||||||
return normalizeIP(c.ClientIP())
|
return normalizeIP(c.ClientIP())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetTrustedClientIP 从 Gin 的可信代理解析链提取客户端 IP。
|
||||||
|
// 该方法依赖 gin.Engine.SetTrustedProxies 配置,不会优先直接信任原始转发头值。
|
||||||
|
// 适用于 ACL / 风控等安全敏感场景。
|
||||||
|
func GetTrustedClientIP(c *gin.Context) string {
|
||||||
|
if c == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return normalizeIP(c.ClientIP())
|
||||||
|
}
|
||||||
|
|
||||||
// normalizeIP 规范化 IP 地址,去除端口号和空格。
|
// normalizeIP 规范化 IP 地址,去除端口号和空格。
|
||||||
func normalizeIP(ip string) string {
|
func normalizeIP(ip string) string {
|
||||||
ip = strings.TrimSpace(ip)
|
ip = strings.TrimSpace(ip)
|
||||||
|
|||||||
@@ -3,8 +3,10 @@
|
|||||||
package ip
|
package ip
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -49,3 +51,25 @@ func TestIsPrivateIP(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
require.NoError(t, r.SetTrustedProxies(nil))
|
||||||
|
|
||||||
|
r.GET("/t", func(c *gin.Context) {
|
||||||
|
c.String(200, GetTrustedClientIP(c))
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest("GET", "/t", nil)
|
||||||
|
req.RemoteAddr = "9.9.9.9:12345"
|
||||||
|
req.Header.Set("X-Forwarded-For", "1.2.3.4")
|
||||||
|
req.Header.Set("X-Real-IP", "1.2.3.4")
|
||||||
|
req.Header.Set("CF-Connecting-IP", "1.2.3.4")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, 200, w.Code)
|
||||||
|
require.Equal(t, "9.9.9.9", w.Body.String())
|
||||||
|
}
|
||||||
|
|||||||
@@ -19,10 +19,14 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
|||||||
insecure := false
|
insecure := false
|
||||||
allowPrivate := false
|
allowPrivate := false
|
||||||
validateResolvedIP := true
|
validateResolvedIP := true
|
||||||
|
maxResponseBytes := defaultProxyProbeResponseMaxBytes
|
||||||
if cfg != nil {
|
if cfg != nil {
|
||||||
insecure = cfg.Security.ProxyProbe.InsecureSkipVerify
|
insecure = cfg.Security.ProxyProbe.InsecureSkipVerify
|
||||||
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
|
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
|
||||||
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
|
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
|
||||||
|
if cfg.Gateway.ProxyProbeResponseReadMaxBytes > 0 {
|
||||||
|
maxResponseBytes = cfg.Gateway.ProxyProbeResponseReadMaxBytes
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if insecure {
|
if insecure {
|
||||||
log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.")
|
log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.")
|
||||||
@@ -31,11 +35,13 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
|||||||
insecureSkipVerify: insecure,
|
insecureSkipVerify: insecure,
|
||||||
allowPrivateHosts: allowPrivate,
|
allowPrivateHosts: allowPrivate,
|
||||||
validateResolvedIP: validateResolvedIP,
|
validateResolvedIP: validateResolvedIP,
|
||||||
|
maxResponseBytes: maxResponseBytes,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
defaultProxyProbeTimeout = 30 * time.Second
|
defaultProxyProbeTimeout = 30 * time.Second
|
||||||
|
defaultProxyProbeResponseMaxBytes = int64(1024 * 1024)
|
||||||
)
|
)
|
||||||
|
|
||||||
// probeURLs 按优先级排列的探测 URL 列表
|
// probeURLs 按优先级排列的探测 URL 列表
|
||||||
@@ -52,6 +58,7 @@ type proxyProbeService struct {
|
|||||||
insecureSkipVerify bool
|
insecureSkipVerify bool
|
||||||
allowPrivateHosts bool
|
allowPrivateHosts bool
|
||||||
validateResolvedIP bool
|
validateResolvedIP bool
|
||||||
|
maxResponseBytes int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
||||||
@@ -98,10 +105,17 @@ func (s *proxyProbeService) probeWithURL(ctx context.Context, client *http.Clien
|
|||||||
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
|
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
maxResponseBytes := s.maxResponseBytes
|
||||||
|
if maxResponseBytes <= 0 {
|
||||||
|
maxResponseBytes = defaultProxyProbeResponseMaxBytes
|
||||||
|
}
|
||||||
|
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes+1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
|
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
|
||||||
}
|
}
|
||||||
|
if int64(len(body)) > maxResponseBytes {
|
||||||
|
return nil, latencyMs, fmt.Errorf("proxy probe response exceeds limit: %d", maxResponseBytes)
|
||||||
|
}
|
||||||
|
|
||||||
switch parser {
|
switch parser {
|
||||||
case "ip-api":
|
case "ip-api":
|
||||||
|
|||||||
@@ -51,6 +51,9 @@ func ProvideRouter(
|
|||||||
if err := r.SetTrustedProxies(nil); err != nil {
|
if err := r.SetTrustedProxies(nil); err != nil {
|
||||||
log.Printf("Failed to disable trusted proxies: %v", err)
|
log.Printf("Failed to disable trusted proxies: %v", err)
|
||||||
}
|
}
|
||||||
|
if cfg.Server.Mode == "release" {
|
||||||
|
log.Printf("Warning: server.trusted_proxies is empty in release mode; client IP trust chain is disabled")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
|
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
|||||||
// 检查 IP 限制(白名单/黑名单)
|
// 检查 IP 限制(白名单/黑名单)
|
||||||
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
|
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
|
||||||
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
|
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetTrustedClientIP(c)
|
||||||
allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist)
|
allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist)
|
||||||
if !allowed {
|
if !allowed {
|
||||||
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
|
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
|
||||||
|
|||||||
@@ -300,6 +300,57 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
|
|||||||
require.Equal(t, http.StatusOK, w.Code)
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
user := &service.User{
|
||||||
|
ID: 7,
|
||||||
|
Role: service.RoleUser,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
Balance: 10,
|
||||||
|
Concurrency: 3,
|
||||||
|
}
|
||||||
|
apiKey := &service.APIKey{
|
||||||
|
ID: 100,
|
||||||
|
UserID: user.ID,
|
||||||
|
Key: "test-key",
|
||||||
|
Status: service.StatusActive,
|
||||||
|
User: user,
|
||||||
|
IPWhitelist: []string{"1.2.3.4"},
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKeyRepo := &stubApiKeyRepo{
|
||||||
|
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
|
if key != apiKey.Key {
|
||||||
|
return nil, service.ErrAPIKeyNotFound
|
||||||
|
}
|
||||||
|
clone := *apiKey
|
||||||
|
return &clone, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||||
|
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||||
|
router := gin.New()
|
||||||
|
require.NoError(t, router.SetTrustedProxies(nil))
|
||||||
|
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
|
||||||
|
router.GET("/t", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||||
|
req.RemoteAddr = "9.9.9.9:12345"
|
||||||
|
req.Header.Set("x-api-key", apiKey.Key)
|
||||||
|
req.Header.Set("X-Forwarded-For", "1.2.3.4")
|
||||||
|
req.Header.Set("X-Real-IP", "1.2.3.4")
|
||||||
|
req.Header.Set("CF-Connecting-IP", "1.2.3.4")
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusForbidden, w.Code)
|
||||||
|
require.Contains(t, w.Body.String(), "ACCESS_DENIED")
|
||||||
|
}
|
||||||
|
|
||||||
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
|
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
|
||||||
router := gin.New()
|
router := gin.New()
|
||||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
|
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
|
||||||
|
|||||||
@@ -24,10 +24,19 @@ func RegisterAuthRoutes(
|
|||||||
// 公开接口
|
// 公开接口
|
||||||
auth := v1.Group("/auth")
|
auth := v1.Group("/auth")
|
||||||
{
|
{
|
||||||
auth.POST("/register", h.Auth.Register)
|
// 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close)
|
||||||
auth.POST("/login", h.Auth.Login)
|
auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{
|
||||||
auth.POST("/login/2fa", h.Auth.Login2FA)
|
FailureMode: middleware.RateLimitFailClose,
|
||||||
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
}), h.Auth.Register)
|
||||||
|
auth.POST("/login", rateLimiter.LimitWithOptions("auth-login", 20, time.Minute, middleware.RateLimitOptions{
|
||||||
|
FailureMode: middleware.RateLimitFailClose,
|
||||||
|
}), h.Auth.Login)
|
||||||
|
auth.POST("/login/2fa", rateLimiter.LimitWithOptions("auth-login-2fa", 20, time.Minute, middleware.RateLimitOptions{
|
||||||
|
FailureMode: middleware.RateLimitFailClose,
|
||||||
|
}), h.Auth.Login2FA)
|
||||||
|
auth.POST("/send-verify-code", rateLimiter.LimitWithOptions("auth-send-verify-code", 5, time.Minute, middleware.RateLimitOptions{
|
||||||
|
FailureMode: middleware.RateLimitFailClose,
|
||||||
|
}), h.Auth.SendVerifyCode)
|
||||||
// Token刷新接口添加速率限制:每分钟最多 30 次(Redis 故障时 fail-close)
|
// Token刷新接口添加速率限制:每分钟最多 30 次(Redis 故障时 fail-close)
|
||||||
auth.POST("/refresh", rateLimiter.LimitWithOptions("refresh-token", 30, time.Minute, middleware.RateLimitOptions{
|
auth.POST("/refresh", rateLimiter.LimitWithOptions("refresh-token", 30, time.Minute, middleware.RateLimitOptions{
|
||||||
FailureMode: middleware.RateLimitFailClose,
|
FailureMode: middleware.RateLimitFailClose,
|
||||||
|
|||||||
@@ -0,0 +1,111 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package routes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
|
||||||
|
)
|
||||||
|
|
||||||
|
const authRouteRedisImageTag = "redis:8.4-alpine"
|
||||||
|
|
||||||
|
func TestAuthRegisterRateLimitThresholdHitReturns429(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
rdb := startAuthRouteRedis(t, ctx)
|
||||||
|
|
||||||
|
router := newAuthRoutesTestRouter(rdb)
|
||||||
|
const path = "/api/v1/auth/register"
|
||||||
|
|
||||||
|
for i := 1; i <= 6; i++ {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.RemoteAddr = "198.51.100.10:23456"
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if i <= 5 {
|
||||||
|
require.Equal(t, http.StatusBadRequest, w.Code, "第 %d 次请求应先进入业务校验", i)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
require.Equal(t, http.StatusTooManyRequests, w.Code, "第 6 次请求应命中限流")
|
||||||
|
require.Contains(t, w.Body.String(), "rate limit exceeded")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func startAuthRouteRedis(t *testing.T, ctx context.Context) *redis.Client {
|
||||||
|
t.Helper()
|
||||||
|
ensureAuthRouteDockerAvailable(t)
|
||||||
|
|
||||||
|
redisContainer, err := tcredis.Run(ctx, authRouteRedisImageTag)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = redisContainer.Terminate(ctx)
|
||||||
|
})
|
||||||
|
|
||||||
|
redisHost, err := redisContainer.Host(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rdb := redis.NewClient(&redis.Options{
|
||||||
|
Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()),
|
||||||
|
DB: 0,
|
||||||
|
})
|
||||||
|
require.NoError(t, rdb.Ping(ctx).Err())
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = rdb.Close()
|
||||||
|
})
|
||||||
|
return rdb
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureAuthRouteDockerAvailable(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
if authRouteDockerAvailable() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Skip("Docker 未启用,跳过认证限流集成测试")
|
||||||
|
}
|
||||||
|
|
||||||
|
func authRouteDockerAvailable() bool {
|
||||||
|
if os.Getenv("DOCKER_HOST") != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
socketCandidates := []string{
|
||||||
|
"/var/run/docker.sock",
|
||||||
|
filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "docker.sock"),
|
||||||
|
filepath.Join(authRouteUserHomeDir(), ".docker", "run", "docker.sock"),
|
||||||
|
filepath.Join(authRouteUserHomeDir(), ".docker", "desktop", "docker.sock"),
|
||||||
|
filepath.Join("/run/user", strconv.Itoa(os.Getuid()), "docker.sock"),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, socket := range socketCandidates {
|
||||||
|
if socket == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(socket); err == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func authRouteUserHomeDir() string {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return home
|
||||||
|
}
|
||||||
67
backend/internal/server/routes/auth_rate_limit_test.go
Normal file
67
backend/internal/server/routes/auth_rate_limit_test.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package routes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||||
|
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
v1 := router.Group("/api/v1")
|
||||||
|
|
||||||
|
RegisterAuthRoutes(
|
||||||
|
v1,
|
||||||
|
&handler.Handlers{
|
||||||
|
Auth: &handler.AuthHandler{},
|
||||||
|
Setting: &handler.SettingHandler{},
|
||||||
|
},
|
||||||
|
servermiddleware.JWTAuthMiddleware(func(c *gin.Context) {
|
||||||
|
c.Next()
|
||||||
|
}),
|
||||||
|
redisClient,
|
||||||
|
)
|
||||||
|
|
||||||
|
return router
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthRoutesRateLimitFailCloseWhenRedisUnavailable(t *testing.T) {
|
||||||
|
rdb := redis.NewClient(&redis.Options{
|
||||||
|
Addr: "127.0.0.1:1",
|
||||||
|
DialTimeout: 50 * time.Millisecond,
|
||||||
|
ReadTimeout: 50 * time.Millisecond,
|
||||||
|
WriteTimeout: 50 * time.Millisecond,
|
||||||
|
})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = rdb.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
router := newAuthRoutesTestRouter(rdb)
|
||||||
|
paths := []string{
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
"/api/v1/auth/login/2fa",
|
||||||
|
"/api/v1/auth/send-verify-code",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, path := range paths {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.RemoteAddr = "203.0.113.10:12345"
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusTooManyRequests, w.Code, "path=%s", path)
|
||||||
|
require.Contains(t, w.Body.String(), "rate limit exceeded", "path=%s", path)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3332,7 +3332,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
|
|
||||||
// 不需要重试(成功或不可重试的错误),跳出循环
|
// 不需要重试(成功或不可重试的错误),跳出循环
|
||||||
// DEBUG: 输出响应 headers(用于检测 rate limit 信息)
|
// DEBUG: 输出响应 headers(用于检测 rate limit 信息)
|
||||||
if account.Platform == PlatformGemini && resp.StatusCode < 400 {
|
if account.Platform == PlatformGemini && resp.StatusCode < 400 && s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders {
|
||||||
logger.LegacyPrintf("service.gateway", "[DEBUG] Gemini API Response Headers for account %d:", account.ID)
|
logger.LegacyPrintf("service.gateway", "[DEBUG] Gemini API Response Headers for account %d:", account.ID)
|
||||||
for k, v := range resp.Header {
|
for k, v := range resp.Header {
|
||||||
logger.LegacyPrintf("service.gateway", "[DEBUG] %s: %v", k, v)
|
logger.LegacyPrintf("service.gateway", "[DEBUG] %s: %v", k, v)
|
||||||
@@ -4467,8 +4467,19 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
|||||||
// 更新5h窗口状态
|
// 更新5h窗口状态
|
||||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||||
|
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||||
|
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{
|
||||||
|
"type": "error",
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "upstream_error",
|
||||||
|
"message": "Upstream response too large",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4990,9 +5001,15 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 读取响应体
|
// 读取响应体
|
||||||
respBody, err := io.ReadAll(resp.Body)
|
maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||||
|
respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||||
|
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||||
|
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
|
||||||
|
return err
|
||||||
|
}
|
||||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -5007,9 +5024,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||||
if retryErr == nil {
|
if retryErr == nil {
|
||||||
resp = retryResp
|
resp = retryResp
|
||||||
respBody, err = io.ReadAll(resp.Body)
|
respBody, err = readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||||
|
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||||
|
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
|
||||||
|
return err
|
||||||
|
}
|
||||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2358,29 +2358,36 @@ type UpstreamHTTPResult struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) {
|
func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) {
|
||||||
// Log response headers for debugging
|
if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders {
|
||||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Response Headers ==========")
|
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Response Headers ==========")
|
||||||
for key, values := range resp.Header {
|
for key, values := range resp.Header {
|
||||||
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
||||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
|
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================")
|
||||||
}
|
}
|
||||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================")
|
|
||||||
|
|
||||||
respBody, err := io.ReadAll(resp.Body)
|
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||||
|
respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||||
|
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "upstream_error",
|
||||||
|
"message": "Upstream response too large",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var parsed map[string]any
|
|
||||||
if isOAuth {
|
if isOAuth {
|
||||||
unwrappedBody, uwErr := unwrapGeminiResponse(respBody)
|
unwrappedBody, uwErr := unwrapGeminiResponse(respBody)
|
||||||
if uwErr == nil {
|
if uwErr == nil {
|
||||||
respBody = unwrappedBody
|
respBody = unwrappedBody
|
||||||
}
|
}
|
||||||
_ = json.Unmarshal(respBody, &parsed)
|
|
||||||
} else {
|
|
||||||
_ = json.Unmarshal(respBody, &parsed)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||||
@@ -2398,14 +2405,15 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) {
|
func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) {
|
||||||
// Log response headers for debugging
|
if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders {
|
||||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Streaming Response Headers ==========")
|
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Streaming Response Headers ==========")
|
||||||
for key, values := range resp.Header {
|
for key, values := range resp.Header {
|
||||||
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
||||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
|
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================")
|
||||||
}
|
}
|
||||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================")
|
|
||||||
|
|
||||||
if s.cfg != nil {
|
if s.cfg != nil {
|
||||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||||
|
|||||||
@@ -3,10 +3,15 @@ package service
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -133,6 +138,38 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGeminiHandleNativeNonStreamingResponse_DebugDisabledDoesNotEmitHeaderLogs(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
logSink, restore := captureStructuredLog(t)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
svc := &GeminiMessagesCompatService{
|
||||||
|
cfg: &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
GeminiDebugResponseHeaders: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"application/json"},
|
||||||
|
"X-RateLimit-Limit": []string{"60"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":2}}`)),
|
||||||
|
}
|
||||||
|
|
||||||
|
usage, err := svc.handleNativeNonStreamingResponse(c, resp, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, usage)
|
||||||
|
require.False(t, logSink.ContainsMessage("[GeminiAPI]"), "debug 关闭时不应输出 Gemini 响应头日志")
|
||||||
|
}
|
||||||
|
|
||||||
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
|
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
|
||||||
claudeReq := map[string]any{
|
claudeReq := map[string]any{
|
||||||
"model": "claude-haiku-4-5-20251001",
|
"model": "claude-haiku-4-5-20251001",
|
||||||
|
|||||||
@@ -1741,8 +1741,18 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
|
|||||||
resp *http.Response,
|
resp *http.Response,
|
||||||
c *gin.Context,
|
c *gin.Context,
|
||||||
) (*OpenAIUsage, error) {
|
) (*OpenAIUsage, error) {
|
||||||
body, err := io.ReadAll(resp.Body)
|
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||||
|
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||||
|
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "upstream_error",
|
||||||
|
"message": "Upstream response too large",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2371,8 +2381,18 @@ func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
||||||
body, err := io.ReadAll(resp.Body)
|
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||||
|
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||||
|
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "upstream_error",
|
||||||
|
"message": "Upstream response too large",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2930,6 +2950,25 @@ func normalizeOpenAIPassthroughOAuthBody(body []byte) ([]byte, bool, error) {
|
|||||||
return normalized, changed, nil
|
return normalized, changed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func detectOpenAIPassthroughInstructionsRejectReason(reqModel string, body []byte) string {
|
||||||
|
model := strings.ToLower(strings.TrimSpace(reqModel))
|
||||||
|
if !strings.Contains(model, "codex") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
instructions := gjson.GetBytes(body, "instructions")
|
||||||
|
if !instructions.Exists() {
|
||||||
|
return "instructions_missing"
|
||||||
|
}
|
||||||
|
if instructions.Type != gjson.String {
|
||||||
|
return "instructions_not_string"
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(instructions.String()) == "" {
|
||||||
|
return "instructions_empty"
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string {
|
func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string {
|
||||||
reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
|
reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
|
||||||
if reasoningEffort == "" {
|
if reasoningEffort == "" {
|
||||||
@@ -3002,22 +3041,3 @@ func normalizeOpenAIReasoningEffort(raw string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func detectOpenAIPassthroughInstructionsRejectReason(reqModel string, body []byte) string {
|
|
||||||
model := strings.ToLower(strings.TrimSpace(reqModel))
|
|
||||||
if !strings.Contains(model, "codex") {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
instructions := gjson.GetBytes(body, "instructions")
|
|
||||||
if !instructions.Exists() {
|
|
||||||
return "instructions_missing"
|
|
||||||
}
|
|
||||||
if instructions.Type != gjson.String {
|
|
||||||
return "instructions_not_string"
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(instructions.String()) == "" {
|
|
||||||
return "instructions_empty"
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|||||||
38
backend/internal/service/upstream_response_limit.go
Normal file
38
backend/internal/service/upstream_response_limit.go
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrUpstreamResponseBodyTooLarge = errors.New("upstream response body too large")
|
||||||
|
|
||||||
|
const defaultUpstreamResponseReadMaxBytes int64 = 8 * 1024 * 1024
|
||||||
|
|
||||||
|
func resolveUpstreamResponseReadLimit(cfg *config.Config) int64 {
|
||||||
|
if cfg != nil && cfg.Gateway.UpstreamResponseReadMaxBytes > 0 {
|
||||||
|
return cfg.Gateway.UpstreamResponseReadMaxBytes
|
||||||
|
}
|
||||||
|
return defaultUpstreamResponseReadMaxBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
func readUpstreamResponseBodyLimited(reader io.Reader, maxBytes int64) ([]byte, error) {
|
||||||
|
if reader == nil {
|
||||||
|
return nil, errors.New("response body is nil")
|
||||||
|
}
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = defaultUpstreamResponseReadMaxBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(io.LimitReader(reader, maxBytes+1))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if int64(len(body)) > maxBytes {
|
||||||
|
return nil, fmt.Errorf("%w: limit=%d", ErrUpstreamResponseBodyTooLarge, maxBytes)
|
||||||
|
}
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
37
backend/internal/service/upstream_response_limit_test.go
Normal file
37
backend/internal/service/upstream_response_limit_test.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestResolveUpstreamResponseReadLimit(t *testing.T) {
|
||||||
|
t.Run("use default when config missing", func(t *testing.T) {
|
||||||
|
require.Equal(t, defaultUpstreamResponseReadMaxBytes, resolveUpstreamResponseReadLimit(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("use configured value", func(t *testing.T) {
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Gateway.UpstreamResponseReadMaxBytes = 1234
|
||||||
|
require.Equal(t, int64(1234), resolveUpstreamResponseReadLimit(cfg))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadUpstreamResponseBodyLimited(t *testing.T) {
|
||||||
|
t.Run("within limit", func(t *testing.T) {
|
||||||
|
body, err := readUpstreamResponseBodyLimited(bytes.NewReader([]byte("ok")), 2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []byte("ok"), body)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("exceeds limit", func(t *testing.T) {
|
||||||
|
body, err := readUpstreamResponseBodyLimited(bytes.NewReader([]byte("toolong")), 3)
|
||||||
|
require.Nil(t, body)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.True(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge))
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -146,6 +146,15 @@ gateway:
|
|||||||
# Max request body size in bytes (default: 100MB)
|
# Max request body size in bytes (default: 100MB)
|
||||||
# 请求体最大字节数(默认 100MB)
|
# 请求体最大字节数(默认 100MB)
|
||||||
max_body_size: 104857600
|
max_body_size: 104857600
|
||||||
|
# Max bytes to read for non-stream upstream responses (default: 8MB)
|
||||||
|
# 非流式上游响应体读取上限(默认 8MB)
|
||||||
|
upstream_response_read_max_bytes: 8388608
|
||||||
|
# Max bytes to read for proxy probe responses (default: 1MB)
|
||||||
|
# 代理探测响应体读取上限(默认 1MB)
|
||||||
|
proxy_probe_response_read_max_bytes: 1048576
|
||||||
|
# Enable Gemini upstream response header debug logs (default: false)
|
||||||
|
# 是否开启 Gemini 上游响应头调试日志(默认 false)
|
||||||
|
gemini_debug_response_headers: false
|
||||||
# Sora max request body size in bytes (0=use max_body_size)
|
# Sora max request body size in bytes (0=use max_body_size)
|
||||||
# Sora 请求体最大字节数(0=使用 max_body_size)
|
# Sora 请求体最大字节数(0=使用 max_body_size)
|
||||||
sora_max_body_size: 268435456
|
sora_max_body_size: 268435456
|
||||||
|
|||||||
Reference in New Issue
Block a user