diff --git a/README.md b/README.md
index c3a37e68..95c67986 100644
--- a/README.md
+++ b/README.md
@@ -268,6 +268,15 @@ default:
rate_multiplier: 1.0
```
+Additional security-related options are available in `config.yaml`:
+
+- `cors.allowed_origins` for CORS allowlist
+- `security.url_allowlist` for upstream/pricing/CRS host allowlists
+- `security.csp` to control Content-Security-Policy headers
+- `billing.circuit_breaker` to fail closed on billing errors
+- `server.trusted_proxies` to enable X-Forwarded-For parsing
+- `turnstile.required` to require Turnstile in release mode
+
```bash
# 6. Run the application
./sub2api
diff --git a/README_CN.md b/README_CN.md
index 8fe46c51..59436998 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -268,6 +268,15 @@ default:
rate_multiplier: 1.0
```
+`config.yaml` 还支持以下安全相关配置:
+
+- `cors.allowed_origins` 配置 CORS 白名单
+- `security.url_allowlist` 配置上游/价格数据/CRS 主机白名单
+- `security.csp` 配置 Content-Security-Policy
+- `billing.circuit_breaker` 计费异常时 fail-closed
+- `server.trusted_proxies` 启用可信代理解析 X-Forwarded-For
+- `turnstile.required` 在 release 模式强制启用 Turnstile
+
```bash
# 6. 运行应用
./sub2api
diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go
index d9e0c485..c9dc57bb 100644
--- a/backend/cmd/server/main.go
+++ b/backend/cmd/server/main.go
@@ -86,7 +86,8 @@ func main() {
func runSetupServer() {
r := gin.New()
r.Use(middleware.Recovery())
- r.Use(middleware.CORS())
+ r.Use(middleware.CORS(config.CORSConfig{}))
+ r.Use(middleware.SecurityHeaders(config.CSPConfig{Enabled: true, Policy: config.DefaultCSPPolicy}))
// Register setup routes
setup.RegisterRoutes(r)
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index 47cf5f3a..9f23c993 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -76,7 +76,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
dashboardHandler := admin.NewDashboardHandler(dashboardService)
accountRepository := repository.NewAccountRepository(client, db)
proxyRepository := repository.NewProxyRepository(client, db)
- proxyExitInfoProber := repository.NewProxyExitInfoProber()
+ proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber)
adminUserHandler := admin.NewUserHandler(adminService)
groupHandler := admin.NewGroupHandler(adminService)
@@ -101,10 +101,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
httpUpstream := repository.NewHTTPUpstream(configConfig)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
- accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream)
+ accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
- concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
- crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService)
+ concurrencyService := service.NewConcurrencyService(concurrencyCache)
+ crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
@@ -125,7 +125,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
- pricingRemoteClient := repository.NewPricingRemoteClient()
+ pricingRemoteClient := repository.NewPricingRemoteClient(configConfig)
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil {
return nil, err
@@ -136,10 +136,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
timingWheelService := service.ProvideTimingWheelService()
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
- geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService)
- gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService)
+ geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
+ gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
- openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
+ openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index 916efd87..1d8c64ef 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -2,7 +2,10 @@
package config
import (
+ "crypto/rand"
+ "encoding/hex"
"fmt"
+ "log"
"strings"
"time"
@@ -14,6 +17,8 @@ const (
RunModeSimple = "simple"
)
+const DefaultCSPPolicy = "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self' https:; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
+
// 连接池隔离策略常量
// 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗
const (
@@ -30,6 +35,10 @@ const (
type Config struct {
Server ServerConfig `mapstructure:"server"`
+ CORS CORSConfig `mapstructure:"cors"`
+ Security SecurityConfig `mapstructure:"security"`
+ Billing BillingConfig `mapstructure:"billing"`
+ Turnstile TurnstileConfig `mapstructure:"turnstile"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
JWT JWTConfig `mapstructure:"jwt"`
@@ -37,6 +46,7 @@ type Config struct {
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"`
+ Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
@@ -95,11 +105,61 @@ type PricingConfig struct {
}
type ServerConfig struct {
- Host string `mapstructure:"host"`
- Port int `mapstructure:"port"`
- Mode string `mapstructure:"mode"` // debug/release
- ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
- IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
+ Host string `mapstructure:"host"`
+ Port int `mapstructure:"port"`
+ Mode string `mapstructure:"mode"` // debug/release
+ ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
+ IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
+ TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP)
+}
+
+type CORSConfig struct {
+ AllowedOrigins []string `mapstructure:"allowed_origins"`
+ AllowCredentials bool `mapstructure:"allow_credentials"`
+}
+
+type SecurityConfig struct {
+ URLAllowlist URLAllowlistConfig `mapstructure:"url_allowlist"`
+ ResponseHeaders ResponseHeaderConfig `mapstructure:"response_headers"`
+ CSP CSPConfig `mapstructure:"csp"`
+ ProxyProbe ProxyProbeConfig `mapstructure:"proxy_probe"`
+}
+
+type URLAllowlistConfig struct {
+ UpstreamHosts []string `mapstructure:"upstream_hosts"`
+ PricingHosts []string `mapstructure:"pricing_hosts"`
+ CRSHosts []string `mapstructure:"crs_hosts"`
+ AllowPrivateHosts bool `mapstructure:"allow_private_hosts"`
+}
+
+type ResponseHeaderConfig struct {
+ AdditionalAllowed []string `mapstructure:"additional_allowed"`
+ ForceRemove []string `mapstructure:"force_remove"`
+}
+
+type CSPConfig struct {
+ Enabled bool `mapstructure:"enabled"`
+ Policy string `mapstructure:"policy"`
+}
+
+type ProxyProbeConfig struct {
+ InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"`
+}
+
+type BillingConfig struct {
+ CircuitBreaker CircuitBreakerConfig `mapstructure:"circuit_breaker"`
+}
+
+type CircuitBreakerConfig struct {
+ Enabled bool `mapstructure:"enabled"`
+ FailureThreshold int `mapstructure:"failure_threshold"`
+ ResetTimeoutSeconds int `mapstructure:"reset_timeout_seconds"`
+ HalfOpenRequests int `mapstructure:"half_open_requests"`
+}
+
+type ConcurrencyConfig struct {
+ // PingInterval: 并发等待期间的 SSE ping 间隔(秒)
+ PingInterval int `mapstructure:"ping_interval"`
}
// GatewayConfig API网关相关配置
@@ -134,6 +194,13 @@ type GatewayConfig struct {
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
+ // StreamDataIntervalTimeout: 流数据间隔超时(秒),0表示禁用
+ StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"`
+ // StreamKeepaliveInterval: 流式 keepalive 间隔(秒),0表示禁用
+ StreamKeepaliveInterval int `mapstructure:"stream_keepalive_interval"`
+ // MaxLineSize: 上游 SSE 单行最大字节数(0使用默认值)
+ MaxLineSize int `mapstructure:"max_line_size"`
+
// 是否记录上游错误响应体摘要(避免输出请求内容)
LogUpstreamErrorBody bool `mapstructure:"log_upstream_error_body"`
// 上游错误响应体记录最大字节数(超过会截断)
@@ -237,6 +304,10 @@ type JWTConfig struct {
ExpireHour int `mapstructure:"expire_hour"`
}
+type TurnstileConfig struct {
+ Required bool `mapstructure:"required"`
+}
+
type DefaultConfig struct {
AdminEmail string `mapstructure:"admin_email"`
AdminPassword string `mapstructure:"admin_password"`
@@ -287,11 +358,39 @@ func Load() (*Config, error) {
}
cfg.RunMode = NormalizeRunMode(cfg.RunMode)
+ cfg.Server.Mode = strings.ToLower(strings.TrimSpace(cfg.Server.Mode))
+ if cfg.Server.Mode == "" {
+ cfg.Server.Mode = "debug"
+ }
+ cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
+ cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
+ cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
+ cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
+ cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy)
+
+ if cfg.Server.Mode != "release" && cfg.JWT.Secret == "" {
+ secret, err := generateJWTSecret(64)
+ if err != nil {
+ return nil, fmt.Errorf("generate jwt secret error: %w", err)
+ }
+ cfg.JWT.Secret = secret
+ log.Println("Warning: JWT secret auto-generated for non-release mode. Do not use in production.")
+ }
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("validate config error: %w", err)
}
+ if cfg.Server.Mode != "release" && cfg.JWT.Secret != "" && isWeakJWTSecret(cfg.JWT.Secret) {
+ log.Println("Warning: JWT secret appears weak; use a 32+ character random secret in production.")
+ }
+ if len(cfg.Security.ResponseHeaders.AdditionalAllowed) > 0 || len(cfg.Security.ResponseHeaders.ForceRemove) > 0 {
+ log.Printf("AUDIT: response header policy configured additional_allowed=%v force_remove=%v",
+ cfg.Security.ResponseHeaders.AdditionalAllowed,
+ cfg.Security.ResponseHeaders.ForceRemove,
+ )
+ }
+
return &cfg, nil
}
@@ -304,6 +403,39 @@ func setDefaults() {
viper.SetDefault("server.mode", "debug")
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
+ viper.SetDefault("server.trusted_proxies", []string{})
+
+ // CORS
+ viper.SetDefault("cors.allowed_origins", []string{})
+ viper.SetDefault("cors.allow_credentials", true)
+
+ // Security
+ viper.SetDefault("security.url_allowlist.upstream_hosts", []string{
+ "api.openai.com",
+ "api.anthropic.com",
+ "generativelanguage.googleapis.com",
+ "cloudcode-pa.googleapis.com",
+ "*.openai.azure.com",
+ })
+ viper.SetDefault("security.url_allowlist.pricing_hosts", []string{
+ "raw.githubusercontent.com",
+ })
+ viper.SetDefault("security.url_allowlist.crs_hosts", []string{})
+ viper.SetDefault("security.url_allowlist.allow_private_hosts", false)
+ viper.SetDefault("security.response_headers.additional_allowed", []string{})
+ viper.SetDefault("security.response_headers.force_remove", []string{})
+ viper.SetDefault("security.csp.enabled", true)
+ viper.SetDefault("security.csp.policy", DefaultCSPPolicy)
+ viper.SetDefault("security.proxy_probe.insecure_skip_verify", false)
+
+ // Billing
+ viper.SetDefault("billing.circuit_breaker.enabled", true)
+ viper.SetDefault("billing.circuit_breaker.failure_threshold", 5)
+ viper.SetDefault("billing.circuit_breaker.reset_timeout_seconds", 30)
+ viper.SetDefault("billing.circuit_breaker.half_open_requests", 3)
+
+ // Turnstile
+ viper.SetDefault("turnstile.required", false)
// Database
viper.SetDefault("database.host", "localhost")
@@ -329,7 +461,7 @@ func setDefaults() {
viper.SetDefault("redis.min_idle_conns", 10)
// JWT
- viper.SetDefault("jwt.secret", "change-me-in-production")
+ viper.SetDefault("jwt.secret", "")
viper.SetDefault("jwt.expire_hour", 24)
// Default
@@ -357,7 +489,7 @@ func setDefaults() {
viper.SetDefault("timezone", "Asia/Shanghai")
// Gateway
- viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久
+ viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久
viper.SetDefault("gateway.log_upstream_error_body", false)
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
viper.SetDefault("gateway.inject_beta_for_apikey", false)
@@ -365,19 +497,23 @@ func setDefaults() {
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
- viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认)
- viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认)
- viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃,HTTP/2 场景默认)
- viper.SetDefault("gateway.idle_conn_timeout_seconds", 300) // 空闲连接超时(秒)
+ viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认)
+ viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认)
+ viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃,HTTP/2 场景默认)
+ viper.SetDefault("gateway.idle_conn_timeout_seconds", 90) // 空闲连接超时(秒)
viper.SetDefault("gateway.max_upstream_clients", 5000)
viper.SetDefault("gateway.client_idle_ttl_seconds", 900)
- viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 15) // 并发槽位过期时间(支持超长请求)
+ viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
+ viper.SetDefault("gateway.stream_data_interval_timeout", 180)
+ viper.SetDefault("gateway.stream_keepalive_interval", 10)
+ viper.SetDefault("gateway.max_line_size", 10*1024*1024)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second)
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
+ viper.SetDefault("concurrency.ping_interval", 10)
// TokenRefresh
viper.SetDefault("token_refresh.enabled", true)
@@ -396,11 +532,39 @@ func setDefaults() {
}
func (c *Config) Validate() error {
- if c.JWT.Secret == "" {
- return fmt.Errorf("jwt.secret is required")
+ if c.Server.Mode == "release" {
+ if c.JWT.Secret == "" {
+ return fmt.Errorf("jwt.secret is required in release mode")
+ }
+ if len(c.JWT.Secret) < 32 {
+ return fmt.Errorf("jwt.secret must be at least 32 characters")
+ }
+ if isWeakJWTSecret(c.JWT.Secret) {
+ return fmt.Errorf("jwt.secret is too weak")
+ }
}
- if c.JWT.Secret == "change-me-in-production" && c.Server.Mode == "release" {
- return fmt.Errorf("jwt.secret must be changed in production")
+ if c.JWT.ExpireHour <= 0 {
+ return fmt.Errorf("jwt.expire_hour must be positive")
+ }
+ if c.JWT.ExpireHour > 168 {
+ return fmt.Errorf("jwt.expire_hour must be <= 168 (7 days)")
+ }
+ if c.JWT.ExpireHour > 24 {
+ log.Printf("Warning: jwt.expire_hour is %d hours (> 24). Consider shorter expiration for security.", c.JWT.ExpireHour)
+ }
+ if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" {
+ return fmt.Errorf("security.csp.policy is required when CSP is enabled")
+ }
+ if c.Billing.CircuitBreaker.Enabled {
+ if c.Billing.CircuitBreaker.FailureThreshold <= 0 {
+ return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive")
+ }
+ if c.Billing.CircuitBreaker.ResetTimeoutSeconds <= 0 {
+ return fmt.Errorf("billing.circuit_breaker.reset_timeout_seconds must be positive")
+ }
+ if c.Billing.CircuitBreaker.HalfOpenRequests <= 0 {
+ return fmt.Errorf("billing.circuit_breaker.half_open_requests must be positive")
+ }
}
if c.Database.MaxOpenConns <= 0 {
return fmt.Errorf("database.max_open_conns must be positive")
@@ -458,6 +622,9 @@ func (c *Config) Validate() error {
if c.Gateway.IdleConnTimeoutSeconds <= 0 {
return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive")
}
+ if c.Gateway.IdleConnTimeoutSeconds > 180 {
+ log.Printf("Warning: gateway.idle_conn_timeout_seconds is %d (> 180). Consider 60-120 seconds for better connection reuse.", c.Gateway.IdleConnTimeoutSeconds)
+ }
if c.Gateway.MaxUpstreamClients <= 0 {
return fmt.Errorf("gateway.max_upstream_clients must be positive")
}
@@ -467,6 +634,26 @@ func (c *Config) Validate() error {
if c.Gateway.ConcurrencySlotTTLMinutes <= 0 {
return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive")
}
+ if c.Gateway.StreamDataIntervalTimeout < 0 {
+ return fmt.Errorf("gateway.stream_data_interval_timeout must be non-negative")
+ }
+ if c.Gateway.StreamDataIntervalTimeout != 0 &&
+ (c.Gateway.StreamDataIntervalTimeout < 30 || c.Gateway.StreamDataIntervalTimeout > 300) {
+ return fmt.Errorf("gateway.stream_data_interval_timeout must be 0 or between 30-300 seconds")
+ }
+ if c.Gateway.StreamKeepaliveInterval < 0 {
+ return fmt.Errorf("gateway.stream_keepalive_interval must be non-negative")
+ }
+ if c.Gateway.StreamKeepaliveInterval != 0 &&
+ (c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) {
+ return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds")
+ }
+ if c.Gateway.MaxLineSize < 0 {
+ return fmt.Errorf("gateway.max_line_size must be non-negative")
+ }
+ if c.Gateway.MaxLineSize != 0 && c.Gateway.MaxLineSize < 1024*1024 {
+ return fmt.Errorf("gateway.max_line_size must be at least 1MB")
+ }
if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 {
return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive")
}
@@ -482,9 +669,57 @@ func (c *Config) Validate() error {
if c.Gateway.Scheduling.SlotCleanupInterval < 0 {
return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative")
}
+ if c.Concurrency.PingInterval < 5 || c.Concurrency.PingInterval > 30 {
+ return fmt.Errorf("concurrency.ping_interval must be between 5-30 seconds")
+ }
return nil
}
+func normalizeStringSlice(values []string) []string {
+ if len(values) == 0 {
+ return values
+ }
+ normalized := make([]string, 0, len(values))
+ for _, v := range values {
+ trimmed := strings.TrimSpace(v)
+ if trimmed == "" {
+ continue
+ }
+ normalized = append(normalized, trimmed)
+ }
+ return normalized
+}
+
+func isWeakJWTSecret(secret string) bool {
+ lower := strings.ToLower(strings.TrimSpace(secret))
+ if lower == "" {
+ return true
+ }
+ weak := map[string]struct{}{
+ "change-me-in-production": {},
+ "changeme": {},
+ "secret": {},
+ "password": {},
+ "123456": {},
+ "12345678": {},
+ "admin": {},
+ "jwt-secret": {},
+ }
+ _, exists := weak[lower]
+ return exists
+}
+
+func generateJWTSecret(byteLength int) (string, error) {
+ if byteLength <= 0 {
+ byteLength = 32
+ }
+ buf := make([]byte, byteLength)
+ if _, err := rand.Read(buf); err != nil {
+ return "", err
+ }
+ return hex.EncodeToString(buf), nil
+}
+
// GetServerAddress returns the server address (host:port) from config file or environment variable.
// This is a lightweight function that can be used before full config validation,
// such as during setup wizard startup.
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index a52b06b4..9ce29785 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -1,8 +1,12 @@
package admin
import (
+ "log"
+ "time"
+
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -34,33 +38,33 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
}
response.Success(c, dto.SystemSettings{
- RegistrationEnabled: settings.RegistrationEnabled,
- EmailVerifyEnabled: settings.EmailVerifyEnabled,
- SMTPHost: settings.SMTPHost,
- SMTPPort: settings.SMTPPort,
- SMTPUsername: settings.SMTPUsername,
- SMTPPassword: settings.SMTPPassword,
- SMTPFrom: settings.SMTPFrom,
- SMTPFromName: settings.SMTPFromName,
- SMTPUseTLS: settings.SMTPUseTLS,
- TurnstileEnabled: settings.TurnstileEnabled,
- TurnstileSiteKey: settings.TurnstileSiteKey,
- TurnstileSecretKey: settings.TurnstileSecretKey,
- SiteName: settings.SiteName,
- SiteLogo: settings.SiteLogo,
- SiteSubtitle: settings.SiteSubtitle,
- APIBaseURL: settings.APIBaseURL,
- ContactInfo: settings.ContactInfo,
- DocURL: settings.DocURL,
- DefaultConcurrency: settings.DefaultConcurrency,
- DefaultBalance: settings.DefaultBalance,
- EnableModelFallback: settings.EnableModelFallback,
- FallbackModelAnthropic: settings.FallbackModelAnthropic,
- FallbackModelOpenAI: settings.FallbackModelOpenAI,
- FallbackModelGemini: settings.FallbackModelGemini,
- FallbackModelAntigravity: settings.FallbackModelAntigravity,
- EnableIdentityPatch: settings.EnableIdentityPatch,
- IdentityPatchPrompt: settings.IdentityPatchPrompt,
+ RegistrationEnabled: settings.RegistrationEnabled,
+ EmailVerifyEnabled: settings.EmailVerifyEnabled,
+ SMTPHost: settings.SMTPHost,
+ SMTPPort: settings.SMTPPort,
+ SMTPUsername: settings.SMTPUsername,
+ SMTPPasswordConfigured: settings.SMTPPasswordConfigured,
+ SMTPFrom: settings.SMTPFrom,
+ SMTPFromName: settings.SMTPFromName,
+ SMTPUseTLS: settings.SMTPUseTLS,
+ TurnstileEnabled: settings.TurnstileEnabled,
+ TurnstileSiteKey: settings.TurnstileSiteKey,
+ TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
+ SiteName: settings.SiteName,
+ SiteLogo: settings.SiteLogo,
+ SiteSubtitle: settings.SiteSubtitle,
+ APIBaseURL: settings.APIBaseURL,
+ ContactInfo: settings.ContactInfo,
+ DocURL: settings.DocURL,
+ DefaultConcurrency: settings.DefaultConcurrency,
+ DefaultBalance: settings.DefaultBalance,
+ EnableModelFallback: settings.EnableModelFallback,
+ FallbackModelAnthropic: settings.FallbackModelAnthropic,
+ FallbackModelOpenAI: settings.FallbackModelOpenAI,
+ FallbackModelGemini: settings.FallbackModelGemini,
+ FallbackModelAntigravity: settings.FallbackModelAntigravity,
+ EnableIdentityPatch: settings.EnableIdentityPatch,
+ IdentityPatchPrompt: settings.IdentityPatchPrompt,
})
}
@@ -117,6 +121,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return
}
+ previousSettings, err := h.settingService.GetAllSettings(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
// 验证参数
if req.DefaultConcurrency < 1 {
req.DefaultConcurrency = 1
@@ -193,6 +203,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return
}
+ h.auditSettingsUpdate(c, previousSettings, settings, req)
+
// 重新获取设置返回
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
@@ -201,36 +213,136 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
response.Success(c, dto.SystemSettings{
- RegistrationEnabled: updatedSettings.RegistrationEnabled,
- EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
- SMTPHost: updatedSettings.SMTPHost,
- SMTPPort: updatedSettings.SMTPPort,
- SMTPUsername: updatedSettings.SMTPUsername,
- SMTPPassword: updatedSettings.SMTPPassword,
- SMTPFrom: updatedSettings.SMTPFrom,
- SMTPFromName: updatedSettings.SMTPFromName,
- SMTPUseTLS: updatedSettings.SMTPUseTLS,
- TurnstileEnabled: updatedSettings.TurnstileEnabled,
- TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
- TurnstileSecretKey: updatedSettings.TurnstileSecretKey,
- SiteName: updatedSettings.SiteName,
- SiteLogo: updatedSettings.SiteLogo,
- SiteSubtitle: updatedSettings.SiteSubtitle,
- APIBaseURL: updatedSettings.APIBaseURL,
- ContactInfo: updatedSettings.ContactInfo,
- DocURL: updatedSettings.DocURL,
- DefaultConcurrency: updatedSettings.DefaultConcurrency,
- DefaultBalance: updatedSettings.DefaultBalance,
- EnableModelFallback: updatedSettings.EnableModelFallback,
- FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
- FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
- FallbackModelGemini: updatedSettings.FallbackModelGemini,
- FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
- EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
- IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
+ RegistrationEnabled: updatedSettings.RegistrationEnabled,
+ EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
+ SMTPHost: updatedSettings.SMTPHost,
+ SMTPPort: updatedSettings.SMTPPort,
+ SMTPUsername: updatedSettings.SMTPUsername,
+ SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured,
+ SMTPFrom: updatedSettings.SMTPFrom,
+ SMTPFromName: updatedSettings.SMTPFromName,
+ SMTPUseTLS: updatedSettings.SMTPUseTLS,
+ TurnstileEnabled: updatedSettings.TurnstileEnabled,
+ TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
+ TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
+ SiteName: updatedSettings.SiteName,
+ SiteLogo: updatedSettings.SiteLogo,
+ SiteSubtitle: updatedSettings.SiteSubtitle,
+ APIBaseURL: updatedSettings.APIBaseURL,
+ ContactInfo: updatedSettings.ContactInfo,
+ DocURL: updatedSettings.DocURL,
+ DefaultConcurrency: updatedSettings.DefaultConcurrency,
+ DefaultBalance: updatedSettings.DefaultBalance,
+ EnableModelFallback: updatedSettings.EnableModelFallback,
+ FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
+ FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
+ FallbackModelGemini: updatedSettings.FallbackModelGemini,
+ FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
+ EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
+ IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
})
}
+func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) {
+ if before == nil || after == nil {
+ return
+ }
+
+ changed := diffSettings(before, after, req)
+ if len(changed) == 0 {
+ return
+ }
+
+ subject, _ := middleware.GetAuthSubjectFromContext(c)
+ role, _ := middleware.GetUserRoleFromContext(c)
+ log.Printf("AUDIT: settings updated at=%s user_id=%d role=%s changed=%v",
+ time.Now().UTC().Format(time.RFC3339),
+ subject.UserID,
+ role,
+ changed,
+ )
+}
+
+func diffSettings(before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) []string {
+ changed := make([]string, 0, 20)
+ if before.RegistrationEnabled != after.RegistrationEnabled {
+ changed = append(changed, "registration_enabled")
+ }
+ if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
+ changed = append(changed, "email_verify_enabled")
+ }
+ if before.SMTPHost != after.SMTPHost {
+ changed = append(changed, "smtp_host")
+ }
+ if before.SMTPPort != after.SMTPPort {
+ changed = append(changed, "smtp_port")
+ }
+ if before.SMTPUsername != after.SMTPUsername {
+ changed = append(changed, "smtp_username")
+ }
+ if req.SMTPPassword != "" {
+ changed = append(changed, "smtp_password")
+ }
+ if before.SMTPFrom != after.SMTPFrom {
+ changed = append(changed, "smtp_from_email")
+ }
+ if before.SMTPFromName != after.SMTPFromName {
+ changed = append(changed, "smtp_from_name")
+ }
+ if before.SMTPUseTLS != after.SMTPUseTLS {
+ changed = append(changed, "smtp_use_tls")
+ }
+ if before.TurnstileEnabled != after.TurnstileEnabled {
+ changed = append(changed, "turnstile_enabled")
+ }
+ if before.TurnstileSiteKey != after.TurnstileSiteKey {
+ changed = append(changed, "turnstile_site_key")
+ }
+ if req.TurnstileSecretKey != "" {
+ changed = append(changed, "turnstile_secret_key")
+ }
+ if before.SiteName != after.SiteName {
+ changed = append(changed, "site_name")
+ }
+ if before.SiteLogo != after.SiteLogo {
+ changed = append(changed, "site_logo")
+ }
+ if before.SiteSubtitle != after.SiteSubtitle {
+ changed = append(changed, "site_subtitle")
+ }
+ if before.APIBaseURL != after.APIBaseURL {
+ changed = append(changed, "api_base_url")
+ }
+ if before.ContactInfo != after.ContactInfo {
+ changed = append(changed, "contact_info")
+ }
+ if before.DocURL != after.DocURL {
+ changed = append(changed, "doc_url")
+ }
+ if before.DefaultConcurrency != after.DefaultConcurrency {
+ changed = append(changed, "default_concurrency")
+ }
+ if before.DefaultBalance != after.DefaultBalance {
+ changed = append(changed, "default_balance")
+ }
+ if before.EnableModelFallback != after.EnableModelFallback {
+ changed = append(changed, "enable_model_fallback")
+ }
+ if before.FallbackModelAnthropic != after.FallbackModelAnthropic {
+ changed = append(changed, "fallback_model_anthropic")
+ }
+ if before.FallbackModelOpenAI != after.FallbackModelOpenAI {
+ changed = append(changed, "fallback_model_openai")
+ }
+ if before.FallbackModelGemini != after.FallbackModelGemini {
+ changed = append(changed, "fallback_model_gemini")
+ }
+ if before.FallbackModelAntigravity != after.FallbackModelAntigravity {
+ changed = append(changed, "fallback_model_antigravity")
+ }
+ return changed
+}
+
// TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host" binding:"required"`
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index 668fb2dc..4c50cedf 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -5,17 +5,17 @@ type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
- SMTPHost string `json:"smtp_host"`
- SMTPPort int `json:"smtp_port"`
- SMTPUsername string `json:"smtp_username"`
- SMTPPassword string `json:"smtp_password,omitempty"`
- SMTPFrom string `json:"smtp_from_email"`
- SMTPFromName string `json:"smtp_from_name"`
- SMTPUseTLS bool `json:"smtp_use_tls"`
+ SMTPHost string `json:"smtp_host"`
+ SMTPPort int `json:"smtp_port"`
+ SMTPUsername string `json:"smtp_username"`
+ SMTPPasswordConfigured bool `json:"smtp_password_configured"`
+ SMTPFrom string `json:"smtp_from_email"`
+ SMTPFromName string `json:"smtp_from_name"`
+ SMTPUseTLS bool `json:"smtp_use_tls"`
- TurnstileEnabled bool `json:"turnstile_enabled"`
- TurnstileSiteKey string `json:"turnstile_site_key"`
- TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"`
+ TurnstileEnabled bool `json:"turnstile_enabled"`
+ TurnstileSiteKey string `json:"turnstile_site_key"`
+ TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go
index 8247a0c3..de3cbad9 100644
--- a/backend/internal/handler/gateway_handler.go
+++ b/backend/internal/handler/gateway_handler.go
@@ -1,7 +1,6 @@
package handler
import (
- "bytes"
"context"
"encoding/json"
"errors"
@@ -12,8 +11,10 @@ import (
"strings"
"time"
+ "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
+ pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -21,10 +22,6 @@ import (
"github.com/gin-gonic/gin"
)
-const maxGatewayRequestBodyBytes int64 = 10 * 1024 * 1024 // 10MB
-
-var errEmptyRequestBody = errors.New("request body is empty")
-
// GatewayHandler handles API gateway requests
type GatewayHandler struct {
gatewayService *service.GatewayService
@@ -35,23 +32,6 @@ type GatewayHandler struct {
concurrencyHelper *ConcurrencyHelper
}
-func (h *GatewayHandler) recordUsageSync(apiKey *service.APIKey, subscription *service.UserSubscription, result *service.ForwardResult, usedAccount *service.Account) {
- // 计费属于关键数据:同步写入,避免 goroutine 异步导致进程崩溃时丢失使用量/扣费数据。
- // 使用独立 Background context,避免客户端取消请求导致计费中断。
- ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
- defer cancel()
-
- if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
- Result: result,
- APIKey: apiKey,
- User: apiKey.User,
- Account: usedAccount,
- Subscription: subscription,
- }); err != nil {
- log.Printf("Record usage failed: request_id=%s user=%d api_key=%d account=%d err=%v", result.RequestID, apiKey.UserID, apiKey.ID, usedAccount.ID, err)
- }
-}
-
// NewGatewayHandler creates a new GatewayHandler
func NewGatewayHandler(
gatewayService *service.GatewayService,
@@ -60,89 +40,22 @@ func NewGatewayHandler(
userService *service.UserService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
+ cfg *config.Config,
) *GatewayHandler {
+ pingInterval := time.Duration(0)
+ if cfg != nil {
+ pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
+ }
return &GatewayHandler{
gatewayService: gatewayService,
geminiCompatService: geminiCompatService,
antigravityGatewayService: antigravityGatewayService,
userService: userService,
billingCacheService: billingCacheService,
- concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude),
+ concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
}
}
-func parseGatewayRequestStream(r io.Reader, limit int64) (*service.ParsedRequest, error) {
- if r == nil {
- return nil, errEmptyRequestBody
- }
-
- var raw bytes.Buffer
- limited := io.LimitReader(r, limit+1)
- tee := io.TeeReader(limited, &raw)
- decoder := json.NewDecoder(tee)
-
- var req map[string]any
- if err := decoder.Decode(&req); err != nil {
- if errors.Is(err, io.EOF) {
- return nil, errEmptyRequestBody
- }
- if int64(raw.Len()) > limit {
- return nil, &http.MaxBytesError{Limit: limit}
- }
- return nil, err
- }
-
- // Ensure the body contains exactly one JSON value (allowing trailing whitespace).
- var extra any
- if err := decoder.Decode(&extra); err != io.EOF {
- if int64(raw.Len()) > limit {
- return nil, &http.MaxBytesError{Limit: limit}
- }
- if err == nil {
- return nil, fmt.Errorf("request body must contain a single JSON object")
- }
- return nil, err
- }
- if int64(raw.Len()) > limit {
- return nil, &http.MaxBytesError{Limit: limit}
- }
-
- parsed := &service.ParsedRequest{
- Body: raw.Bytes(),
- }
-
- if rawModel, exists := req["model"]; exists {
- model, ok := rawModel.(string)
- if !ok {
- return nil, fmt.Errorf("invalid model field type")
- }
- parsed.Model = model
- }
- if rawStream, exists := req["stream"]; exists {
- stream, ok := rawStream.(bool)
- if !ok {
- return nil, fmt.Errorf("invalid stream field type")
- }
- parsed.Stream = stream
- }
- if metadata, ok := req["metadata"].(map[string]any); ok {
- if userID, ok := metadata["user_id"].(string); ok {
- parsed.MetadataUserID = userID
- }
- }
- // system 字段只要存在就视为显式提供(即使为 null),
- // 以避免客户端传 null 时被默认 system 误注入。
- if system, ok := req["system"]; ok {
- parsed.HasSystem = true
- parsed.System = system
- }
- if messages, ok := req["messages"].([]any); ok {
- parsed.Messages = messages
- }
-
- return parsed, nil
-}
-
// Messages handles Claude API compatible messages endpoint
// POST /v1/messages
func (h *GatewayHandler) Messages(c *gin.Context) {
@@ -159,29 +72,27 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
- parsedReq, err := parseGatewayRequestStream(c.Request.Body, maxGatewayRequestBodyBytes)
+ // 读取请求体
+ body, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
- if errors.Is(err, errEmptyRequestBody) {
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
- return
- }
- var syntaxErr *json.SyntaxError
- var typeErr *json.UnmarshalTypeError
- if errors.As(err, &syntaxErr) || errors.As(err, &typeErr) || errors.Is(err, io.ErrUnexpectedEOF) {
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
- return
- }
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
- if len(parsedReq.Body) == 0 {
+
+ if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
+
+ parsedReq, err := service.ParseGatewayRequest(body)
+ if err != nil {
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
+ return
+ }
reqModel := parsedReq.Model
reqStream := parsedReq.Stream
@@ -217,6 +128,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
+ // 在请求结束或 Context 取消时确保释放槽位,避免客户端断开造成泄漏
+ userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
@@ -224,7 +137,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 2. 【新增】Wait后二次检查余额/订阅
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
- h.handleStreamingAwareError(c, http.StatusForbidden, "permission_error", "Insufficient balance or active subscription required", streamStarted)
+ status, code, message := billingErrorDetails(err)
+ h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
@@ -252,9 +166,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
if err != nil {
- log.Printf("Select account failed: %v", err)
if len(failedAccountIDs) == 0 {
- h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts for requested model", streamStarted)
+ h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
@@ -263,7 +176,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
account := selection.Account
// 检查预热请求拦截(在账号选择后、转发前检查)
- if account.IsInterceptWarmupEnabled() && isWarmupRequest(parsedReq.Body) {
+ if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
@@ -317,13 +230,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err)
}
}
+ // 账号槽位/等待计数需要在超时或断开时安全回收
+ accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
+ accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
if account.Platform == service.PlatformAntigravity {
- result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, parsedReq.Body)
+ result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, body)
} else {
- result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, parsedReq.Body)
+ result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
}
if accountReleaseFunc != nil {
accountReleaseFunc()
@@ -350,8 +266,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
- // 同步记录使用量,避免进程崩溃导致计费数据丢失(subscription已在函数开头获取)
- h.recordUsageSync(apiKey, subscription, result, account)
+ // 异步记录使用量(subscription已在函数开头获取)
+ go func(result *service.ForwardResult, usedAccount *service.Account) {
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
+ Result: result,
+ APIKey: apiKey,
+ User: apiKey.User,
+ Account: usedAccount,
+ Subscription: subscription,
+ }); err != nil {
+ log.Printf("Record usage failed: %v", err)
+ }
+ }(result, account)
return
}
}
@@ -365,9 +293,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 选择支持该模型的账号
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
if err != nil {
- log.Printf("Select account failed: %v", err)
if len(failedAccountIDs) == 0 {
- h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts for requested model", streamStarted)
+ h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
@@ -376,7 +303,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
account := selection.Account
// 检查预热请求拦截(在账号选择后、转发前检查)
- if account.IsInterceptWarmupEnabled() && isWarmupRequest(parsedReq.Body) {
+ if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
@@ -430,11 +357,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err)
}
}
+ // 账号槽位/等待计数需要在超时或断开时安全回收
+ accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
+ accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
if account.Platform == service.PlatformAntigravity {
- result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, parsedReq.Body)
+ result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body)
} else {
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq)
}
@@ -463,8 +393,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
- // 同步记录使用量,避免进程崩溃导致计费数据丢失(subscription已在函数开头获取)
- h.recordUsageSync(apiKey, subscription, result, account)
+ // 异步记录使用量(subscription已在函数开头获取)
+ go func(result *service.ForwardResult, usedAccount *service.Account) {
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
+ Result: result,
+ APIKey: apiKey,
+ User: apiKey.User,
+ Account: usedAccount,
+ Subscription: subscription,
+ }); err != nil {
+ log.Printf("Record usage failed: %v", err)
+ }
+ }(result, account)
return
}
}
@@ -640,71 +582,32 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int,
func (h *GatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
switch statusCode {
case 401:
- return http.StatusBadGateway, "api_error", "Upstream authentication failed, please contact administrator"
+ return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
case 403:
- return http.StatusBadGateway, "api_error", "Upstream access forbidden, please contact administrator"
+ return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
case 429:
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
case 529:
return http.StatusServiceUnavailable, "overloaded_error", "Upstream service overloaded, please retry later"
case 500, 502, 503, 504:
- return http.StatusBadGateway, "api_error", "Upstream service temporarily unavailable"
+ return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
default:
- return http.StatusBadGateway, "api_error", "Upstream request failed"
+ return http.StatusBadGateway, "upstream_error", "Upstream request failed"
}
}
-func normalizeAnthropicErrorType(errType string) string {
- switch errType {
- case "invalid_request_error",
- "authentication_error",
- "permission_error",
- "not_found_error",
- "rate_limit_error",
- "api_error",
- "overloaded_error":
- return errType
- case "billing_error":
- // Not an Anthropic-standard error type; map to the closest equivalent.
- return "permission_error"
- case "subscription_error":
- // Not an Anthropic-standard error type; map to the closest equivalent.
- return "permission_error"
- case "upstream_error":
- // Not an Anthropic-standard error type; keep clients compatible.
- return "api_error"
- default:
- return "api_error"
- }
-}
-
-const maxPublicErrorMessageLen = 512
-
-func sanitizePublicErrorMessage(message string) string {
- cleaned := strings.TrimSpace(message)
- cleaned = strings.ReplaceAll(cleaned, "\r", " ")
- cleaned = strings.ReplaceAll(cleaned, "\n", " ")
- if len(cleaned) > maxPublicErrorMessageLen {
- cleaned = cleaned[:maxPublicErrorMessageLen] + "..."
- }
- return cleaned
-}
-
// handleStreamingAwareError handles errors that may occur after streaming has started
func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
- normalizedType := normalizeAnthropicErrorType(errType)
- publicMessage := sanitizePublicErrorMessage(message)
-
if streamStarted {
// Stream already started, send error as SSE event then close
flusher, ok := c.Writer.(http.Flusher)
if ok {
- // Anthropic streaming spec: send `event: error` with JSON `data`.
+ // Send error event in SSE format with proper JSON marshaling
errorData := map[string]any{
"type": "error",
"error": map[string]string{
- "type": normalizedType,
- "message": publicMessage,
+ "type": errType,
+ "message": message,
},
}
jsonBytes, err := json.Marshal(errorData)
@@ -712,11 +615,8 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
_ = c.Error(err)
return
}
- if _, err := fmt.Fprintf(c.Writer, "event: error\n"); err != nil {
- _ = c.Error(err)
- return
- }
- if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", string(jsonBytes)); err != nil {
+ errorEvent := fmt.Sprintf("data: %s\n\n", string(jsonBytes))
+ if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err)
}
flusher.Flush()
@@ -725,19 +625,16 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
}
// Normal case: return JSON response with proper status code
- h.errorResponse(c, status, normalizedType, publicMessage)
+ h.errorResponse(c, status, errType, message)
}
// errorResponse 返回Claude API格式的错误响应
func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
- normalizedType := normalizeAnthropicErrorType(errType)
- publicMessage := sanitizePublicErrorMessage(message)
-
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{
- "type": normalizedType,
- "message": publicMessage,
+ "type": errType,
+ "message": message,
},
})
}
@@ -759,30 +656,28 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
- parsedReq, err := parseGatewayRequestStream(c.Request.Body, maxGatewayRequestBodyBytes)
+ // 读取请求体
+ body, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
- if errors.Is(err, errEmptyRequestBody) {
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
- return
- }
- var syntaxErr *json.SyntaxError
- var typeErr *json.UnmarshalTypeError
- if errors.As(err, &syntaxErr) || errors.As(err, &typeErr) || errors.Is(err, io.ErrUnexpectedEOF) {
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
- return
- }
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
- if len(parsedReq.Body) == 0 {
+
+ if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
+ parsedReq, err := service.ParseGatewayRequest(body)
+ if err != nil {
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
+ return
+ }
+
// 验证 model 必填
if parsedReq.Model == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
@@ -795,8 +690,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 校验 billing eligibility(订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
- log.Printf("Billing eligibility check failed: %v", err)
- h.errorResponse(c, http.StatusForbidden, "permission_error", "Insufficient balance or active subscription required")
+ status, code, message := billingErrorDetails(err)
+ h.errorResponse(c, status, code, message)
return
}
@@ -806,8 +701,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
if err != nil {
- log.Printf("Select account failed: %v", err)
- h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts for requested model")
+ h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
@@ -923,3 +817,18 @@ func sendMockWarmupResponse(c *gin.Context, model string) {
},
})
}
+
+func billingErrorDetails(err error) (status int, code, message string) {
+ if errors.Is(err, service.ErrBillingServiceUnavailable) {
+ msg := pkgerrors.Message(err)
+ if msg == "" {
+ msg = "Billing service temporarily unavailable. Please retry later."
+ }
+ return http.StatusServiceUnavailable, "billing_service_error", msg
+ }
+ msg := pkgerrors.Message(err)
+ if msg == "" {
+ msg = err.Error()
+ }
+ return http.StatusForbidden, "billing_error", msg
+}
diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go
index 9d2e4a9d..2eb3ac72 100644
--- a/backend/internal/handler/gateway_helper.go
+++ b/backend/internal/handler/gateway_helper.go
@@ -5,6 +5,7 @@ import (
"fmt"
"math/rand"
"net/http"
+ "sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -26,8 +27,8 @@ import (
const (
// maxConcurrencyWait 等待并发槽位的最大时间
maxConcurrencyWait = 30 * time.Second
- // pingInterval 流式响应等待时发送 ping 的间隔
- pingInterval = 15 * time.Second
+ // defaultPingInterval 流式响应等待时发送 ping 的默认间隔
+ defaultPingInterval = 10 * time.Second
// initialBackoff 初始退避时间
initialBackoff = 100 * time.Millisecond
// backoffMultiplier 退避时间乘数(指数退避)
@@ -44,6 +45,8 @@ const (
SSEPingFormatClaude SSEPingFormat = "data: {\"type\": \"ping\"}\n\n"
// SSEPingFormatNone indicates no ping should be sent (e.g., OpenAI has no ping spec)
SSEPingFormatNone SSEPingFormat = ""
+ // SSEPingFormatComment is an SSE comment ping for OpenAI/Codex CLI clients
+ SSEPingFormatComment SSEPingFormat = ":\n\n"
)
// ConcurrencyError represents a concurrency limit error with context
@@ -63,16 +66,38 @@ func (e *ConcurrencyError) Error() string {
type ConcurrencyHelper struct {
concurrencyService *service.ConcurrencyService
pingFormat SSEPingFormat
+ pingInterval time.Duration
}
// NewConcurrencyHelper creates a new ConcurrencyHelper
-func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat) *ConcurrencyHelper {
+func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat, pingInterval time.Duration) *ConcurrencyHelper {
+ if pingInterval <= 0 {
+ pingInterval = defaultPingInterval
+ }
return &ConcurrencyHelper{
concurrencyService: concurrencyService,
pingFormat: pingFormat,
+ pingInterval: pingInterval,
}
}
+// wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation.
+// 用于避免客户端断开或上游超时导致的并发槽位泄漏。
+func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() {
+ if releaseFunc == nil {
+ return nil
+ }
+ var once sync.Once
+ wrapped := func() {
+ once.Do(releaseFunc)
+ }
+ go func() {
+ <-ctx.Done()
+ wrapped()
+ }()
+ return wrapped
+}
+
// IncrementWaitCount increments the wait count for a user
func (h *ConcurrencyHelper) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
return h.concurrencyService.IncrementWaitCount(ctx, userID, maxWait)
@@ -174,7 +199,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
// Only create ping ticker if ping is needed
var pingCh <-chan time.Time
if needPing {
- pingTicker := time.NewTicker(pingInterval)
+ pingTicker := time.NewTicker(h.pingInterval)
defer pingTicker.Stop()
pingCh = pingTicker.C
}
diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go
index 2dbc7660..aa75e6c1 100644
--- a/backend/internal/handler/gemini_v1beta_handler.go
+++ b/backend/internal/handler/gemini_v1beta_handler.go
@@ -165,7 +165,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
subscription, _ := middleware.GetSubscriptionFromContext(c)
// For Gemini native API, do not send Claude-style ping frames.
- geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone)
+ geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0)
// 0) wait queue check
maxWait := service.CalculateMaxWait(authSubject.Concurrency)
@@ -185,13 +185,16 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
+ // 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
+ userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2) billing eligibility check (after wait)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
- googleError(c, http.StatusForbidden, err.Error())
+ status, _, message := billingErrorDetails(err)
+ googleError(c, status, message)
return
}
@@ -260,6 +263,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err)
}
}
+ // 账号槽位/等待计数需要在超时或断开时安全回收
+ accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
+ accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// 5) forward (根据平台分流)
var result *service.ForwardResult
@@ -373,7 +379,7 @@ func writeUpstreamResponse(c *gin.Context, res *service.UpstreamHTTPResult) {
}
for k, vv := range res.Headers {
// Avoid overriding content-length and hop-by-hop headers.
- if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Connection") {
+ if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Connection") || strings.EqualFold(k, "Www-Authenticate") {
continue
}
for _, v := range vv {
diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go
index 518cd10a..04d268a5 100644
--- a/backend/internal/handler/openai_gateway_handler.go
+++ b/backend/internal/handler/openai_gateway_handler.go
@@ -10,6 +10,7 @@ import (
"net/http"
"time"
+ "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -29,11 +30,16 @@ func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
+ cfg *config.Config,
) *OpenAIGatewayHandler {
+ pingInterval := time.Duration(0)
+ if cfg != nil {
+ pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
+ }
return &OpenAIGatewayHandler{
gatewayService: gatewayService,
billingCacheService: billingCacheService,
- concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatNone),
+ concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
}
}
@@ -124,6 +130,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
+ // 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
+ userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
@@ -131,7 +139,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
- h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
+ status, code, message := billingErrorDetails(err)
+ h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
@@ -201,6 +210,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err)
}
}
+ // 账号槽位/等待计数需要在超时或断开时安全回收
+ accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
+ accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// Forward request
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
diff --git a/backend/internal/pkg/httpclient/pool.go b/backend/internal/pkg/httpclient/pool.go
index 8a81c09a..7bf5cff4 100644
--- a/backend/internal/pkg/httpclient/pool.go
+++ b/backend/internal/pkg/httpclient/pool.go
@@ -25,13 +25,14 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
+ "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
// Transport 连接池默认配置
const (
defaultMaxIdleConns = 100 // 最大空闲连接数
defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数
- defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间
+ defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间(建议小于上游 LB 超时)
)
// Options 定义共享 HTTP 客户端的构建参数
@@ -40,6 +41,9 @@ type Options struct {
Timeout time.Duration // 请求总超时时间
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
InsecureSkipVerify bool // 是否跳过 TLS 证书验证
+ ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
+ ValidateResolvedIP bool // 是否校验解析后的 IP(防止 DNS Rebinding)
+ AllowPrivateHosts bool // 允许私有地址解析(与 ValidateResolvedIP 一起使用)
// 可选的连接池参数(不设置则使用默认值)
MaxIdleConns int // 最大空闲连接总数(默认 100)
@@ -79,8 +83,12 @@ func buildClient(opts Options) (*http.Client, error) {
return nil, err
}
+ var rt http.RoundTripper = transport
+ if opts.ValidateResolvedIP && !opts.AllowPrivateHosts {
+ rt = &validatedTransport{base: transport}
+ }
return &http.Client{
- Transport: transport,
+ Transport: rt,
Timeout: opts.Timeout,
}, nil
}
@@ -126,13 +134,32 @@ func buildTransport(opts Options) (*http.Transport, error) {
}
func buildClientKey(opts Options) string {
- return fmt.Sprintf("%s|%s|%s|%t|%d|%d|%d",
+ return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%t|%d|%d|%d",
strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(),
opts.ResponseHeaderTimeout.String(),
opts.InsecureSkipVerify,
+ opts.ProxyStrict,
+ opts.ValidateResolvedIP,
+ opts.AllowPrivateHosts,
opts.MaxIdleConns,
opts.MaxIdleConnsPerHost,
opts.MaxConnsPerHost,
)
}
+
+type validatedTransport struct {
+ base http.RoundTripper
+}
+
+func (t *validatedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+ if req != nil && req.URL != nil {
+ host := strings.TrimSpace(req.URL.Hostname())
+ if host != "" {
+ if err := urlvalidator.ValidateResolvedIP(host); err != nil {
+ return nil, err
+ }
+ }
+ }
+ return t.base.RoundTrip(req)
+}
diff --git a/backend/internal/repository/claude_oauth_service.go b/backend/internal/repository/claude_oauth_service.go
index 35e7f535..677fce52 100644
--- a/backend/internal/repository/claude_oauth_service.go
+++ b/backend/internal/repository/claude_oauth_service.go
@@ -12,6 +12,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/Wei-Shaw/sub2api/internal/util/logredact"
"github.com/imroc/req/v3"
)
@@ -54,7 +55,7 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
return "", fmt.Errorf("request failed: %w", err)
}
- log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
+ log.Printf("[OAuth] Step 1 Response - Status: %d", resp.StatusCode)
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
@@ -84,8 +85,8 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
"code_challenge_method": "S256",
}
- reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
+ reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
var result struct {
@@ -113,7 +114,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
return "", fmt.Errorf("request failed: %w", err)
}
- log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
+ log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
@@ -141,7 +142,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
fullCode = authCode + "#" + responseState
}
- log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", prefix(authCode, 20))
+ log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code")
return fullCode, nil
}
@@ -173,8 +174,8 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds
}
- reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
+ reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
var tokenResp oauth.TokenResponse
@@ -191,7 +192,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
return nil, fmt.Errorf("request failed: %w", err)
}
- log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
+ log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
@@ -245,13 +246,3 @@ func createReqClient(proxyURL string) *req.Client {
return client
}
-
-func prefix(s string, n int) string {
- if n <= 0 {
- return ""
- }
- if len(s) <= n {
- return s
- }
- return s[:n]
-}
diff --git a/backend/internal/repository/claude_usage_service.go b/backend/internal/repository/claude_usage_service.go
index 424d1a9a..4c87b2de 100644
--- a/backend/internal/repository/claude_usage_service.go
+++ b/backend/internal/repository/claude_usage_service.go
@@ -15,7 +15,8 @@ import (
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
type claudeUsageService struct {
- usageURL string
+ usageURL string
+ allowPrivateHosts bool
}
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
@@ -24,8 +25,10 @@ func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
client, err := httpclient.GetClient(httpclient.Options{
- ProxyURL: proxyURL,
- Timeout: 30 * time.Second,
+ ProxyURL: proxyURL,
+ Timeout: 30 * time.Second,
+ ValidateResolvedIP: true,
+ AllowPrivateHosts: s.allowPrivateHosts,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}
diff --git a/backend/internal/repository/claude_usage_service_test.go b/backend/internal/repository/claude_usage_service_test.go
index c3570076..2e10f3e5 100644
--- a/backend/internal/repository/claude_usage_service_test.go
+++ b/backend/internal/repository/claude_usage_service_test.go
@@ -45,7 +45,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
}`)
}))
- s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
+ s.fetcher = &claudeUsageService{
+ usageURL: s.srv.URL,
+ allowPrivateHosts: true,
+ }
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
require.NoError(s.T(), err, "FetchUsage")
@@ -64,7 +67,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
_, _ = io.WriteString(w, "nope")
}))
- s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
+ s.fetcher = &claudeUsageService{
+ usageURL: s.srv.URL,
+ allowPrivateHosts: true,
+ }
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.Error(s.T(), err)
@@ -78,7 +84,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
_, _ = io.WriteString(w, "not-json")
}))
- s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
+ s.fetcher = &claudeUsageService{
+ usageURL: s.srv.URL,
+ allowPrivateHosts: true,
+ }
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.Error(s.T(), err)
@@ -91,7 +100,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
<-r.Context().Done()
}))
- s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
+ s.fetcher = &claudeUsageService{
+ usageURL: s.srv.URL,
+ allowPrivateHosts: true,
+ }
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
diff --git a/backend/internal/repository/github_release_service.go b/backend/internal/repository/github_release_service.go
index 3fa4b1ff..dd53c091 100644
--- a/backend/internal/repository/github_release_service.go
+++ b/backend/internal/repository/github_release_service.go
@@ -14,18 +14,23 @@ import (
)
type githubReleaseClient struct {
- httpClient *http.Client
+ httpClient *http.Client
+ allowPrivateHosts bool
}
func NewGitHubReleaseClient() service.GitHubReleaseClient {
+ allowPrivate := false
sharedClient, err := httpclient.GetClient(httpclient.Options{
- Timeout: 30 * time.Second,
+ Timeout: 30 * time.Second,
+ ValidateResolvedIP: true,
+ AllowPrivateHosts: allowPrivate,
})
if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second}
}
return &githubReleaseClient{
- httpClient: sharedClient,
+ httpClient: sharedClient,
+ allowPrivateHosts: allowPrivate,
}
}
@@ -64,7 +69,9 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
}
downloadClient, err := httpclient.GetClient(httpclient.Options{
- Timeout: 10 * time.Minute,
+ Timeout: 10 * time.Minute,
+ ValidateResolvedIP: true,
+ AllowPrivateHosts: c.allowPrivateHosts,
})
if err != nil {
downloadClient = &http.Client{Timeout: 10 * time.Minute}
diff --git a/backend/internal/repository/github_release_service_test.go b/backend/internal/repository/github_release_service_test.go
index ea849d46..4eebe81d 100644
--- a/backend/internal/repository/github_release_service_test.go
+++ b/backend/internal/repository/github_release_service_test.go
@@ -37,6 +37,13 @@ func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return http.DefaultTransport.RoundTrip(newReq)
}
+func newTestGitHubReleaseClient() *githubReleaseClient {
+ return &githubReleaseClient{
+ httpClient: &http.Client{},
+ allowPrivateHosts: true,
+ }
+}
+
func (s *GitHubReleaseServiceSuite) SetupTest() {
s.tempDir = s.T().TempDir()
}
@@ -55,9 +62,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLeng
_, _ = w.Write(bytes.Repeat([]byte("a"), 100))
}))
- client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
- require.True(s.T(), ok, "type assertion failed")
- s.client = client
+ s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "file1.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
@@ -82,9 +87,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
}
}))
- client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
- require.True(s.T(), ok, "type assertion failed")
- s.client = client
+ s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "file2.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
@@ -108,9 +111,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
}
}))
- client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
- require.True(s.T(), ok, "type assertion failed")
- s.client = client
+ s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "file3.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 200)
@@ -127,9 +128,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
w.WriteHeader(http.StatusNotFound)
}))
- client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
- require.True(s.T(), ok, "type assertion failed")
- s.client = client
+ s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "notfound.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
@@ -145,9 +144,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
_, _ = w.Write([]byte("sum"))
}))
- client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
- require.True(s.T(), ok, "type assertion failed")
- s.client = client
+ s.client = newTestGitHubReleaseClient()
body, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
require.NoError(s.T(), err, "FetchChecksumFile")
@@ -159,9 +156,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
w.WriteHeader(http.StatusInternalServerError)
}))
- client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
- require.True(s.T(), ok, "type assertion failed")
- s.client = client
+ s.client = newTestGitHubReleaseClient()
_, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
require.Error(s.T(), err, "expected error for non-200")
@@ -172,9 +167,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
<-r.Context().Done()
}))
- client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
- require.True(s.T(), ok, "type assertion failed")
- s.client = client
+ s.client = newTestGitHubReleaseClient()
ctx, cancel := context.WithCancel(context.Background())
cancel()
@@ -185,9 +178,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() {
- client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
- require.True(s.T(), ok, "type assertion failed")
- s.client = client
+ s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "invalid.bin")
err := s.client.DownloadFile(context.Background(), "://invalid-url", dest, 100)
@@ -200,9 +191,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
_, _ = w.Write([]byte("content"))
}))
- client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
- require.True(s.T(), ok, "type assertion failed")
- s.client = client
+ s.client = newTestGitHubReleaseClient()
// Use a path that cannot be created (directory doesn't exist)
dest := filepath.Join(s.tempDir, "nonexistent", "subdir", "file.bin")
@@ -211,9 +200,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_InvalidURL() {
- client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
- require.True(s.T(), ok, "type assertion failed")
- s.client = client
+ s.client = newTestGitHubReleaseClient()
_, err := s.client.FetchChecksumFile(context.Background(), "://invalid-url")
require.Error(s.T(), err, "expected error for invalid URL")
@@ -247,6 +234,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
+ allowPrivateHosts: true,
}
release, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
@@ -266,6 +254,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
+ allowPrivateHosts: true,
}
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
@@ -283,6 +272,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
+ allowPrivateHosts: true,
}
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
@@ -298,6 +288,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
+ allowPrivateHosts: true,
}
ctx, cancel := context.WithCancel(context.Background())
@@ -312,9 +303,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() {
<-r.Context().Done()
}))
- client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
- require.True(s.T(), ok, "type assertion failed")
- s.client = client
+ s.client = newTestGitHubReleaseClient()
ctx, cancel := context.WithCancel(context.Background())
cancel()
diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go
index f0669979..21723d4a 100644
--- a/backend/internal/repository/http_upstream.go
+++ b/backend/internal/repository/http_upstream.go
@@ -15,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
// 默认配置常量
@@ -30,9 +31,9 @@ const (
// defaultMaxConnsPerHost: 默认每主机最大连接数(含活跃连接)
// 达到上限后新请求会等待,而非无限创建连接
defaultMaxConnsPerHost = 240
- // defaultIdleConnTimeout: 默认空闲连接超时时间(5分钟)
- // 超时后连接会被关闭,释放系统资源
- defaultIdleConnTimeout = 300 * time.Second
+ // defaultIdleConnTimeout: 默认空闲连接超时时间(90秒)
+ // 超时后连接会被关闭,释放系统资源(建议小于上游 LB 超时)
+ defaultIdleConnTimeout = 90 * time.Second
// defaultResponseHeaderTimeout: 默认等待响应头超时时间(5分钟)
// LLM 请求可能排队较久,需要较长超时
defaultResponseHeaderTimeout = 300 * time.Second
@@ -120,6 +121,10 @@ func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
// - 调用方必须关闭 resp.Body,否则会导致 inFlight 计数泄漏
// - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
+ if err := s.validateRequestHost(req); err != nil {
+ return nil, err
+ }
+
// 获取或创建对应的客户端,并标记请求占用
entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency)
if err != nil {
@@ -145,6 +150,37 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
return resp, nil
}
+func (s *httpUpstreamService) shouldValidateResolvedIP() bool {
+ if s.cfg == nil {
+ return false
+ }
+ return !s.cfg.Security.URLAllowlist.AllowPrivateHosts
+}
+
+func (s *httpUpstreamService) validateRequestHost(req *http.Request) error {
+ if !s.shouldValidateResolvedIP() {
+ return nil
+ }
+ if req == nil || req.URL == nil {
+ return errors.New("request url is nil")
+ }
+ host := strings.TrimSpace(req.URL.Hostname())
+ if host == "" {
+ return errors.New("request host is empty")
+ }
+ if err := urlvalidator.ValidateResolvedIP(host); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (s *httpUpstreamService) redirectChecker(req *http.Request, via []*http.Request) error {
+ if len(via) >= 10 {
+ return errors.New("stopped after 10 redirects")
+ }
+ return s.validateRequestHost(req)
+}
+
// acquireClient 获取或创建客户端,并标记为进行中请求
// 用于请求路径,避免在获取后被淘汰
func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
@@ -232,6 +268,9 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
return nil, fmt.Errorf("build transport: %w", err)
}
client := &http.Client{Transport: transport}
+ if s.shouldValidateResolvedIP() {
+ client.CheckRedirect = s.redirectChecker
+ }
entry := &upstreamClientEntry{
client: client,
proxyKey: proxyKey,
diff --git a/backend/internal/repository/http_upstream_test.go b/backend/internal/repository/http_upstream_test.go
index 241b490f..fbe44c5e 100644
--- a/backend/internal/repository/http_upstream_test.go
+++ b/backend/internal/repository/http_upstream_test.go
@@ -22,7 +22,13 @@ type HTTPUpstreamSuite struct {
// SetupTest 每个测试用例执行前的初始化
// 创建空配置,各测试用例可按需覆盖
func (s *HTTPUpstreamSuite) SetupTest() {
- s.cfg = &config.Config{}
+ s.cfg = &config.Config{
+ Security: config.SecurityConfig{
+ URLAllowlist: config.URLAllowlistConfig{
+ AllowPrivateHosts: true,
+ },
+ },
+ }
}
// newService 创建测试用的 httpUpstreamService 实例
diff --git a/backend/internal/repository/pricing_service.go b/backend/internal/repository/pricing_service.go
index 11f82fd3..0a6d0cd9 100644
--- a/backend/internal/repository/pricing_service.go
+++ b/backend/internal/repository/pricing_service.go
@@ -8,6 +8,7 @@ import (
"strings"
"time"
+ "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
@@ -16,9 +17,15 @@ type pricingRemoteClient struct {
httpClient *http.Client
}
-func NewPricingRemoteClient() service.PricingRemoteClient {
+func NewPricingRemoteClient(cfg *config.Config) service.PricingRemoteClient {
+ allowPrivate := false
+ if cfg != nil {
+ allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
+ }
sharedClient, err := httpclient.GetClient(httpclient.Options{
- Timeout: 30 * time.Second,
+ Timeout: 30 * time.Second,
+ ValidateResolvedIP: true,
+ AllowPrivateHosts: allowPrivate,
})
if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second}
diff --git a/backend/internal/repository/pricing_service_test.go b/backend/internal/repository/pricing_service_test.go
index 112c7eaa..6745ac58 100644
--- a/backend/internal/repository/pricing_service_test.go
+++ b/backend/internal/repository/pricing_service_test.go
@@ -6,6 +6,7 @@ import (
"net/http/httptest"
"testing"
+ "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
@@ -19,7 +20,13 @@ type PricingServiceSuite struct {
func (s *PricingServiceSuite) SetupTest() {
s.ctx = context.Background()
- client, ok := NewPricingRemoteClient().(*pricingRemoteClient)
+ client, ok := NewPricingRemoteClient(&config.Config{
+ Security: config.SecurityConfig{
+ URLAllowlist: config.URLAllowlistConfig{
+ AllowPrivateHosts: true,
+ },
+ },
+ }).(*pricingRemoteClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
}
diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go
index f5f625f9..b49b4efb 100644
--- a/backend/internal/repository/proxy_probe_service.go
+++ b/backend/internal/repository/proxy_probe_service.go
@@ -5,28 +5,48 @@ import (
"encoding/json"
"fmt"
"io"
+ "log"
"net/http"
"time"
+ "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
-func NewProxyExitInfoProber() service.ProxyExitInfoProber {
- return &proxyProbeService{ipInfoURL: defaultIPInfoURL}
+func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
+ insecure := false
+ allowPrivate := false
+ if cfg != nil {
+ insecure = cfg.Security.ProxyProbe.InsecureSkipVerify
+ allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
+ }
+ if insecure {
+ log.Printf("[ProxyProbe] Warning: TLS verification is disabled for proxy probing.")
+ }
+ return &proxyProbeService{
+ ipInfoURL: defaultIPInfoURL,
+ insecureSkipVerify: insecure,
+ allowPrivateHosts: allowPrivate,
+ }
}
const defaultIPInfoURL = "https://ipinfo.io/json"
type proxyProbeService struct {
- ipInfoURL string
+ ipInfoURL string
+ insecureSkipVerify bool
+ allowPrivateHosts bool
}
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 15 * time.Second,
- InsecureSkipVerify: true,
+ InsecureSkipVerify: s.insecureSkipVerify,
+ ProxyStrict: true,
+ ValidateResolvedIP: true,
+ AllowPrivateHosts: s.allowPrivateHosts,
})
if err != nil {
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)
diff --git a/backend/internal/repository/proxy_probe_service_test.go b/backend/internal/repository/proxy_probe_service_test.go
index e7270324..fe45adbb 100644
--- a/backend/internal/repository/proxy_probe_service_test.go
+++ b/backend/internal/repository/proxy_probe_service_test.go
@@ -20,7 +20,10 @@ type ProxyProbeServiceSuite struct {
func (s *ProxyProbeServiceSuite) SetupTest() {
s.ctx = context.Background()
- s.prober = &proxyProbeService{ipInfoURL: "http://ipinfo.test/json"}
+ s.prober = &proxyProbeService{
+ ipInfoURL: "http://ipinfo.test/json",
+ allowPrivateHosts: true,
+ }
}
func (s *ProxyProbeServiceSuite) TearDownTest() {
diff --git a/backend/internal/repository/turnstile_service.go b/backend/internal/repository/turnstile_service.go
index cf6083e2..89748cd3 100644
--- a/backend/internal/repository/turnstile_service.go
+++ b/backend/internal/repository/turnstile_service.go
@@ -22,7 +22,8 @@ type turnstileVerifier struct {
func NewTurnstileVerifier() service.TurnstileVerifier {
sharedClient, err := httpclient.GetClient(httpclient.Options{
- Timeout: 10 * time.Second,
+ Timeout: 10 * time.Second,
+ ValidateResolvedIP: true,
})
if err != nil {
sharedClient = &http.Client{Timeout: 10 * time.Second}
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index 6b8af91e..8a469661 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -294,13 +294,13 @@ func TestAPIContracts(t *testing.T) {
"smtp_host": "smtp.example.com",
"smtp_port": 587,
"smtp_username": "user",
- "smtp_password": "secret",
+ "smtp_password_configured": true,
"smtp_from_email": "no-reply@example.com",
"smtp_from_name": "Sub2API",
"smtp_use_tls": true,
"turnstile_enabled": true,
"turnstile_site_key": "site-key",
- "turnstile_secret_key": "secret-key",
+ "turnstile_secret_key_configured": true,
"site_name": "Sub2API",
"site_logo": "",
"site_subtitle": "Subtitle",
diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go
index 0239410a..a8740ecc 100644
--- a/backend/internal/server/http.go
+++ b/backend/internal/server/http.go
@@ -2,6 +2,7 @@
package server
import (
+ "log"
"net/http"
"time"
@@ -36,6 +37,15 @@ func ProvideRouter(
r := gin.New()
r.Use(middleware2.Recovery())
+ if len(cfg.Server.TrustedProxies) > 0 {
+ if err := r.SetTrustedProxies(cfg.Server.TrustedProxies); err != nil {
+ log.Printf("Failed to set trusted proxies: %v", err)
+ }
+ } else {
+ if err := r.SetTrustedProxies(nil); err != nil {
+ log.Printf("Failed to disable trusted proxies: %v", err)
+ }
+ }
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg)
}
diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go
index eb8c2aff..74ff8af3 100644
--- a/backend/internal/server/middleware/api_key_auth.go
+++ b/backend/internal/server/middleware/api_key_auth.go
@@ -19,6 +19,13 @@ func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionS
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
+ queryKey := strings.TrimSpace(c.Query("key"))
+ queryApiKey := strings.TrimSpace(c.Query("api_key"))
+ if queryKey != "" || queryApiKey != "" {
+ AbortWithError(c, 400, "api_key_in_query_deprecated", "API key in query parameter is deprecated. Please use Authorization header instead.")
+ return
+ }
+
// 尝试从Authorization header中提取API key (Bearer scheme)
authHeader := c.GetHeader("Authorization")
var apiKeyString string
@@ -41,19 +48,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
apiKeyString = c.GetHeader("x-goog-api-key")
}
- // 如果header中没有,尝试从query参数中提取(Google API key风格)
- if apiKeyString == "" {
- apiKeyString = c.Query("key")
- }
-
- // 兼容常见别名
- if apiKeyString == "" {
- apiKeyString = c.Query("api_key")
- }
-
// 如果所有header都没有API key
if apiKeyString == "" {
- AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter")
+ AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, or x-goog-api-key header")
return
}
diff --git a/backend/internal/server/middleware/api_key_auth_google.go b/backend/internal/server/middleware/api_key_auth_google.go
index f2796db4..c5afd7ef 100644
--- a/backend/internal/server/middleware/api_key_auth_google.go
+++ b/backend/internal/server/middleware/api_key_auth_google.go
@@ -22,6 +22,10 @@ func APIKeyAuthGoogle(apiKeyService *service.APIKeyService, cfg *config.Config)
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
+ if v := strings.TrimSpace(c.Query("api_key")); v != "" {
+ abortWithGoogleError(c, 400, "Query parameter api_key is deprecated. Use Authorization header or key instead.")
+ return
+ }
apiKeyString := extractAPIKeyFromRequest(c)
if apiKeyString == "" {
abortWithGoogleError(c, 401, "API key is required")
@@ -116,15 +120,18 @@ func extractAPIKeyFromRequest(c *gin.Context) string {
if v := strings.TrimSpace(c.GetHeader("x-goog-api-key")); v != "" {
return v
}
- if v := strings.TrimSpace(c.Query("key")); v != "" {
- return v
- }
- if v := strings.TrimSpace(c.Query("api_key")); v != "" {
- return v
+ if allowGoogleQueryKey(c.Request.URL.Path) {
+ if v := strings.TrimSpace(c.Query("key")); v != "" {
+ return v
+ }
}
return ""
}
+func allowGoogleQueryKey(path string) bool {
+ return strings.HasPrefix(path, "/v1beta") || strings.HasPrefix(path, "/antigravity/v1beta")
+}
+
func abortWithGoogleError(c *gin.Context, status int, message string) {
c.JSON(status, gin.H{
"error": gin.H{
diff --git a/backend/internal/server/middleware/api_key_auth_google_test.go b/backend/internal/server/middleware/api_key_auth_google_test.go
index 96dcdbb3..0ed5a4a2 100644
--- a/backend/internal/server/middleware/api_key_auth_google_test.go
+++ b/backend/internal/server/middleware/api_key_auth_google_test.go
@@ -109,6 +109,58 @@ func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
}
+func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ r := gin.New()
+ apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
+ getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
+ return nil, errors.New("should not be called")
+ },
+ })
+ r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
+ r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
+
+ req := httptest.NewRequest(http.MethodGet, "/v1beta/test?api_key=legacy", nil)
+ rec := httptest.NewRecorder()
+ r.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ var resp googleErrorResponse
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ require.Equal(t, http.StatusBadRequest, resp.Error.Code)
+ require.Equal(t, "Query parameter api_key is deprecated. Use Authorization header or key instead.", resp.Error.Message)
+ require.Equal(t, "INVALID_ARGUMENT", resp.Error.Status)
+}
+
+func TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ r := gin.New()
+ apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
+ getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
+ return &service.APIKey{
+ ID: 1,
+ Key: key,
+ Status: service.StatusActive,
+ User: &service.User{
+ ID: 123,
+ Status: service.StatusActive,
+ },
+ }, nil
+ },
+ })
+ cfg := &config.Config{RunMode: config.RunModeSimple}
+ r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
+ r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
+
+ req := httptest.NewRequest(http.MethodGet, "/v1beta/test?key=valid", nil)
+ rec := httptest.NewRecorder()
+ r.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+}
+
func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
gin.SetMode(gin.TestMode)
diff --git a/backend/internal/server/middleware/cors.go b/backend/internal/server/middleware/cors.go
index bc16279f..7d82f183 100644
--- a/backend/internal/server/middleware/cors.go
+++ b/backend/internal/server/middleware/cors.go
@@ -1,24 +1,103 @@
package middleware
import (
+ "log"
+ "net/http"
+ "strings"
+ "sync"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
+var corsWarningOnce sync.Once
+
// CORS 跨域中间件
-func CORS() gin.HandlerFunc {
+func CORS(cfg config.CORSConfig) gin.HandlerFunc {
+ allowedOrigins := normalizeOrigins(cfg.AllowedOrigins)
+ allowAll := false
+ for _, origin := range allowedOrigins {
+ if origin == "*" {
+ allowAll = true
+ break
+ }
+ }
+ wildcardWithSpecific := allowAll && len(allowedOrigins) > 1
+ if wildcardWithSpecific {
+ allowedOrigins = []string{"*"}
+ }
+ allowCredentials := cfg.AllowCredentials
+
+ corsWarningOnce.Do(func() {
+ if len(allowedOrigins) == 0 {
+ log.Println("Warning: CORS allowed_origins not configured; cross-origin requests will be rejected.")
+ }
+ if wildcardWithSpecific {
+ log.Println("Warning: CORS allowed_origins includes '*'; wildcard will take precedence over explicit origins.")
+ }
+ if allowAll && allowCredentials {
+ log.Println("Warning: CORS allowed_origins set to '*', disabling allow_credentials.")
+ }
+ })
+ if allowAll && allowCredentials {
+ allowCredentials = false
+ }
+
+ allowedSet := make(map[string]struct{}, len(allowedOrigins))
+ for _, origin := range allowedOrigins {
+ if origin == "" || origin == "*" {
+ continue
+ }
+ allowedSet[origin] = struct{}{}
+ }
+
return func(c *gin.Context) {
- // 设置允许跨域的响应头
- c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
- c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
+ origin := strings.TrimSpace(c.GetHeader("Origin"))
+ originAllowed := allowAll
+ if origin != "" && !allowAll {
+ _, originAllowed = allowedSet[origin]
+ }
+
+ if originAllowed {
+ if allowAll {
+ c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
+ } else if origin != "" {
+ c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
+ c.Writer.Header().Add("Vary", "Origin")
+ }
+ if allowCredentials {
+ c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
+ }
+ }
+
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
// 处理预检请求
- if c.Request.Method == "OPTIONS" {
- c.AbortWithStatus(204)
+ if c.Request.Method == http.MethodOptions {
+ if originAllowed {
+ c.AbortWithStatus(http.StatusNoContent)
+ } else {
+ c.AbortWithStatus(http.StatusForbidden)
+ }
return
}
c.Next()
}
}
+
+func normalizeOrigins(values []string) []string {
+ if len(values) == 0 {
+ return nil
+ }
+ normalized := make([]string, 0, len(values))
+ for _, value := range values {
+ trimmed := strings.TrimSpace(value)
+ if trimmed == "" {
+ continue
+ }
+ normalized = append(normalized, trimmed)
+ }
+ return normalized
+}
diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go
new file mode 100644
index 00000000..9fca0cd3
--- /dev/null
+++ b/backend/internal/server/middleware/security_headers.go
@@ -0,0 +1,26 @@
+package middleware
+
+import (
+ "strings"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/gin-gonic/gin"
+)
+
+// SecurityHeaders sets baseline security headers for all responses.
+func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
+ policy := strings.TrimSpace(cfg.Policy)
+ if policy == "" {
+ policy = config.DefaultCSPPolicy
+ }
+
+ return func(c *gin.Context) {
+ c.Header("X-Content-Type-Options", "nosniff")
+ c.Header("X-Frame-Options", "DENY")
+ c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
+ if cfg.Enabled {
+ c.Header("Content-Security-Policy", policy)
+ }
+ c.Next()
+ }
+}
diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go
index 6eebb6d8..15a1b325 100644
--- a/backend/internal/server/router.go
+++ b/backend/internal/server/router.go
@@ -24,7 +24,8 @@ func SetupRouter(
) *gin.Engine {
// 应用中间件
r.Use(middleware2.Logger())
- r.Use(middleware2.CORS())
+ r.Use(middleware2.CORS(cfg.CORS))
+ r.Use(middleware2.SecurityHeaders(cfg.Security.CSP))
// Serve embedded frontend if available
if web.HasEmbeddedFrontend() {
diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go
index 6573be4b..e49da48f 100644
--- a/backend/internal/service/account_test_service.go
+++ b/backend/internal/service/account_test_service.go
@@ -7,6 +7,7 @@ import (
"crypto/rand"
"encoding/hex"
"encoding/json"
+ "errors"
"fmt"
"io"
"log"
@@ -14,9 +15,11 @@ import (
"regexp"
"strings"
+ "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
+ "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
@@ -45,6 +48,7 @@ type AccountTestService struct {
geminiTokenProvider *GeminiTokenProvider
antigravityGatewayService *AntigravityGatewayService
httpUpstream HTTPUpstream
+ cfg *config.Config
}
// NewAccountTestService creates a new AccountTestService
@@ -53,15 +57,32 @@ func NewAccountTestService(
geminiTokenProvider *GeminiTokenProvider,
antigravityGatewayService *AntigravityGatewayService,
httpUpstream HTTPUpstream,
+ cfg *config.Config,
) *AccountTestService {
return &AccountTestService{
accountRepo: accountRepo,
geminiTokenProvider: geminiTokenProvider,
antigravityGatewayService: antigravityGatewayService,
httpUpstream: httpUpstream,
+ cfg: cfg,
}
}
+func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error) {
+ if s.cfg == nil {
+ return "", errors.New("config is not available")
+ }
+ normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
+ AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
+ RequireAllowlist: true,
+ AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
+ })
+ if err != nil {
+ return "", err
+ }
+ return normalized, nil
+}
+
// generateSessionString generates a Claude Code style session string
func generateSessionString() (string, error) {
bytes := make([]byte, 32)
@@ -183,11 +204,15 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
return s.sendErrorAndEnd(c, "No API key available")
}
- apiURL = account.GetBaseURL()
- if apiURL == "" {
- apiURL = "https://api.anthropic.com"
+ baseURL := account.GetBaseURL()
+ if baseURL == "" {
+ baseURL = "https://api.anthropic.com"
}
- apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages"
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
+ }
+ apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages"
} else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
@@ -300,7 +325,11 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
if baseURL == "" {
baseURL = "https://api.openai.com"
}
- apiURL = strings.TrimSuffix(baseURL, "/") + "/responses"
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
+ }
+ apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/responses"
} else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
@@ -480,10 +509,14 @@ func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, accou
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, err
+ }
// Use streamGenerateContent for real-time feedback
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse",
- strings.TrimRight(baseURL, "/"), modelID)
+ strings.TrimRight(normalizedBaseURL, "/"), modelID)
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(payload))
if err != nil {
@@ -515,7 +548,11 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun
if strings.TrimSpace(baseURL) == "" {
baseURL = geminicli.AIStudioBaseURL
}
- fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(baseURL, "/"), modelID)
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, err
+ }
+ fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(normalizedBaseURL, "/"), modelID)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
if err != nil {
@@ -544,7 +581,11 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT
}
wrappedBytes, _ := json.Marshal(wrapped)
- fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", geminicli.GeminiCliBaseURL)
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
+ if err != nil {
+ return nil, err
+ }
+ fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", normalizedBaseURL)
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(wrappedBytes))
if err != nil {
diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go
index 7776e4c3..7763bc40 100644
--- a/backend/internal/service/antigravity_gateway_service.go
+++ b/backend/internal/service/antigravity_gateway_service.go
@@ -11,6 +11,7 @@ import (
"log"
"net/http"
"strings"
+ "sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
@@ -1103,57 +1104,145 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
return nil, errors.New("streaming not supported")
}
- reader := bufio.NewReader(resp.Body)
+ // 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
+ scanner := bufio.NewScanner(resp.Body)
+ maxLineSize := defaultMaxLineSize
+ if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
+ maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
+ }
+ scanner.Buffer(make([]byte, 64*1024), maxLineSize)
usage := &ClaudeUsage{}
var firstTokenMs *int
+ type scanEvent struct {
+ line string
+ err error
+ }
+ // 独立 goroutine 读取上游,避免读取阻塞影响超时处理
+ events := make(chan scanEvent, 16)
+ done := make(chan struct{})
+ sendEvent := func(ev scanEvent) bool {
+ select {
+ case events <- ev:
+ return true
+ case <-done:
+ return false
+ }
+ }
+ var lastReadAt int64
+ atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
+ go func() {
+ defer close(events)
+ for scanner.Scan() {
+ atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
+ if !sendEvent(scanEvent{line: scanner.Text()}) {
+ return
+ }
+ }
+ if err := scanner.Err(); err != nil {
+ _ = sendEvent(scanEvent{err: err})
+ }
+ }()
+ defer close(done)
+
+ // 上游数据间隔超时保护(防止上游挂起长期占用连接)
+ streamInterval := time.Duration(0)
+ if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
+ streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
+ }
+ var intervalTicker *time.Ticker
+ if streamInterval > 0 {
+ intervalTicker = time.NewTicker(streamInterval)
+ defer intervalTicker.Stop()
+ }
+ var intervalCh <-chan time.Time
+ if intervalTicker != nil {
+ intervalCh = intervalTicker.C
+ }
+
+ // 仅发送一次错误事件,避免多次写入导致协议混乱
+ errorEventSent := false
+ sendErrorEvent := func(reason string) {
+ if errorEventSent {
+ return
+ }
+ errorEventSent = true
+ _, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
+ flusher.Flush()
+ }
+
for {
- line, err := reader.ReadString('\n')
- if len(line) > 0 {
+ select {
+ case ev, ok := <-events:
+ if !ok {
+ return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
+ }
+ if ev.err != nil {
+ if errors.Is(ev.err, bufio.ErrTooLong) {
+ log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
+ sendErrorEvent("response_too_large")
+ return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
+ }
+ sendErrorEvent("stream_read_error")
+ return nil, ev.err
+ }
+
+ line := ev.line
trimmed := strings.TrimRight(line, "\r\n")
if strings.HasPrefix(trimmed, "data:") {
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
if payload == "" || payload == "[DONE]" {
- _, _ = io.WriteString(c.Writer, line)
- flusher.Flush()
- } else {
- // 解包 v1internal 响应
- inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
- if parseErr == nil && inner != nil {
- payload = string(inner)
- }
-
- // 解析 usage
- var parsed map[string]any
- if json.Unmarshal(inner, &parsed) == nil {
- if u := extractGeminiUsage(parsed); u != nil {
- usage = u
- }
- }
-
- if firstTokenMs == nil {
- ms := int(time.Since(startTime).Milliseconds())
- firstTokenMs = &ms
- }
-
- _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", payload)
+ if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil {
+ sendErrorEvent("write_failed")
+ return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
+ }
flusher.Flush()
+ continue
}
- } else {
- _, _ = io.WriteString(c.Writer, line)
- flusher.Flush()
- }
- }
- if errors.Is(err, io.EOF) {
- break
- }
- if err != nil {
- return nil, err
+ // 解包 v1internal 响应
+ inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
+ if parseErr == nil && inner != nil {
+ payload = string(inner)
+ }
+
+ // 解析 usage
+ var parsed map[string]any
+ if json.Unmarshal(inner, &parsed) == nil {
+ if u := extractGeminiUsage(parsed); u != nil {
+ usage = u
+ }
+ }
+
+ if firstTokenMs == nil {
+ ms := int(time.Since(startTime).Milliseconds())
+ firstTokenMs = &ms
+ }
+
+ if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", payload); err != nil {
+ sendErrorEvent("write_failed")
+ return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
+ }
+ flusher.Flush()
+ continue
+ }
+
+ if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil {
+ sendErrorEvent("write_failed")
+ return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
+ }
+ flusher.Flush()
+
+ case <-intervalCh:
+ lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
+ if time.Since(lastRead) < streamInterval {
+ continue
+ }
+ log.Printf("Stream data interval timeout (antigravity)")
+ sendErrorEvent("stream_timeout")
+ return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
}
-
- return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
func (s *AntigravityGatewayService) handleGeminiNonStreamingResponse(c *gin.Context, resp *http.Response) (*ClaudeUsage, error) {
@@ -1292,7 +1381,13 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
processor := antigravity.NewStreamingProcessor(originalModel)
var firstTokenMs *int
- reader := bufio.NewReader(resp.Body)
+ // 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
+ scanner := bufio.NewScanner(resp.Body)
+ maxLineSize := defaultMaxLineSize
+ if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
+ maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
+ }
+ scanner.Buffer(make([]byte, 64*1024), maxLineSize)
// 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage
convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage {
@@ -1307,13 +1402,85 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
}
}
- for {
- line, err := reader.ReadString('\n')
- if err != nil && !errors.Is(err, io.EOF) {
- return nil, fmt.Errorf("stream read error: %w", err)
+ type scanEvent struct {
+ line string
+ err error
+ }
+ // 独立 goroutine 读取上游,避免读取阻塞影响超时处理
+ events := make(chan scanEvent, 16)
+ done := make(chan struct{})
+ sendEvent := func(ev scanEvent) bool {
+ select {
+ case events <- ev:
+ return true
+ case <-done:
+ return false
}
+ }
+ var lastReadAt int64
+ atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
+ go func() {
+ defer close(events)
+ for scanner.Scan() {
+ atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
+ if !sendEvent(scanEvent{line: scanner.Text()}) {
+ return
+ }
+ }
+ if err := scanner.Err(); err != nil {
+ _ = sendEvent(scanEvent{err: err})
+ }
+ }()
+ defer close(done)
- if len(line) > 0 {
+ streamInterval := time.Duration(0)
+ if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
+ streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
+ }
+ var intervalTicker *time.Ticker
+ if streamInterval > 0 {
+ intervalTicker = time.NewTicker(streamInterval)
+ defer intervalTicker.Stop()
+ }
+ var intervalCh <-chan time.Time
+ if intervalTicker != nil {
+ intervalCh = intervalTicker.C
+ }
+
+ // 仅发送一次错误事件,避免多次写入导致协议混乱
+ errorEventSent := false
+ sendErrorEvent := func(reason string) {
+ if errorEventSent {
+ return
+ }
+ errorEventSent = true
+ _, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
+ flusher.Flush()
+ }
+
+ for {
+ select {
+ case ev, ok := <-events:
+ if !ok {
+ // 发送结束事件
+ finalEvents, agUsage := processor.Finish()
+ if len(finalEvents) > 0 {
+ _, _ = c.Writer.Write(finalEvents)
+ flusher.Flush()
+ }
+ return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil
+ }
+ if ev.err != nil {
+ if errors.Is(ev.err, bufio.ErrTooLong) {
+ log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
+ sendErrorEvent("response_too_large")
+ return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, ev.err
+ }
+ sendErrorEvent("stream_read_error")
+ return nil, fmt.Errorf("stream read error: %w", ev.err)
+ }
+
+ line := ev.line
// 处理 SSE 行,转换为 Claude 格式
claudeEvents := processor.ProcessLine(strings.TrimRight(line, "\r\n"))
@@ -1328,23 +1495,21 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
if len(finalEvents) > 0 {
_, _ = c.Writer.Write(finalEvents)
}
+ sendErrorEvent("write_failed")
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, writeErr
}
flusher.Flush()
}
- }
- if errors.Is(err, io.EOF) {
- break
+ case <-intervalCh:
+ lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
+ if time.Since(lastRead) < streamInterval {
+ continue
+ }
+ log.Printf("Stream data interval timeout (antigravity)")
+ sendErrorEvent("stream_timeout")
+ return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
}
- // 发送结束事件
- finalEvents, agUsage := processor.Finish()
- if len(finalEvents) > 0 {
- _, _ = c.Writer.Write(finalEvents)
- flusher.Flush()
- }
-
- return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil
}
diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go
index 69765520..91551314 100644
--- a/backend/internal/service/auth_service.go
+++ b/backend/internal/service/auth_service.go
@@ -221,9 +221,33 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
// VerifyTurnstile 验证Turnstile token
func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error {
+ required := s.cfg != nil && s.cfg.Server.Mode == "release" && s.cfg.Turnstile.Required
+
+ if required {
+ if s.settingService == nil {
+ log.Println("[Auth] Turnstile required but settings service is not configured")
+ return ErrTurnstileNotConfigured
+ }
+ enabled := s.settingService.IsTurnstileEnabled(ctx)
+ secretConfigured := s.settingService.GetTurnstileSecretKey(ctx) != ""
+ if !enabled || !secretConfigured {
+ log.Printf("[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)", enabled, secretConfigured)
+ return ErrTurnstileNotConfigured
+ }
+ }
+
if s.turnstileService == nil {
+ if required {
+ log.Println("[Auth] Turnstile required but service not configured")
+ return ErrTurnstileNotConfigured
+ }
return nil // 服务未配置则跳过验证
}
+
+ if !required && s.settingService != nil && s.settingService.IsTurnstileEnabled(ctx) && s.settingService.GetTurnstileSecretKey(ctx) == "" {
+ log.Println("[Auth] Turnstile enabled but secret key not configured")
+ }
+
return s.turnstileService.VerifyToken(ctx, token, remoteIP)
}
diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go
index 86148b37..c09cafb9 100644
--- a/backend/internal/service/billing_cache_service.go
+++ b/backend/internal/service/billing_cache_service.go
@@ -16,7 +16,8 @@ import (
// 注:ErrInsufficientBalance在redeem_service.go中定义
// 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
var (
- ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
+ ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
+ ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
)
// subscriptionCacheData 订阅缓存数据结构(内部使用)
@@ -72,10 +73,11 @@ type cacheWriteTask struct {
// BillingCacheService 计费缓存服务
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
type BillingCacheService struct {
- cache BillingCache
- userRepo UserRepository
- subRepo UserSubscriptionRepository
- cfg *config.Config
+ cache BillingCache
+ userRepo UserRepository
+ subRepo UserSubscriptionRepository
+ cfg *config.Config
+ circuitBreaker *billingCircuitBreaker
cacheWriteChan chan cacheWriteTask
cacheWriteWg sync.WaitGroup
@@ -95,6 +97,7 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo
subRepo: subRepo,
cfg: cfg,
}
+ svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
svc.startCacheWriteWorkers()
return svc
}
@@ -450,6 +453,9 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
if s.cfg.RunMode == config.RunModeSimple {
return nil
}
+ if s.circuitBreaker != nil && !s.circuitBreaker.Allow() {
+ return ErrBillingServiceUnavailable
+ }
// 判断计费模式
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
@@ -465,9 +471,14 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userID int64) error {
balance, err := s.GetUserBalance(ctx, userID)
if err != nil {
- // 缓存/数据库错误,允许通过(降级处理)
- log.Printf("Warning: get user balance failed, allowing request: %v", err)
- return nil
+ if s.circuitBreaker != nil {
+ s.circuitBreaker.OnFailure(err)
+ }
+ log.Printf("ALERT: billing balance check failed for user %d: %v", userID, err)
+ return ErrBillingServiceUnavailable.WithCause(err)
+ }
+ if s.circuitBreaker != nil {
+ s.circuitBreaker.OnSuccess()
}
if balance <= 0 {
@@ -482,9 +493,14 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
// 获取订阅缓存数据
subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID)
if err != nil {
- // 缓存/数据库错误,降级使用传入的subscription进行检查
- log.Printf("Warning: get subscription cache failed, using fallback: %v", err)
- return s.checkSubscriptionLimitsFallback(subscription, group)
+ if s.circuitBreaker != nil {
+ s.circuitBreaker.OnFailure(err)
+ }
+ log.Printf("ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err)
+ return ErrBillingServiceUnavailable.WithCause(err)
+ }
+ if s.circuitBreaker != nil {
+ s.circuitBreaker.OnSuccess()
}
// 检查订阅状态
@@ -513,27 +529,133 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
return nil
}
-// checkSubscriptionLimitsFallback 降级检查订阅限额
-func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *UserSubscription, group *Group) error {
- if subscription == nil {
- return ErrSubscriptionInvalid
- }
+type billingCircuitBreakerState int
- if !subscription.IsActive() {
- return ErrSubscriptionInvalid
- }
+const (
+ billingCircuitClosed billingCircuitBreakerState = iota
+ billingCircuitOpen
+ billingCircuitHalfOpen
+)
- if !subscription.CheckDailyLimit(group, 0) {
- return ErrDailyLimitExceeded
- }
-
- if !subscription.CheckWeeklyLimit(group, 0) {
- return ErrWeeklyLimitExceeded
- }
-
- if !subscription.CheckMonthlyLimit(group, 0) {
- return ErrMonthlyLimitExceeded
- }
-
- return nil
+type billingCircuitBreaker struct {
+ mu sync.Mutex
+ state billingCircuitBreakerState
+ failures int
+ openedAt time.Time
+ failureThreshold int
+ resetTimeout time.Duration
+ halfOpenRequests int
+ halfOpenRemaining int
+}
+
+func newBillingCircuitBreaker(cfg config.CircuitBreakerConfig) *billingCircuitBreaker {
+ if !cfg.Enabled {
+ return nil
+ }
+ resetTimeout := time.Duration(cfg.ResetTimeoutSeconds) * time.Second
+ if resetTimeout <= 0 {
+ resetTimeout = 30 * time.Second
+ }
+ halfOpen := cfg.HalfOpenRequests
+ if halfOpen <= 0 {
+ halfOpen = 1
+ }
+ threshold := cfg.FailureThreshold
+ if threshold <= 0 {
+ threshold = 5
+ }
+ return &billingCircuitBreaker{
+ state: billingCircuitClosed,
+ failureThreshold: threshold,
+ resetTimeout: resetTimeout,
+ halfOpenRequests: halfOpen,
+ }
+}
+
+func (b *billingCircuitBreaker) Allow() bool {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+
+ switch b.state {
+ case billingCircuitClosed:
+ return true
+ case billingCircuitOpen:
+ if time.Since(b.openedAt) < b.resetTimeout {
+ return false
+ }
+ b.state = billingCircuitHalfOpen
+ b.halfOpenRemaining = b.halfOpenRequests
+ log.Printf("ALERT: billing circuit breaker entering half-open state")
+ fallthrough
+ case billingCircuitHalfOpen:
+ if b.halfOpenRemaining <= 0 {
+ return false
+ }
+ b.halfOpenRemaining--
+ return true
+ default:
+ return false
+ }
+}
+
+func (b *billingCircuitBreaker) OnFailure(err error) {
+ if b == nil {
+ return
+ }
+ b.mu.Lock()
+ defer b.mu.Unlock()
+
+ switch b.state {
+ case billingCircuitOpen:
+ return
+ case billingCircuitHalfOpen:
+ b.state = billingCircuitOpen
+ b.openedAt = time.Now()
+ b.halfOpenRemaining = 0
+ log.Printf("ALERT: billing circuit breaker opened after half-open failure: %v", err)
+ return
+ default:
+ b.failures++
+ if b.failures >= b.failureThreshold {
+ b.state = billingCircuitOpen
+ b.openedAt = time.Now()
+ b.halfOpenRemaining = 0
+ log.Printf("ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err)
+ }
+ }
+}
+
+func (b *billingCircuitBreaker) OnSuccess() {
+ if b == nil {
+ return
+ }
+ b.mu.Lock()
+ defer b.mu.Unlock()
+
+ previousState := b.state
+ previousFailures := b.failures
+
+ b.state = billingCircuitClosed
+ b.failures = 0
+ b.halfOpenRemaining = 0
+
+ // 只有状态真正发生变化时才记录日志
+ if previousState != billingCircuitClosed {
+ log.Printf("ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState))
+ } else if previousFailures > 0 {
+ log.Printf("INFO: billing circuit breaker failures reset from %d", previousFailures)
+ }
+}
+
+func circuitStateString(state billingCircuitBreakerState) string {
+ switch state {
+ case billingCircuitClosed:
+ return "closed"
+ case billingCircuitOpen:
+ return "open"
+ case billingCircuitHalfOpen:
+ return "half-open"
+ default:
+ return "unknown"
+ }
}
diff --git a/backend/internal/service/crs_sync_service.go b/backend/internal/service/crs_sync_service.go
index 1bf5a11e..759034e7 100644
--- a/backend/internal/service/crs_sync_service.go
+++ b/backend/internal/service/crs_sync_service.go
@@ -8,12 +8,13 @@ import (
"fmt"
"io"
"net/http"
- "net/url"
"strconv"
"strings"
"time"
+ "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
+ "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
type CRSSyncService struct {
@@ -22,6 +23,7 @@ type CRSSyncService struct {
oauthService *OAuthService
openaiOAuthService *OpenAIOAuthService
geminiOAuthService *GeminiOAuthService
+ cfg *config.Config
}
func NewCRSSyncService(
@@ -30,6 +32,7 @@ func NewCRSSyncService(
oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
+ cfg *config.Config,
) *CRSSyncService {
return &CRSSyncService{
accountRepo: accountRepo,
@@ -37,6 +40,7 @@ func NewCRSSyncService(
oauthService: oauthService,
openaiOAuthService: openaiOAuthService,
geminiOAuthService: geminiOAuthService,
+ cfg: cfg,
}
}
@@ -187,7 +191,10 @@ type crsGeminiAPIKeyAccount struct {
}
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
- baseURL, err := normalizeBaseURL(input.BaseURL)
+ if s.cfg == nil {
+ return nil, errors.New("config is not available")
+ }
+ baseURL, err := normalizeBaseURL(input.BaseURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
if err != nil {
return nil, err
}
@@ -196,7 +203,9 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
client, err := httpclient.GetClient(httpclient.Options{
- Timeout: 20 * time.Second,
+ Timeout: 20 * time.Second,
+ ValidateResolvedIP: true,
+ AllowPrivateHosts: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
client = &http.Client{Timeout: 20 * time.Second}
@@ -1055,17 +1064,18 @@ func mapCRSStatus(isActive bool, status string) string {
return "active"
}
-func normalizeBaseURL(raw string) (string, error) {
- trimmed := strings.TrimSpace(raw)
- if trimmed == "" {
- return "", errors.New("base_url is required")
+func normalizeBaseURL(raw string, allowlist []string, allowPrivate bool) (string, error) {
+ // 当 allowlist 为空时,不强制要求白名单(只进行基本的 URL 和 SSRF 验证)
+ requireAllowlist := len(allowlist) > 0
+ normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
+ AllowedHosts: allowlist,
+ RequireAllowlist: requireAllowlist,
+ AllowPrivate: allowPrivate,
+ })
+ if err != nil {
+ return "", fmt.Errorf("invalid base_url: %w", err)
}
- u, err := url.Parse(trimmed)
- if err != nil || u.Scheme == "" || u.Host == "" {
- return "", fmt.Errorf("invalid base_url: %s", trimmed)
- }
- u.Path = strings.TrimRight(u.Path, "/")
- return strings.TrimRight(u.String(), "/"), nil
+ return normalized, nil
}
// cleanBaseURL removes trailing suffix from base_url in credentials
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index c706fb80..5d39c01d 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -15,11 +15,14 @@ import (
"regexp"
"sort"
"strings"
+ "sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
+ "github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
+ "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
@@ -30,6 +33,7 @@ const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL = time.Hour // 粘性会话TTL
+ defaultMaxLineSize = 10 * 1024 * 1024
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
)
@@ -1342,7 +1346,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
targetURL := claudeAPIURL
if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL()
- targetURL = baseURL + "/v1/messages"
+ if baseURL != "" {
+ validatedURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, err
+ }
+ targetURL = validatedURL + "/v1/messages"
+ }
}
// OAuth账号:应用统一指纹
@@ -1711,51 +1721,133 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
var firstTokenMs *int
scanner := bufio.NewScanner(resp.Body)
// 设置更大的buffer以处理长行
- scanner.Buffer(make([]byte, 64*1024), 1024*1024)
+ maxLineSize := defaultMaxLineSize
+ if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
+ maxLineSize = s.cfg.Gateway.MaxLineSize
+ }
+ scanner.Buffer(make([]byte, 64*1024), maxLineSize)
+
+ type scanEvent struct {
+ line string
+ err error
+ }
+ // 独立 goroutine 读取上游,避免读取阻塞导致超时/keepalive无法处理
+ events := make(chan scanEvent, 16)
+ done := make(chan struct{})
+ sendEvent := func(ev scanEvent) bool {
+ select {
+ case events <- ev:
+ return true
+ case <-done:
+ return false
+ }
+ }
+ var lastReadAt int64
+ atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
+ go func() {
+ defer close(events)
+ for scanner.Scan() {
+ atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
+ if !sendEvent(scanEvent{line: scanner.Text()}) {
+ return
+ }
+ }
+ if err := scanner.Err(); err != nil {
+ _ = sendEvent(scanEvent{err: err})
+ }
+ }()
+ defer close(done)
+
+ streamInterval := time.Duration(0)
+ if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
+ streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
+ }
+ // 仅监控上游数据间隔超时,避免下游写入阻塞导致误判
+ var intervalTicker *time.Ticker
+ if streamInterval > 0 {
+ intervalTicker = time.NewTicker(streamInterval)
+ defer intervalTicker.Stop()
+ }
+ var intervalCh <-chan time.Time
+ if intervalTicker != nil {
+ intervalCh = intervalTicker.C
+ }
+
+ // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
+ errorEventSent := false
+ sendErrorEvent := func(reason string) {
+ if errorEventSent {
+ return
+ }
+ errorEventSent = true
+ _, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
+ flusher.Flush()
+ }
needModelReplace := originalModel != mappedModel
- for scanner.Scan() {
- line := scanner.Text()
- if line == "event: error" {
- return nil, errors.New("have error in stream")
- }
-
- // Extract data from SSE line (supports both "data: " and "data:" formats)
- if sseDataRe.MatchString(line) {
- data := sseDataRe.ReplaceAllString(line, "")
-
- // 如果有模型映射,替换响应中的model字段
- if needModelReplace {
- line = s.replaceModelInSSELine(line, mappedModel, originalModel)
+ for {
+ select {
+ case ev, ok := <-events:
+ if !ok {
+ return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
+ }
+ if ev.err != nil {
+ if errors.Is(ev.err, bufio.ErrTooLong) {
+ log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
+ sendErrorEvent("response_too_large")
+ return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
+ }
+ sendErrorEvent("stream_read_error")
+ return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
+ }
+ line := ev.line
+ if line == "event: error" {
+ return nil, errors.New("have error in stream")
}
- // 转发行
- if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
- return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
- }
- flusher.Flush()
+ // Extract data from SSE line (supports both "data: " and "data:" formats)
+ if sseDataRe.MatchString(line) {
+ data := sseDataRe.ReplaceAllString(line, "")
- // 记录首字时间:第一个有效的 content_block_delta 或 message_start
- if firstTokenMs == nil && data != "" && data != "[DONE]" {
- ms := int(time.Since(startTime).Milliseconds())
- firstTokenMs = &ms
+ // 如果有模型映射,替换响应中的model字段
+ if needModelReplace {
+ line = s.replaceModelInSSELine(line, mappedModel, originalModel)
+ }
+
+ // 转发行
+ if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
+ sendErrorEvent("write_failed")
+ return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
+ }
+ flusher.Flush()
+
+ // 记录首字时间:第一个有效的 content_block_delta 或 message_start
+ if firstTokenMs == nil && data != "" && data != "[DONE]" {
+ ms := int(time.Since(startTime).Milliseconds())
+ firstTokenMs = &ms
+ }
+ s.parseSSEUsage(data, usage)
+ } else {
+ // 非 data 行直接转发
+ if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
+ sendErrorEvent("write_failed")
+ return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
+ }
+ flusher.Flush()
}
- s.parseSSEUsage(data, usage)
- } else {
- // 非 data 行直接转发
- if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
- return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
+
+ case <-intervalCh:
+ lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
+ if time.Since(lastRead) < streamInterval {
+ continue
}
- flusher.Flush()
+ log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
+ sendErrorEvent("stream_timeout")
+ return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
}
- if err := scanner.Err(); err != nil {
- return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
- }
-
- return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
// replaceModelInSSELine 替换SSE数据行中的model字段
@@ -1860,12 +1952,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
- // 透传响应头
- for key, values := range resp.Header {
- for _, value := range values {
- c.Header(key, value)
- }
- }
+ responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
// 写入响应
c.Data(resp.StatusCode, "application/json", body)
@@ -2137,7 +2224,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL()
- targetURL = baseURL + "/v1/messages/count_tokens"
+ if baseURL != "" {
+ validatedURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, err
+ }
+ targetURL = validatedURL + "/v1/messages/count_tokens"
+ }
}
// OAuth 账号:应用统一指纹和重写 userID
@@ -2217,6 +2310,18 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
})
}
+func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
+ normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
+ AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
+ RequireAllowlist: true,
+ AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
+ })
+ if err != nil {
+ return "", fmt.Errorf("invalid base_url: %w", err)
+ }
+ return normalized, nil
+}
+
// GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {
diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go
index 99e5bdf3..38050eab 100644
--- a/backend/internal/service/gemini_messages_compat_service.go
+++ b/backend/internal/service/gemini_messages_compat_service.go
@@ -18,9 +18,12 @@ import (
"strings"
"time"
+ "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
+ "github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
+ "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
)
@@ -41,6 +44,7 @@ type GeminiMessagesCompatService struct {
rateLimitService *RateLimitService
httpUpstream HTTPUpstream
antigravityGatewayService *AntigravityGatewayService
+ cfg *config.Config
}
func NewGeminiMessagesCompatService(
@@ -51,6 +55,7 @@ func NewGeminiMessagesCompatService(
rateLimitService *RateLimitService,
httpUpstream HTTPUpstream,
antigravityGatewayService *AntigravityGatewayService,
+ cfg *config.Config,
) *GeminiMessagesCompatService {
return &GeminiMessagesCompatService{
accountRepo: accountRepo,
@@ -60,6 +65,7 @@ func NewGeminiMessagesCompatService(
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
antigravityGatewayService: antigravityGatewayService,
+ cfg: cfg,
}
}
@@ -230,6 +236,18 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit
return s.antigravityGatewayService
}
+func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
+ normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
+ AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
+ RequireAllowlist: true,
+ AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
+ })
+ if err != nil {
+ return "", fmt.Errorf("invalid base_url: %w", err)
+ }
+ return normalized, nil
+}
+
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
var accounts []Account
@@ -382,16 +400,20 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return nil, "", errors.New("gemini api_key not configured")
}
- baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
+ baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, "", err
+ }
action := "generateContent"
if req.Stream {
action = "streamGenerateContent"
}
- fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, action)
+ fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
if req.Stream {
fullURL += "?alt=sse"
}
@@ -428,7 +450,11 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if projectID != "" {
// Mode 1: Code Assist API
- fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, action)
+ baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
+ if err != nil {
+ return nil, "", err
+ }
+ fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), action)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -454,12 +480,16 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return upstreamReq, "x-request-id", nil
} else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
- baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
+ baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, "", err
+ }
- fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, action)
+ fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -700,12 +730,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return nil, "", errors.New("gemini api_key not configured")
}
- baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
+ baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, "", err
+ }
- fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, upstreamAction)
+ fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -737,7 +771,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if projectID != "" && !forceAIStudio {
// Mode 1: Code Assist API
- fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, upstreamAction)
+ baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
+ if err != nil {
+ return nil, "", err
+ }
+ fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -763,12 +801,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return upstreamReq, "x-request-id", nil
} else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
- baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
+ baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, "", err
+ }
- fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, upstreamAction)
+ fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -1702,6 +1744,8 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
_ = json.Unmarshal(respBody, &parsed)
}
+ responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
+
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
@@ -1823,11 +1867,15 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
return nil, errors.New("invalid path")
}
- baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
+ baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
- fullURL := strings.TrimRight(baseURL, "/") + path
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, err
+ }
+ fullURL := strings.TrimRight(normalizedBaseURL, "/") + path
var proxyURL string
if account.ProxyID != nil && account.Proxy != nil {
@@ -1866,9 +1914,14 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
+ wwwAuthenticate := resp.Header.Get("Www-Authenticate")
+ filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.cfg.Security.ResponseHeaders)
+ if wwwAuthenticate != "" {
+ filteredHeaders.Set("Www-Authenticate", wwwAuthenticate)
+ }
return &UpstreamHTTPResult{
StatusCode: resp.StatusCode,
- Headers: resp.Header.Clone(),
+ Headers: filteredHeaders,
Body: body,
}, nil
}
diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go
index 1dce9da9..48d31da9 100644
--- a/backend/internal/service/gemini_oauth_service.go
+++ b/backend/internal/service/gemini_oauth_service.go
@@ -1000,8 +1000,9 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
client, err := httpclient.GetClient(httpclient.Options{
- ProxyURL: strings.TrimSpace(proxyURL),
- Timeout: 30 * time.Second,
+ ProxyURL: strings.TrimSpace(proxyURL),
+ Timeout: 30 * time.Second,
+ ValidateResolvedIP: true,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}
diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go
index 02f58369..b9cf4b9e 100644
--- a/backend/internal/service/openai_gateway_service.go
+++ b/backend/internal/service/openai_gateway_service.go
@@ -16,9 +16,12 @@ import (
"sort"
"strconv"
"strings"
+ "sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
+ "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
)
@@ -630,10 +633,14 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
case AccountTypeAPIKey:
// API Key accounts use Platform API or custom base URL
baseURL := account.GetOpenAIBaseURL()
- if baseURL != "" {
- targetURL = baseURL + "/responses"
- } else {
+ if baseURL == "" {
targetURL = openaiPlatformAPIURL
+ } else {
+ validatedURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, err
+ }
+ targetURL = validatedURL + "/responses"
}
default:
targetURL = openaiPlatformAPIURL
@@ -775,48 +782,158 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
usage := &OpenAIUsage{}
var firstTokenMs *int
scanner := bufio.NewScanner(resp.Body)
- scanner.Buffer(make([]byte, 64*1024), 1024*1024)
+ maxLineSize := defaultMaxLineSize
+ if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
+ maxLineSize = s.cfg.Gateway.MaxLineSize
+ }
+ scanner.Buffer(make([]byte, 64*1024), maxLineSize)
+
+ type scanEvent struct {
+ line string
+ err error
+ }
+ // 独立 goroutine 读取上游,避免读取阻塞影响 keepalive/超时处理
+ events := make(chan scanEvent, 16)
+ done := make(chan struct{})
+ sendEvent := func(ev scanEvent) bool {
+ select {
+ case events <- ev:
+ return true
+ case <-done:
+ return false
+ }
+ }
+ var lastReadAt int64
+ atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
+ go func() {
+ defer close(events)
+ for scanner.Scan() {
+ atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
+ if !sendEvent(scanEvent{line: scanner.Text()}) {
+ return
+ }
+ }
+ if err := scanner.Err(); err != nil {
+ _ = sendEvent(scanEvent{err: err})
+ }
+ }()
+ defer close(done)
+
+ streamInterval := time.Duration(0)
+ if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
+ streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
+ }
+ // 仅监控上游数据间隔超时,不被下游写入阻塞影响
+ var intervalTicker *time.Ticker
+ if streamInterval > 0 {
+ intervalTicker = time.NewTicker(streamInterval)
+ defer intervalTicker.Stop()
+ }
+ var intervalCh <-chan time.Time
+ if intervalTicker != nil {
+ intervalCh = intervalTicker.C
+ }
+
+ keepaliveInterval := time.Duration(0)
+ if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
+ keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
+ }
+ // 下游 keepalive 仅用于防止代理空闲断开
+ var keepaliveTicker *time.Ticker
+ if keepaliveInterval > 0 {
+ keepaliveTicker = time.NewTicker(keepaliveInterval)
+ defer keepaliveTicker.Stop()
+ }
+ var keepaliveCh <-chan time.Time
+ if keepaliveTicker != nil {
+ keepaliveCh = keepaliveTicker.C
+ }
+ // 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
+ lastDataAt := time.Now()
+
+ // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
+ errorEventSent := false
+ sendErrorEvent := func(reason string) {
+ if errorEventSent {
+ return
+ }
+ errorEventSent = true
+ _, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
+ flusher.Flush()
+ }
needModelReplace := originalModel != mappedModel
- for scanner.Scan() {
- line := scanner.Text()
-
- // Extract data from SSE line (supports both "data: " and "data:" formats)
- if openaiSSEDataRe.MatchString(line) {
- data := openaiSSEDataRe.ReplaceAllString(line, "")
-
- // Replace model in response if needed
- if needModelReplace {
- line = s.replaceModelInSSELine(line, mappedModel, originalModel)
+ for {
+ select {
+ case ev, ok := <-events:
+ if !ok {
+ return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
+ }
+ if ev.err != nil {
+ if errors.Is(ev.err, bufio.ErrTooLong) {
+ log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
+ sendErrorEvent("response_too_large")
+ return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
+ }
+ sendErrorEvent("stream_read_error")
+ return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
}
- // Forward line
- if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
- return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
- }
- flusher.Flush()
+ line := ev.line
+ lastDataAt = time.Now()
- // Record first token time
- if firstTokenMs == nil && data != "" && data != "[DONE]" {
- ms := int(time.Since(startTime).Milliseconds())
- firstTokenMs = &ms
+ // Extract data from SSE line (supports both "data: " and "data:" formats)
+ if openaiSSEDataRe.MatchString(line) {
+ data := openaiSSEDataRe.ReplaceAllString(line, "")
+
+ // Replace model in response if needed
+ if needModelReplace {
+ line = s.replaceModelInSSELine(line, mappedModel, originalModel)
+ }
+
+ // Forward line
+ if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
+ sendErrorEvent("write_failed")
+ return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
+ }
+ flusher.Flush()
+
+ // Record first token time
+ if firstTokenMs == nil && data != "" && data != "[DONE]" {
+ ms := int(time.Since(startTime).Milliseconds())
+ firstTokenMs = &ms
+ }
+ s.parseSSEUsage(data, usage)
+ } else {
+ // Forward non-data lines as-is
+ if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
+ sendErrorEvent("write_failed")
+ return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
+ }
+ flusher.Flush()
}
- s.parseSSEUsage(data, usage)
- } else {
- // Forward non-data lines as-is
- if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
+
+ case <-intervalCh:
+ lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
+ if time.Since(lastRead) < streamInterval {
+ continue
+ }
+ log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
+ sendErrorEvent("stream_timeout")
+ return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
+
+ case <-keepaliveCh:
+ if time.Since(lastDataAt) < keepaliveInterval {
+ continue
+ }
+ if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
}
}
- if err := scanner.Err(); err != nil {
- return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
- }
-
- return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
@@ -911,18 +1028,25 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
- // Pass through headers
- for key, values := range resp.Header {
- for _, value := range values {
- c.Header(key, value)
- }
- }
+ responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
c.Data(resp.StatusCode, "application/json", body)
return usage, nil
}
+func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) {
+ normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
+ AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
+ RequireAllowlist: true,
+ AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
+ })
+ if err != nil {
+ return "", fmt.Errorf("invalid base_url: %w", err)
+ }
+ return normalized, nil
+}
+
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go
new file mode 100644
index 00000000..bcad7ac8
--- /dev/null
+++ b/backend/internal/service/openai_gateway_service_test.go
@@ -0,0 +1,90 @@
+package service
+
+import (
+ "bufio"
+ "errors"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/gin-gonic/gin"
+)
+
+func TestOpenAIStreamingTimeout(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ cfg := &config.Config{
+ Gateway: config.GatewayConfig{
+ StreamDataIntervalTimeout: 1,
+ StreamKeepaliveInterval: 0,
+ MaxLineSize: defaultMaxLineSize,
+ },
+ }
+ svc := &OpenAIGatewayService{cfg: cfg}
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
+
+ pr, pw := io.Pipe()
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: pr,
+ Header: http.Header{},
+ }
+
+ start := time.Now()
+ _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, start, "model", "model")
+ _ = pw.Close()
+ _ = pr.Close()
+
+ if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") {
+ t.Fatalf("expected stream timeout error, got %v", err)
+ }
+ if !strings.Contains(rec.Body.String(), "stream_timeout") {
+ t.Fatalf("expected stream_timeout SSE error, got %q", rec.Body.String())
+ }
+}
+
+func TestOpenAIStreamingTooLong(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ cfg := &config.Config{
+ Gateway: config.GatewayConfig{
+ StreamDataIntervalTimeout: 0,
+ StreamKeepaliveInterval: 0,
+ MaxLineSize: 64 * 1024,
+ },
+ }
+ svc := &OpenAIGatewayService{cfg: cfg}
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
+
+ pr, pw := io.Pipe()
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: pr,
+ Header: http.Header{},
+ }
+
+ go func() {
+ defer func() { _ = pw.Close() }()
+ // 写入超过 MaxLineSize 的单行数据,触发 ErrTooLong
+ payload := "data: " + strings.Repeat("a", 128*1024) + "\n"
+ _, _ = pw.Write([]byte(payload))
+ }()
+
+ _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 2}, time.Now(), "model", "model")
+ _ = pr.Close()
+
+ if !errors.Is(err, bufio.ErrTooLong) {
+ t.Fatalf("expected ErrTooLong, got %v", err)
+ }
+ if !strings.Contains(rec.Body.String(), "response_too_large") {
+ t.Fatalf("expected response_too_large SSE error, got %q", rec.Body.String())
+ }
+}
diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go
index bb050d0a..58a24c0d 100644
--- a/backend/internal/service/pricing_service.go
+++ b/backend/internal/service/pricing_service.go
@@ -16,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
+ "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
var (
@@ -211,16 +212,35 @@ func (s *PricingService) syncWithRemote() error {
// downloadPricingData 从远程下载价格数据
func (s *PricingService) downloadPricingData() error {
- log.Printf("[Pricing] Downloading from %s", s.cfg.Pricing.RemoteURL)
+ remoteURL, err := s.validatePricingURL(s.cfg.Pricing.RemoteURL)
+ if err != nil {
+ return err
+ }
+ log.Printf("[Pricing] Downloading from %s", remoteURL)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
- body, err := s.remoteClient.FetchPricingJSON(ctx, s.cfg.Pricing.RemoteURL)
+ var expectedHash string
+ if strings.TrimSpace(s.cfg.Pricing.HashURL) != "" {
+ expectedHash, err = s.fetchRemoteHash()
+ if err != nil {
+ return fmt.Errorf("fetch remote hash: %w", err)
+ }
+ }
+
+ body, err := s.remoteClient.FetchPricingJSON(ctx, remoteURL)
if err != nil {
return fmt.Errorf("download failed: %w", err)
}
+ if expectedHash != "" {
+ actualHash := sha256.Sum256(body)
+ if !strings.EqualFold(expectedHash, hex.EncodeToString(actualHash[:])) {
+ return fmt.Errorf("pricing hash mismatch")
+ }
+ }
+
// 解析JSON数据(使用灵活的解析方式)
data, err := s.parsePricingData(body)
if err != nil {
@@ -373,10 +393,31 @@ func (s *PricingService) useFallbackPricing() error {
// fetchRemoteHash 从远程获取哈希值
func (s *PricingService) fetchRemoteHash() (string, error) {
+ hashURL, err := s.validatePricingURL(s.cfg.Pricing.HashURL)
+ if err != nil {
+ return "", err
+ }
+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
- return s.remoteClient.FetchHashText(ctx, s.cfg.Pricing.HashURL)
+ hash, err := s.remoteClient.FetchHashText(ctx, hashURL)
+ if err != nil {
+ return "", err
+ }
+ return strings.TrimSpace(hash), nil
+}
+
+func (s *PricingService) validatePricingURL(raw string) (string, error) {
+ normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
+ AllowedHosts: s.cfg.Security.URLAllowlist.PricingHosts,
+ RequireAllowlist: true,
+ AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
+ })
+ if err != nil {
+ return "", fmt.Errorf("invalid pricing url: %w", err)
+ }
+ return normalized, nil
}
// computeFileHash 计算文件哈希
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index a331594e..6ce8ba2b 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -228,21 +228,23 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// parseSettings 解析设置到结构体
func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings {
result := &SystemSettings{
- RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
- EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
- SMTPHost: settings[SettingKeySMTPHost],
- SMTPUsername: settings[SettingKeySMTPUsername],
- SMTPFrom: settings[SettingKeySMTPFrom],
- SMTPFromName: settings[SettingKeySMTPFromName],
- SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true",
- TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
- TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
- SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
- SiteLogo: settings[SettingKeySiteLogo],
- SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
- APIBaseURL: settings[SettingKeyAPIBaseURL],
- ContactInfo: settings[SettingKeyContactInfo],
- DocURL: settings[SettingKeyDocURL],
+ RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
+ EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
+ SMTPHost: settings[SettingKeySMTPHost],
+ SMTPUsername: settings[SettingKeySMTPUsername],
+ SMTPFrom: settings[SettingKeySMTPFrom],
+ SMTPFromName: settings[SettingKeySMTPFromName],
+ SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true",
+ SMTPPasswordConfigured: settings[SettingKeySMTPPassword] != "",
+ TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
+ TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
+ TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "",
+ SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
+ SiteLogo: settings[SettingKeySiteLogo],
+ SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
+ APIBaseURL: settings[SettingKeyAPIBaseURL],
+ ContactInfo: settings[SettingKeyContactInfo],
+ DocURL: settings[SettingKeyDocURL],
}
// 解析整数类型
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index 1fba5e13..de0331f7 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -4,17 +4,19 @@ type SystemSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
- SMTPHost string
- SMTPPort int
- SMTPUsername string
- SMTPPassword string
- SMTPFrom string
- SMTPFromName string
- SMTPUseTLS bool
+ SMTPHost string
+ SMTPPort int
+ SMTPUsername string
+ SMTPPassword string
+ SMTPPasswordConfigured bool
+ SMTPFrom string
+ SMTPFromName string
+ SMTPUseTLS bool
- TurnstileEnabled bool
- TurnstileSiteKey string
- TurnstileSecretKey string
+ TurnstileEnabled bool
+ TurnstileSiteKey string
+ TurnstileSecretKey string
+ TurnstileSecretKeyConfigured bool
SiteName string
SiteLogo string
diff --git a/backend/internal/setup/setup.go b/backend/internal/setup/setup.go
index 435f6289..ad077735 100644
--- a/backend/internal/setup/setup.go
+++ b/backend/internal/setup/setup.go
@@ -9,6 +9,7 @@ import (
"log"
"os"
"strconv"
+ "strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/repository"
@@ -196,11 +197,17 @@ func Install(cfg *SetupConfig) error {
// Generate JWT secret if not provided
if cfg.JWT.Secret == "" {
+ if strings.EqualFold(cfg.Server.Mode, "release") {
+ return fmt.Errorf("jwt secret is required in release mode")
+ }
secret, err := generateSecret(32)
if err != nil {
return fmt.Errorf("failed to generate jwt secret: %w", err)
}
cfg.JWT.Secret = secret
+ log.Println("Warning: JWT secret auto-generated for non-release mode. Do not use in production.")
+ } else if strings.EqualFold(cfg.Server.Mode, "release") && len(cfg.JWT.Secret) < 32 {
+ return fmt.Errorf("jwt secret must be at least 32 characters in release mode")
}
// Test connections
@@ -474,12 +481,17 @@ func AutoSetupFromEnv() error {
// Generate JWT secret if not provided
if cfg.JWT.Secret == "" {
+ if strings.EqualFold(cfg.Server.Mode, "release") {
+ return fmt.Errorf("jwt secret is required in release mode")
+ }
secret, err := generateSecret(32)
if err != nil {
return fmt.Errorf("failed to generate jwt secret: %w", err)
}
cfg.JWT.Secret = secret
- log.Println("Generated JWT secret automatically")
+ log.Println("Warning: JWT secret auto-generated for non-release mode. Do not use in production.")
+ } else if strings.EqualFold(cfg.Server.Mode, "release") && len(cfg.JWT.Secret) < 32 {
+ return fmt.Errorf("jwt secret must be at least 32 characters in release mode")
}
// Generate admin password if not provided
@@ -489,8 +501,8 @@ func AutoSetupFromEnv() error {
return fmt.Errorf("failed to generate admin password: %w", err)
}
cfg.Admin.Password = password
- log.Printf("Generated admin password: %s", cfg.Admin.Password)
- log.Println("IMPORTANT: Save this password! It will not be shown again.")
+ fmt.Printf("Generated admin password (one-time): %s\n", cfg.Admin.Password)
+ fmt.Println("IMPORTANT: Save this password! It will not be shown again.")
}
// Test database connection
diff --git a/backend/internal/util/logredact/redact.go b/backend/internal/util/logredact/redact.go
new file mode 100644
index 00000000..b2d2429f
--- /dev/null
+++ b/backend/internal/util/logredact/redact.go
@@ -0,0 +1,100 @@
+package logredact
+
+import (
+ "encoding/json"
+ "strings"
+)
+
+// maxRedactDepth 限制递归深度以防止栈溢出
+const maxRedactDepth = 32
+
+var defaultSensitiveKeys = map[string]struct{}{
+ "authorization_code": {},
+ "code": {},
+ "code_verifier": {},
+ "access_token": {},
+ "refresh_token": {},
+ "id_token": {},
+ "client_secret": {},
+ "password": {},
+}
+
+func RedactMap(input map[string]any, extraKeys ...string) map[string]any {
+ if input == nil {
+ return map[string]any{}
+ }
+ keys := buildKeySet(extraKeys)
+ redacted, ok := redactValueWithDepth(input, keys, 0).(map[string]any)
+ if !ok {
+ return map[string]any{}
+ }
+ return redacted
+}
+
+func RedactJSON(raw []byte, extraKeys ...string) string {
+ if len(raw) == 0 {
+ return ""
+ }
+ var value any
+ if err := json.Unmarshal(raw, &value); err != nil {
+ return "
-
-
-
-
-
-
+
@@ -390,7 +390,7 @@
>
@@ -400,7 +400,7 @@
>
@@ -423,7 +423,7 @@
+ +@@ -164,8 +167,8 @@ interface TabConfig { interface FileConfig { path: string content: string - highlighted: string hint?: string // Optional hint message for this file + highlighted?: string } const props = defineProps++