feat(安全): 强化安全策略与配置校验
- 增加 CORS/CSP/安全响应头与代理信任配置 - 引入 URL 白名单与私网开关,校验上游与价格源 - 改善 API Key 处理与网关错误返回 - 管理端设置隐藏敏感字段并优化前端提示 - 增加计费熔断与相关配置示例 测试: go test ./...
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
@@ -12,6 +15,8 @@ const (
|
||||
RunModeSimple = "simple"
|
||||
)
|
||||
|
||||
const DefaultCSPPolicy = "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self' https:; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
|
||||
|
||||
// 连接池隔离策略常量
|
||||
// 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗
|
||||
const (
|
||||
@@ -28,6 +33,10 @@ const (
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
CORS CORSConfig `mapstructure:"cors"`
|
||||
Security SecurityConfig `mapstructure:"security"`
|
||||
Billing BillingConfig `mapstructure:"billing"`
|
||||
Turnstile TurnstileConfig `mapstructure:"turnstile"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
@@ -81,11 +90,56 @@ type PricingConfig struct {
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
Mode string `mapstructure:"mode"` // debug/release
|
||||
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
|
||||
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
Mode string `mapstructure:"mode"` // debug/release
|
||||
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
|
||||
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
|
||||
TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP)
|
||||
}
|
||||
|
||||
type CORSConfig struct {
|
||||
AllowedOrigins []string `mapstructure:"allowed_origins"`
|
||||
AllowCredentials bool `mapstructure:"allow_credentials"`
|
||||
}
|
||||
|
||||
type SecurityConfig struct {
|
||||
URLAllowlist URLAllowlistConfig `mapstructure:"url_allowlist"`
|
||||
ResponseHeaders ResponseHeaderConfig `mapstructure:"response_headers"`
|
||||
CSP CSPConfig `mapstructure:"csp"`
|
||||
ProxyProbe ProxyProbeConfig `mapstructure:"proxy_probe"`
|
||||
}
|
||||
|
||||
type URLAllowlistConfig struct {
|
||||
UpstreamHosts []string `mapstructure:"upstream_hosts"`
|
||||
PricingHosts []string `mapstructure:"pricing_hosts"`
|
||||
CRSHosts []string `mapstructure:"crs_hosts"`
|
||||
AllowPrivateHosts bool `mapstructure:"allow_private_hosts"`
|
||||
}
|
||||
|
||||
type ResponseHeaderConfig struct {
|
||||
AdditionalAllowed []string `mapstructure:"additional_allowed"`
|
||||
ForceRemove []string `mapstructure:"force_remove"`
|
||||
}
|
||||
|
||||
type CSPConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
Policy string `mapstructure:"policy"`
|
||||
}
|
||||
|
||||
type ProxyProbeConfig struct {
|
||||
InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"`
|
||||
}
|
||||
|
||||
type BillingConfig struct {
|
||||
CircuitBreaker CircuitBreakerConfig `mapstructure:"circuit_breaker"`
|
||||
}
|
||||
|
||||
type CircuitBreakerConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
FailureThreshold int `mapstructure:"failure_threshold"`
|
||||
ResetTimeoutSeconds int `mapstructure:"reset_timeout_seconds"`
|
||||
HalfOpenRequests int `mapstructure:"half_open_requests"`
|
||||
}
|
||||
|
||||
// GatewayConfig API网关相关配置
|
||||
@@ -192,6 +246,10 @@ type JWTConfig struct {
|
||||
ExpireHour int `mapstructure:"expire_hour"`
|
||||
}
|
||||
|
||||
type TurnstileConfig struct {
|
||||
Required bool `mapstructure:"required"`
|
||||
}
|
||||
|
||||
type DefaultConfig struct {
|
||||
AdminEmail string `mapstructure:"admin_email"`
|
||||
AdminPassword string `mapstructure:"admin_password"`
|
||||
@@ -242,11 +300,39 @@ func Load() (*Config, error) {
|
||||
}
|
||||
|
||||
cfg.RunMode = NormalizeRunMode(cfg.RunMode)
|
||||
cfg.Server.Mode = strings.ToLower(strings.TrimSpace(cfg.Server.Mode))
|
||||
if cfg.Server.Mode == "" {
|
||||
cfg.Server.Mode = "debug"
|
||||
}
|
||||
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
|
||||
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
|
||||
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
|
||||
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
|
||||
cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy)
|
||||
|
||||
if cfg.Server.Mode != "release" && cfg.JWT.Secret == "" {
|
||||
secret, err := generateJWTSecret(64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate jwt secret error: %w", err)
|
||||
}
|
||||
cfg.JWT.Secret = secret
|
||||
log.Println("Warning: JWT secret auto-generated for non-release mode. Do not use in production.")
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("validate config error: %w", err)
|
||||
}
|
||||
|
||||
if cfg.Server.Mode != "release" && cfg.JWT.Secret != "" && isWeakJWTSecret(cfg.JWT.Secret) {
|
||||
log.Println("Warning: JWT secret appears weak; use a 32+ character random secret in production.")
|
||||
}
|
||||
if len(cfg.Security.ResponseHeaders.AdditionalAllowed) > 0 || len(cfg.Security.ResponseHeaders.ForceRemove) > 0 {
|
||||
log.Printf("AUDIT: response header policy configured additional_allowed=%v force_remove=%v",
|
||||
cfg.Security.ResponseHeaders.AdditionalAllowed,
|
||||
cfg.Security.ResponseHeaders.ForceRemove,
|
||||
)
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
@@ -259,6 +345,39 @@ func setDefaults() {
|
||||
viper.SetDefault("server.mode", "debug")
|
||||
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
|
||||
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
|
||||
viper.SetDefault("server.trusted_proxies", []string{})
|
||||
|
||||
// CORS
|
||||
viper.SetDefault("cors.allowed_origins", []string{})
|
||||
viper.SetDefault("cors.allow_credentials", true)
|
||||
|
||||
// Security
|
||||
viper.SetDefault("security.url_allowlist.upstream_hosts", []string{
|
||||
"api.openai.com",
|
||||
"api.anthropic.com",
|
||||
"generativelanguage.googleapis.com",
|
||||
"cloudcode-pa.googleapis.com",
|
||||
"*.openai.azure.com",
|
||||
})
|
||||
viper.SetDefault("security.url_allowlist.pricing_hosts", []string{
|
||||
"raw.githubusercontent.com",
|
||||
})
|
||||
viper.SetDefault("security.url_allowlist.crs_hosts", []string{})
|
||||
viper.SetDefault("security.url_allowlist.allow_private_hosts", false)
|
||||
viper.SetDefault("security.response_headers.additional_allowed", []string{})
|
||||
viper.SetDefault("security.response_headers.force_remove", []string{})
|
||||
viper.SetDefault("security.csp.enabled", true)
|
||||
viper.SetDefault("security.csp.policy", DefaultCSPPolicy)
|
||||
viper.SetDefault("security.proxy_probe.insecure_skip_verify", false)
|
||||
|
||||
// Billing
|
||||
viper.SetDefault("billing.circuit_breaker.enabled", true)
|
||||
viper.SetDefault("billing.circuit_breaker.failure_threshold", 5)
|
||||
viper.SetDefault("billing.circuit_breaker.reset_timeout_seconds", 30)
|
||||
viper.SetDefault("billing.circuit_breaker.half_open_requests", 3)
|
||||
|
||||
// Turnstile
|
||||
viper.SetDefault("turnstile.required", false)
|
||||
|
||||
// Database
|
||||
viper.SetDefault("database.host", "localhost")
|
||||
@@ -284,7 +403,7 @@ func setDefaults() {
|
||||
viper.SetDefault("redis.min_idle_conns", 10)
|
||||
|
||||
// JWT
|
||||
viper.SetDefault("jwt.secret", "change-me-in-production")
|
||||
viper.SetDefault("jwt.secret", "")
|
||||
viper.SetDefault("jwt.expire_hour", 24)
|
||||
|
||||
// Default
|
||||
@@ -340,11 +459,39 @@ func setDefaults() {
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error {
|
||||
if c.JWT.Secret == "" {
|
||||
return fmt.Errorf("jwt.secret is required")
|
||||
if c.Server.Mode == "release" {
|
||||
if c.JWT.Secret == "" {
|
||||
return fmt.Errorf("jwt.secret is required in release mode")
|
||||
}
|
||||
if len(c.JWT.Secret) < 32 {
|
||||
return fmt.Errorf("jwt.secret must be at least 32 characters")
|
||||
}
|
||||
if isWeakJWTSecret(c.JWT.Secret) {
|
||||
return fmt.Errorf("jwt.secret is too weak")
|
||||
}
|
||||
}
|
||||
if c.JWT.Secret == "change-me-in-production" && c.Server.Mode == "release" {
|
||||
return fmt.Errorf("jwt.secret must be changed in production")
|
||||
if c.JWT.ExpireHour <= 0 {
|
||||
return fmt.Errorf("jwt.expire_hour must be positive")
|
||||
}
|
||||
if c.JWT.ExpireHour > 168 {
|
||||
return fmt.Errorf("jwt.expire_hour must be <= 168 (7 days)")
|
||||
}
|
||||
if c.JWT.ExpireHour > 24 {
|
||||
log.Printf("Warning: jwt.expire_hour is %d hours (> 24). Consider shorter expiration for security.", c.JWT.ExpireHour)
|
||||
}
|
||||
if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" {
|
||||
return fmt.Errorf("security.csp.policy is required when CSP is enabled")
|
||||
}
|
||||
if c.Billing.CircuitBreaker.Enabled {
|
||||
if c.Billing.CircuitBreaker.FailureThreshold <= 0 {
|
||||
return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive")
|
||||
}
|
||||
if c.Billing.CircuitBreaker.ResetTimeoutSeconds <= 0 {
|
||||
return fmt.Errorf("billing.circuit_breaker.reset_timeout_seconds must be positive")
|
||||
}
|
||||
if c.Billing.CircuitBreaker.HalfOpenRequests <= 0 {
|
||||
return fmt.Errorf("billing.circuit_breaker.half_open_requests must be positive")
|
||||
}
|
||||
}
|
||||
if c.Database.MaxOpenConns <= 0 {
|
||||
return fmt.Errorf("database.max_open_conns must be positive")
|
||||
@@ -414,6 +561,51 @@ func (c *Config) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeStringSlice(values []string) []string {
|
||||
if len(values) == 0 {
|
||||
return values
|
||||
}
|
||||
normalized := make([]string, 0, len(values))
|
||||
for _, v := range values {
|
||||
trimmed := strings.TrimSpace(v)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
normalized = append(normalized, trimmed)
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func isWeakJWTSecret(secret string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(secret))
|
||||
if lower == "" {
|
||||
return true
|
||||
}
|
||||
weak := map[string]struct{}{
|
||||
"change-me-in-production": {},
|
||||
"changeme": {},
|
||||
"secret": {},
|
||||
"password": {},
|
||||
"123456": {},
|
||||
"12345678": {},
|
||||
"admin": {},
|
||||
"jwt-secret": {},
|
||||
}
|
||||
_, exists := weak[lower]
|
||||
return exists
|
||||
}
|
||||
|
||||
func generateJWTSecret(byteLength int) (string, error) {
|
||||
if byteLength <= 0 {
|
||||
byteLength = 32
|
||||
}
|
||||
buf := make([]byte, byteLength)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
// GetServerAddress returns the server address (host:port) from config file or environment variable.
|
||||
// This is a lightweight function that can be used before full config validation,
|
||||
// such as during setup wizard startup.
|
||||
|
||||
Reference in New Issue
Block a user