From 048ed061c2628994315b7c6f6885c55f7db464e7 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Mon, 5 Jan 2026 14:41:08 +0800 Subject: [PATCH] =?UTF-8?q?fix(=E5=AE=89=E5=85=A8):=20=E5=85=B3=E9=97=AD?= =?UTF-8?q?=E7=99=BD=E5=90=8D=E5=8D=95=E6=97=B6=E4=BF=9D=E7=95=99=E6=9C=80?= =?UTF-8?q?=E5=B0=8F=E6=A0=A1=E9=AA=8C=E4=B8=8E=E9=BB=98=E8=AE=A4=E7=99=BD?= =?UTF-8?q?=E5=90=8D=E5=8D=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 实现 allow_insecure_http 并在关闭校验时执行最小格式验证 - 关闭 allowlist 时要求 URL 可解析且 scheme 合规 - 响应头过滤关闭时使用默认白名单策略 - 更新相关文档、示例与测试覆盖 --- README.md | 3 +- README_CN.md | 3 +- backend/internal/config/config.go | 7 ++- backend/internal/config/config_test.go | 3 ++ .../internal/service/account_test_service.go | 2 +- backend/internal/service/crs_sync_service.go | 6 +++ backend/internal/service/gateway_service.go | 6 ++- .../service/gemini_messages_compat_service.go | 6 ++- .../service/openai_gateway_service.go | 6 ++- .../service/openai_gateway_service_test.go | 33 +++++++++++--- backend/internal/service/pricing_service.go | 6 ++- .../util/responseheaders/responseheaders.go | 45 +++++++------------ .../responseheaders/responseheaders_test.go | 14 +++--- .../internal/util/urlvalidator/validator.go | 33 ++++++++++++++ .../util/urlvalidator/validator_test.go | 24 ++++++++++ deploy/config.example.yaml | 4 +- 16 files changed, 151 insertions(+), 50 deletions(-) create mode 100644 backend/internal/util/urlvalidator/validator_test.go diff --git a/README.md b/README.md index 58ba87cd..0a821184 100644 --- a/README.md +++ b/README.md @@ -273,7 +273,8 @@ Additional security-related options are available in `config.yaml`: - `cors.allowed_origins` for CORS allowlist - `security.url_allowlist` for upstream/pricing/CRS host allowlists - `security.url_allowlist.enabled` to disable URL validation (use with caution) -- `security.response_headers.enabled` to disable response header filtering +- `security.url_allowlist.allow_insecure_http` to allow http URLs when validation is disabled +- `security.response_headers.enabled` to enable configurable response header filtering (disabled uses default allowlist) - `security.csp` to control Content-Security-Policy headers - `billing.circuit_breaker` to fail closed on billing errors - `server.trusted_proxies` to enable X-Forwarded-For parsing diff --git a/README_CN.md b/README_CN.md index 6ee258e5..c126fa0f 100644 --- a/README_CN.md +++ b/README_CN.md @@ -273,7 +273,8 @@ default: - `cors.allowed_origins` 配置 CORS 白名单 - `security.url_allowlist` 配置上游/价格数据/CRS 主机白名单 - `security.url_allowlist.enabled` 可关闭 URL 校验(慎用) -- `security.response_headers.enabled` 可关闭响应头过滤 +- `security.url_allowlist.allow_insecure_http` 关闭校验时允许 http URL +- `security.response_headers.enabled` 可启用可配置响应头过滤(关闭时使用默认白名单) - `security.csp` 配置 Content-Security-Policy - `billing.circuit_breaker` 计费异常时 fail-closed - `server.trusted_proxies` 启用可信代理解析 X-Forwarded-For diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 4fb94e88..0786b62f 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -131,6 +131,8 @@ type URLAllowlistConfig struct { PricingHosts []string `mapstructure:"pricing_hosts"` CRSHosts []string `mapstructure:"crs_hosts"` AllowPrivateHosts bool `mapstructure:"allow_private_hosts"` + // 关闭 URL 白名单校验时,是否允许 http URL(默认只允许 https) + AllowInsecureHTTP bool `mapstructure:"allow_insecure_http"` } type ResponseHeaderConfig struct { @@ -384,10 +386,10 @@ func Load() (*Config, error) { } if !cfg.Security.URLAllowlist.Enabled { - log.Println("Warning: security.url_allowlist.enabled=false; URL validation is disabled.") + log.Println("Warning: security.url_allowlist.enabled=false; allowlist/SSRF checks disabled (minimal format validation only).") } if !cfg.Security.ResponseHeaders.Enabled { - log.Println("Warning: security.response_headers.enabled=false; response header filtering is disabled.") + log.Println("Warning: security.response_headers.enabled=false; configurable header filtering disabled (default allowlist only).") } if cfg.Server.Mode != "release" && cfg.JWT.Secret != "" && isWeakJWTSecret(cfg.JWT.Secret) { @@ -435,6 +437,7 @@ func setDefaults() { }) viper.SetDefault("security.url_allowlist.crs_hosts", []string{}) viper.SetDefault("security.url_allowlist.allow_private_hosts", false) + viper.SetDefault("security.url_allowlist.allow_insecure_http", false) viper.SetDefault("security.response_headers.enabled", false) viper.SetDefault("security.response_headers.additional_allowed", []string{}) viper.SetDefault("security.response_headers.force_remove", []string{}) diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index a364c2a9..1f6ed58e 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -80,6 +80,9 @@ func TestLoadDefaultSecurityToggles(t *testing.T) { if cfg.Security.URLAllowlist.Enabled { t.Fatalf("URLAllowlist.Enabled = true, want false") } + if cfg.Security.URLAllowlist.AllowInsecureHTTP { + t.Fatalf("URLAllowlist.AllowInsecureHTTP = true, want false") + } if cfg.Security.ResponseHeaders.Enabled { t.Fatalf("ResponseHeaders.Enabled = true, want false") } diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 9adadc10..7121a13d 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -73,7 +73,7 @@ func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error) return "", errors.New("config is not available") } if !s.cfg.Security.URLAllowlist.Enabled { - return strings.TrimSpace(raw), nil + return urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) } normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{ AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts, diff --git a/backend/internal/service/crs_sync_service.go b/backend/internal/service/crs_sync_service.go index c3a08beb..a6ccb967 100644 --- a/backend/internal/service/crs_sync_service.go +++ b/backend/internal/service/crs_sync_service.go @@ -201,6 +201,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput return nil, err } baseURL = normalized + } else { + normalized, err := urlvalidator.ValidateURLFormat(baseURL, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) + if err != nil { + return nil, fmt.Errorf("invalid base_url: %w", err) + } + baseURL = normalized } if strings.TrimSpace(input.Username) == "" || strings.TrimSpace(input.Password) == "" { return nil, errors.New("username and password are required") diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 65f5cec4..5bc4d296 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -2206,7 +2206,11 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) { if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled { - return strings.TrimSpace(raw), nil + normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) + if err != nil { + return "", fmt.Errorf("invalid base_url: %w", err) + } + return normalized, nil } normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{ AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts, diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 92332f54..37909fcc 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -238,7 +238,11 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) { if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled { - return strings.TrimSpace(raw), nil + normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) + if err != nil { + return "", fmt.Errorf("invalid base_url: %w", err) + } + return normalized, nil } normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{ AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts, diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 029d011e..08bd8df5 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1048,7 +1048,11 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) { if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled { - return strings.TrimSpace(raw), nil + normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) + if err != nil { + return "", fmt.Errorf("invalid base_url: %w", err) + } + return normalized, nil } normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{ AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts, diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 2acb9aef..8562d940 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -174,7 +174,7 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) { Body: pr, Header: http.Header{ "Cache-Control": []string{"upstream"}, - "X-Test": []string{"value"}, + "X-Request-Id": []string{"req-123"}, "Content-Type": []string{"application/custom"}, }, } @@ -196,8 +196,8 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) { if rec.Header().Get("Content-Type") != "text/event-stream" { t.Fatalf("expected Content-Type override, got %q", rec.Header().Get("Content-Type")) } - if rec.Header().Get("X-Test") != "value" { - t.Fatalf("expected X-Test passthrough, got %q", rec.Header().Get("X-Test")) + if rec.Header().Get("X-Request-Id") != "req-123" { + t.Fatalf("expected X-Request-Id passthrough, got %q", rec.Header().Get("X-Request-Id")) } } @@ -226,7 +226,7 @@ func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) { } } -func TestOpenAIValidateUpstreamBaseURLDisabledSkipsValidation(t *testing.T) { +func TestOpenAIValidateUpstreamBaseURLDisabledRequiresHTTPS(t *testing.T) { cfg := &config.Config{ Security: config.SecurityConfig{ URLAllowlist: config.URLAllowlistConfig{Enabled: false}, @@ -234,9 +234,32 @@ func TestOpenAIValidateUpstreamBaseURLDisabledSkipsValidation(t *testing.T) { } svc := &OpenAIGatewayService{cfg: cfg} + if _, err := svc.validateUpstreamBaseURL("http://not-https.example.com"); err == nil { + t.Fatalf("expected http to be rejected when allow_insecure_http is false") + } + normalized, err := svc.validateUpstreamBaseURL("https://example.com") + if err != nil { + t.Fatalf("expected https to be allowed when allowlist disabled, got %v", err) + } + if normalized != "https://example.com" { + t.Fatalf("expected raw url passthrough, got %q", normalized) + } +} + +func TestOpenAIValidateUpstreamBaseURLDisabledAllowsHTTP(t *testing.T) { + cfg := &config.Config{ + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{ + Enabled: false, + AllowInsecureHTTP: true, + }, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + normalized, err := svc.validateUpstreamBaseURL("http://not-https.example.com") if err != nil { - t.Fatalf("expected no error when allowlist disabled, got %v", err) + t.Fatalf("expected http allowed when allow_insecure_http is true, got %v", err) } if normalized != "http://not-https.example.com" { t.Fatalf("expected raw url passthrough, got %q", normalized) diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index 27e04a1e..30b53c83 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -410,7 +410,11 @@ func (s *PricingService) fetchRemoteHash() (string, error) { func (s *PricingService) validatePricingURL(raw string) (string, error) { if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled { - return strings.TrimSpace(raw), nil + normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) + if err != nil { + return "", fmt.Errorf("invalid pricing url: %w", err) + } + return normalized, nil } normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{ AllowedHosts: s.cfg.Security.URLAllowlist.PricingHosts, diff --git a/backend/internal/util/responseheaders/responseheaders.go b/backend/internal/util/responseheaders/responseheaders.go index f030225f..86c3f624 100644 --- a/backend/internal/util/responseheaders/responseheaders.go +++ b/backend/internal/util/responseheaders/responseheaders.go @@ -42,28 +42,31 @@ var hopByHopHeaders = map[string]struct{}{ } func FilterHeaders(src http.Header, cfg config.ResponseHeaderConfig) http.Header { - if !cfg.Enabled { - return passThroughHeaders(src) - } allowed := make(map[string]struct{}, len(defaultAllowed)+len(cfg.AdditionalAllowed)) for key := range defaultAllowed { allowed[key] = struct{}{} } - for _, key := range cfg.AdditionalAllowed { - normalized := strings.ToLower(strings.TrimSpace(key)) - if normalized == "" { - continue + // 关闭时只使用默认白名单,additional/force_remove 不生效 + if cfg.Enabled { + for _, key := range cfg.AdditionalAllowed { + normalized := strings.ToLower(strings.TrimSpace(key)) + if normalized == "" { + continue + } + allowed[normalized] = struct{}{} } - allowed[normalized] = struct{}{} } - forceRemove := make(map[string]struct{}, len(cfg.ForceRemove)) - for _, key := range cfg.ForceRemove { - normalized := strings.ToLower(strings.TrimSpace(key)) - if normalized == "" { - continue + forceRemove := map[string]struct{}{} + if cfg.Enabled { + forceRemove = make(map[string]struct{}, len(cfg.ForceRemove)) + for _, key := range cfg.ForceRemove { + normalized := strings.ToLower(strings.TrimSpace(key)) + if normalized == "" { + continue + } + forceRemove[normalized] = struct{}{} } - forceRemove[normalized] = struct{}{} } filtered := make(http.Header, len(src)) @@ -94,17 +97,3 @@ func WriteFilteredHeaders(dst http.Header, src http.Header, cfg config.ResponseH } } } - -func passThroughHeaders(src http.Header) http.Header { - filtered := make(http.Header, len(src)) - for key, values := range src { - lower := strings.ToLower(key) - if _, isHopByHop := hopByHopHeaders[lower]; isHopByHop { - continue - } - for _, value := range values { - filtered.Add(key, value) - } - } - return filtered -} diff --git a/backend/internal/util/responseheaders/responseheaders_test.go b/backend/internal/util/responseheaders/responseheaders_test.go index 3fb03d12..f7343267 100644 --- a/backend/internal/util/responseheaders/responseheaders_test.go +++ b/backend/internal/util/responseheaders/responseheaders_test.go @@ -7,28 +7,28 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" ) -func TestFilterHeadersDisabledPassThrough(t *testing.T) { +func TestFilterHeadersDisabledUsesDefaultAllowlist(t *testing.T) { src := http.Header{} src.Add("Content-Type", "application/json") + src.Add("X-Request-Id", "req-123") src.Add("X-Test", "ok") - src.Add("X-Remove", "keep") src.Add("Connection", "keep-alive") src.Add("Content-Length", "123") cfg := config.ResponseHeaderConfig{ Enabled: false, - ForceRemove: []string{"x-test"}, + ForceRemove: []string{"x-request-id"}, } filtered := FilterHeaders(src, cfg) if filtered.Get("Content-Type") != "application/json" { t.Fatalf("expected Content-Type passthrough, got %q", filtered.Get("Content-Type")) } - if filtered.Get("X-Test") != "ok" { - t.Fatalf("expected X-Test passthrough, got %q", filtered.Get("X-Test")) + if filtered.Get("X-Request-Id") != "req-123" { + t.Fatalf("expected X-Request-Id allowed, got %q", filtered.Get("X-Request-Id")) } - if filtered.Get("X-Remove") != "keep" { - t.Fatalf("expected X-Remove passthrough, got %q", filtered.Get("X-Remove")) + if filtered.Get("X-Test") != "" { + t.Fatalf("expected X-Test removed, got %q", filtered.Get("X-Test")) } if filtered.Get("Connection") != "" { t.Fatalf("expected Connection to be removed, got %q", filtered.Get("Connection")) diff --git a/backend/internal/util/urlvalidator/validator.go b/backend/internal/util/urlvalidator/validator.go index b8f8c72f..56a888b9 100644 --- a/backend/internal/util/urlvalidator/validator.go +++ b/backend/internal/util/urlvalidator/validator.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "net/url" + "strconv" "strings" "time" ) @@ -16,6 +17,38 @@ type ValidationOptions struct { AllowPrivate bool } +func ValidateURLFormat(raw string, allowInsecureHTTP bool) (string, error) { + // 最小格式校验:仅保证 URL 可解析且 scheme 合规,不做白名单/私网/SSRF 校验 + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "", errors.New("url is required") + } + + parsed, err := url.Parse(trimmed) + if err != nil || parsed.Scheme == "" || parsed.Host == "" { + return "", fmt.Errorf("invalid url: %s", trimmed) + } + + scheme := strings.ToLower(parsed.Scheme) + if scheme != "https" && (!allowInsecureHTTP || scheme != "http") { + return "", fmt.Errorf("invalid url scheme: %s", parsed.Scheme) + } + + host := strings.TrimSpace(parsed.Hostname()) + if host == "" { + return "", errors.New("invalid host") + } + + if port := parsed.Port(); port != "" { + num, err := strconv.Atoi(port) + if err != nil || num <= 0 || num > 65535 { + return "", fmt.Errorf("invalid port: %s", port) + } + } + + return trimmed, nil +} + func ValidateHTTPSURL(raw string, opts ValidationOptions) (string, error) { trimmed := strings.TrimSpace(raw) if trimmed == "" { diff --git a/backend/internal/util/urlvalidator/validator_test.go b/backend/internal/util/urlvalidator/validator_test.go new file mode 100644 index 00000000..b7f9ffed --- /dev/null +++ b/backend/internal/util/urlvalidator/validator_test.go @@ -0,0 +1,24 @@ +package urlvalidator + +import "testing" + +func TestValidateURLFormat(t *testing.T) { + if _, err := ValidateURLFormat("", false); err == nil { + t.Fatalf("expected empty url to fail") + } + if _, err := ValidateURLFormat("://bad", false); err == nil { + t.Fatalf("expected invalid url to fail") + } + if _, err := ValidateURLFormat("http://example.com", false); err == nil { + t.Fatalf("expected http to fail when allow_insecure_http is false") + } + if _, err := ValidateURLFormat("https://example.com", false); err != nil { + t.Fatalf("expected https to pass, got %v", err) + } + if _, err := ValidateURLFormat("http://example.com", true); err != nil { + t.Fatalf("expected http to pass when allow_insecure_http is true, got %v", err) + } + if _, err := ValidateURLFormat("https://example.com:bad", true); err == nil { + t.Fatalf("expected invalid port to fail") + } +} diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 34c21e77..0f4babb5 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -56,8 +56,10 @@ security: crs_hosts: [] # Allow localhost/private IPs for upstream/pricing/CRS (use only in trusted networks) allow_private_hosts: false + # Allow http:// URLs when allowlist is disabled (default: false, require https) + allow_insecure_http: false response_headers: - # Enable response header filtering (disable to pass through upstream headers) + # Enable configurable response header filtering (disable to use default allowlist) enabled: false # Extra allowed response headers from upstream additional_allowed: []