feat(安全): 添加安全开关并完善测试流程

实现安全开关默认关闭与响应头透传逻辑
- URL 校验与响应头过滤支持开关并覆盖流式路径
- 非流式 Content-Type 透传/默认值按配置生效
- 接入 go test、golangci-lint 与前端 lint/typecheck
- 补充相关测试与配置/文档说明
This commit is contained in:
yangjianbo
2026-01-05 13:54:43 +08:00
parent c8e5455df0
commit 794a9f969b
24 changed files with 1811 additions and 14 deletions

View File

@@ -1,8 +1,12 @@
.PHONY: build test-unit test-integration test-e2e
.PHONY: build test test-unit test-integration test-e2e
build:
go build -o bin/server ./cmd/server
test:
go test ./...
golangci-lint run ./...
test-unit:
go test -tags=unit ./...

View File

@@ -126,6 +126,7 @@ type SecurityConfig struct {
}
type URLAllowlistConfig struct {
Enabled bool `mapstructure:"enabled"`
UpstreamHosts []string `mapstructure:"upstream_hosts"`
PricingHosts []string `mapstructure:"pricing_hosts"`
CRSHosts []string `mapstructure:"crs_hosts"`
@@ -133,6 +134,7 @@ type URLAllowlistConfig struct {
}
type ResponseHeaderConfig struct {
Enabled bool `mapstructure:"enabled"`
AdditionalAllowed []string `mapstructure:"additional_allowed"`
ForceRemove []string `mapstructure:"force_remove"`
}
@@ -381,6 +383,13 @@ func Load() (*Config, error) {
return nil, fmt.Errorf("validate config error: %w", err)
}
if !cfg.Security.URLAllowlist.Enabled {
log.Println("Warning: security.url_allowlist.enabled=false; URL validation is disabled.")
}
if !cfg.Security.ResponseHeaders.Enabled {
log.Println("Warning: security.response_headers.enabled=false; response header filtering is disabled.")
}
if cfg.Server.Mode != "release" && cfg.JWT.Secret != "" && isWeakJWTSecret(cfg.JWT.Secret) {
log.Println("Warning: JWT secret appears weak; use a 32+ character random secret in production.")
}
@@ -410,6 +419,7 @@ func setDefaults() {
viper.SetDefault("cors.allow_credentials", true)
// Security
viper.SetDefault("security.url_allowlist.enabled", false)
viper.SetDefault("security.url_allowlist.upstream_hosts", []string{
"api.openai.com",
"api.anthropic.com",
@@ -425,6 +435,7 @@ func setDefaults() {
})
viper.SetDefault("security.url_allowlist.crs_hosts", []string{})
viper.SetDefault("security.url_allowlist.allow_private_hosts", false)
viper.SetDefault("security.response_headers.enabled", false)
viper.SetDefault("security.response_headers.additional_allowed", []string{})
viper.SetDefault("security.response_headers.force_remove", []string{})
viper.SetDefault("security.csp.enabled", true)

View File

@@ -68,3 +68,19 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) {
t.Fatalf("StickySessionMaxWaiting = %d, want 5", cfg.Gateway.Scheduling.StickySessionMaxWaiting)
}
}
func TestLoadDefaultSecurityToggles(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Security.URLAllowlist.Enabled {
t.Fatalf("URLAllowlist.Enabled = true, want false")
}
if cfg.Security.ResponseHeaders.Enabled {
t.Fatalf("ResponseHeaders.Enabled = true, want false")
}
}

View File

@@ -379,7 +379,7 @@ func writeUpstreamResponse(c *gin.Context, res *service.UpstreamHTTPResult) {
}
for k, vv := range res.Headers {
// Avoid overriding content-length and hop-by-hop headers.
if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Connection") || strings.EqualFold(k, "Www-Authenticate") {
if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Connection") {
continue
}
for _, v := range vv {

View File

@@ -154,6 +154,9 @@ func (s *httpUpstreamService) shouldValidateResolvedIP() bool {
if s.cfg == nil {
return false
}
if !s.cfg.Security.URLAllowlist.Enabled {
return false
}
return !s.cfg.Security.URLAllowlist.AllowPrivateHosts
}

View File

@@ -19,12 +19,14 @@ type pricingRemoteClient struct {
func NewPricingRemoteClient(cfg *config.Config) service.PricingRemoteClient {
allowPrivate := false
validateResolvedIP := true
if cfg != nil {
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
}
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
ValidateResolvedIP: validateResolvedIP,
AllowPrivateHosts: allowPrivate,
})
if err != nil {

View File

@@ -17,9 +17,11 @@ import (
func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
insecure := false
allowPrivate := false
validateResolvedIP := true
if cfg != nil {
insecure = cfg.Security.ProxyProbe.InsecureSkipVerify
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
}
if insecure {
log.Printf("[ProxyProbe] Warning: TLS verification is disabled for proxy probing.")
@@ -28,6 +30,7 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
ipInfoURL: defaultIPInfoURL,
insecureSkipVerify: insecure,
allowPrivateHosts: allowPrivate,
validateResolvedIP: validateResolvedIP,
}
}
@@ -37,6 +40,7 @@ type proxyProbeService struct {
ipInfoURL string
insecureSkipVerify bool
allowPrivateHosts bool
validateResolvedIP bool
}
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
@@ -45,7 +49,7 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
Timeout: 15 * time.Second,
InsecureSkipVerify: s.insecureSkipVerify,
ProxyStrict: true,
ValidateResolvedIP: true,
ValidateResolvedIP: s.validateResolvedIP,
AllowPrivateHosts: s.allowPrivateHosts,
})
if err != nil {

View File

@@ -72,6 +72,9 @@ func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error)
if s.cfg == nil {
return "", errors.New("config is not available")
}
if !s.cfg.Security.URLAllowlist.Enabled {
return strings.TrimSpace(raw), nil
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,

View File

@@ -194,9 +194,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
if s.cfg == nil {
return nil, errors.New("config is not available")
}
baseURL, err := normalizeBaseURL(input.BaseURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
if err != nil {
return nil, err
baseURL := strings.TrimSpace(input.BaseURL)
if s.cfg.Security.URLAllowlist.Enabled {
normalized, err := normalizeBaseURL(baseURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
if err != nil {
return nil, err
}
baseURL = normalized
}
if strings.TrimSpace(input.Username) == "" || strings.TrimSpace(input.Password) == "" {
return nil, errors.New("username and password are required")
@@ -204,7 +208,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
client, err := httpclient.GetClient(httpclient.Options{
Timeout: 20 * time.Second,
ValidateResolvedIP: true,
ValidateResolvedIP: s.cfg.Security.URLAllowlist.Enabled,
AllowPrivateHosts: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {

View File

@@ -1583,6 +1583,10 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
if s.cfg != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
}
// 设置SSE响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
@@ -1837,8 +1841,15 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
contentType := "application/json"
if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" {
contentType = upstreamType
}
}
// 写入响应
c.Data(resp.StatusCode, "application/json", body)
c.Data(resp.StatusCode, contentType, body)
return &response.Usage, nil
}
@@ -2194,6 +2205,9 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
}
func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
return strings.TrimSpace(raw), nil
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,

View File

@@ -237,6 +237,9 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit
}
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
return strings.TrimSpace(raw), nil
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
@@ -1720,6 +1723,10 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte
}
log.Printf("[GeminiAPI] ====================================================")
if s.cfg != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
}
c.Status(resp.StatusCode)
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")

View File

@@ -762,6 +762,10 @@ type openaiStreamingResult struct {
}
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
if s.cfg != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
}
// Set SSE response headers
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
@@ -1030,12 +1034,22 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
c.Data(resp.StatusCode, "application/json", body)
contentType := "application/json"
if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" {
contentType = upstreamType
}
}
c.Data(resp.StatusCode, contentType, body)
return usage, nil
}
func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
return strings.TrimSpace(raw), nil
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,

View File

@@ -2,6 +2,7 @@ package service
import (
"bufio"
"bytes"
"errors"
"io"
"net/http"
@@ -88,3 +89,175 @@ func TestOpenAIStreamingTooLong(t *testing.T) {
t.Fatalf("expected response_too_large SSE error, got %q", rec.Body.String())
}
}
func TestOpenAINonStreamingContentTypePassThrough(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Security: config.SecurityConfig{
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
body := []byte(`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader(body)),
Header: http.Header{"Content-Type": []string{"application/vnd.test+json"}},
}
_, err := svc.handleNonStreamingResponse(c.Request.Context(), resp, c, &Account{}, "model", "model")
if err != nil {
t.Fatalf("handleNonStreamingResponse error: %v", err)
}
if !strings.Contains(rec.Header().Get("Content-Type"), "application/vnd.test+json") {
t.Fatalf("expected Content-Type passthrough, got %q", rec.Header().Get("Content-Type"))
}
}
func TestOpenAINonStreamingContentTypeDefault(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Security: config.SecurityConfig{
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
body := []byte(`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader(body)),
Header: http.Header{},
}
_, err := svc.handleNonStreamingResponse(c.Request.Context(), resp, c, &Account{}, "model", "model")
if err != nil {
t.Fatalf("handleNonStreamingResponse error: %v", err)
}
if !strings.Contains(rec.Header().Get("Content-Type"), "application/json") {
t.Fatalf("expected default Content-Type, got %q", rec.Header().Get("Content-Type"))
}
}
func TestOpenAIStreamingHeadersOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Security: config.SecurityConfig{
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
},
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{
"Cache-Control": []string{"upstream"},
"X-Test": []string{"value"},
"Content-Type": []string{"application/custom"},
},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {}\n\n"))
}()
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
_ = pr.Close()
if err != nil {
t.Fatalf("handleStreamingResponse error: %v", err)
}
if rec.Header().Get("Cache-Control") != "no-cache" {
t.Fatalf("expected Cache-Control override, got %q", rec.Header().Get("Cache-Control"))
}
if rec.Header().Get("Content-Type") != "text/event-stream" {
t.Fatalf("expected Content-Type override, got %q", rec.Header().Get("Content-Type"))
}
if rec.Header().Get("X-Test") != "value" {
t.Fatalf("expected X-Test passthrough, got %q", rec.Header().Get("X-Test"))
}
}
func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{Enabled: false},
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Credentials: map[string]any{"base_url": "://invalid-url"},
}
_, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false)
if err == nil {
t.Fatalf("expected error for invalid base_url when allowlist disabled")
}
}
func TestOpenAIValidateUpstreamBaseURLDisabledSkipsValidation(t *testing.T) {
cfg := &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{Enabled: false},
},
}
svc := &OpenAIGatewayService{cfg: cfg}
normalized, err := svc.validateUpstreamBaseURL("http://not-https.example.com")
if err != nil {
t.Fatalf("expected no error when allowlist disabled, got %v", err)
}
if normalized != "http://not-https.example.com" {
t.Fatalf("expected raw url passthrough, got %q", normalized)
}
}
func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) {
cfg := &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
Enabled: true,
UpstreamHosts: []string{"example.com"},
},
},
}
svc := &OpenAIGatewayService{cfg: cfg}
if _, err := svc.validateUpstreamBaseURL("https://example.com"); err != nil {
t.Fatalf("expected allowlisted host to pass, got %v", err)
}
if _, err := svc.validateUpstreamBaseURL("https://evil.com"); err == nil {
t.Fatalf("expected non-allowlisted host to fail")
}
}

View File

@@ -409,6 +409,9 @@ func (s *PricingService) fetchRemoteHash() (string, error) {
}
func (s *PricingService) validatePricingURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
return strings.TrimSpace(raw), nil
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.PricingHosts,
RequireAllowlist: true,

View File

@@ -42,6 +42,9 @@ var hopByHopHeaders = map[string]struct{}{
}
func FilterHeaders(src http.Header, cfg config.ResponseHeaderConfig) http.Header {
if !cfg.Enabled {
return passThroughHeaders(src)
}
allowed := make(map[string]struct{}, len(defaultAllowed)+len(cfg.AdditionalAllowed))
for key := range defaultAllowed {
allowed[key] = struct{}{}
@@ -91,3 +94,17 @@ func WriteFilteredHeaders(dst http.Header, src http.Header, cfg config.ResponseH
}
}
}
func passThroughHeaders(src http.Header) http.Header {
filtered := make(http.Header, len(src))
for key, values := range src {
lower := strings.ToLower(key)
if _, isHopByHop := hopByHopHeaders[lower]; isHopByHop {
continue
}
for _, value := range values {
filtered.Add(key, value)
}
}
return filtered
}

View File

@@ -0,0 +1,67 @@
package responseheaders
import (
"net/http"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
)
func TestFilterHeadersDisabledPassThrough(t *testing.T) {
src := http.Header{}
src.Add("Content-Type", "application/json")
src.Add("X-Test", "ok")
src.Add("X-Remove", "keep")
src.Add("Connection", "keep-alive")
src.Add("Content-Length", "123")
cfg := config.ResponseHeaderConfig{
Enabled: false,
ForceRemove: []string{"x-test"},
}
filtered := FilterHeaders(src, cfg)
if filtered.Get("Content-Type") != "application/json" {
t.Fatalf("expected Content-Type passthrough, got %q", filtered.Get("Content-Type"))
}
if filtered.Get("X-Test") != "ok" {
t.Fatalf("expected X-Test passthrough, got %q", filtered.Get("X-Test"))
}
if filtered.Get("X-Remove") != "keep" {
t.Fatalf("expected X-Remove passthrough, got %q", filtered.Get("X-Remove"))
}
if filtered.Get("Connection") != "" {
t.Fatalf("expected Connection to be removed, got %q", filtered.Get("Connection"))
}
if filtered.Get("Content-Length") != "" {
t.Fatalf("expected Content-Length to be removed, got %q", filtered.Get("Content-Length"))
}
}
func TestFilterHeadersEnabledUsesAllowlist(t *testing.T) {
src := http.Header{}
src.Add("Content-Type", "application/json")
src.Add("X-Extra", "ok")
src.Add("X-Remove", "nope")
src.Add("X-Blocked", "nope")
cfg := config.ResponseHeaderConfig{
Enabled: true,
AdditionalAllowed: []string{"x-extra"},
ForceRemove: []string{"x-remove"},
}
filtered := FilterHeaders(src, cfg)
if filtered.Get("Content-Type") != "application/json" {
t.Fatalf("expected Content-Type allowed, got %q", filtered.Get("Content-Type"))
}
if filtered.Get("X-Extra") != "ok" {
t.Fatalf("expected X-Extra allowed, got %q", filtered.Get("X-Extra"))
}
if filtered.Get("X-Remove") != "" {
t.Fatalf("expected X-Remove removed, got %q", filtered.Get("X-Remove"))
}
if filtered.Get("X-Blocked") != "" {
t.Fatalf("expected X-Blocked removed, got %q", filtered.Get("X-Blocked"))
}
}