diff --git a/README.md b/README.md index c3a37e68..95c67986 100644 --- a/README.md +++ b/README.md @@ -268,6 +268,15 @@ default: rate_multiplier: 1.0 ``` +Additional security-related options are available in `config.yaml`: + +- `cors.allowed_origins` for CORS allowlist +- `security.url_allowlist` for upstream/pricing/CRS host allowlists +- `security.csp` to control Content-Security-Policy headers +- `billing.circuit_breaker` to fail closed on billing errors +- `server.trusted_proxies` to enable X-Forwarded-For parsing +- `turnstile.required` to require Turnstile in release mode + ```bash # 6. Run the application ./sub2api diff --git a/README_CN.md b/README_CN.md index 8fe46c51..59436998 100644 --- a/README_CN.md +++ b/README_CN.md @@ -268,6 +268,15 @@ default: rate_multiplier: 1.0 ``` +`config.yaml` 还支持以下安全相关配置: + +- `cors.allowed_origins` 配置 CORS 白名单 +- `security.url_allowlist` 配置上游/价格数据/CRS 主机白名单 +- `security.csp` 配置 Content-Security-Policy +- `billing.circuit_breaker` 计费异常时 fail-closed +- `server.trusted_proxies` 启用可信代理解析 X-Forwarded-For +- `turnstile.required` 在 release 模式强制启用 Turnstile + ```bash # 6. 运行应用 ./sub2api diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index d9e0c485..c9dc57bb 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -86,7 +86,8 @@ func main() { func runSetupServer() { r := gin.New() r.Use(middleware.Recovery()) - r.Use(middleware.CORS()) + r.Use(middleware.CORS(config.CORSConfig{})) + r.Use(middleware.SecurityHeaders(config.CSPConfig{Enabled: true, Policy: config.DefaultCSPPolicy})) // Register setup routes setup.RegisterRoutes(r) diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 47cf5f3a..9f23c993 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -76,7 +76,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { dashboardHandler := admin.NewDashboardHandler(dashboardService) accountRepository := repository.NewAccountRepository(client, db) proxyRepository := repository.NewProxyRepository(client, db) - proxyExitInfoProber := repository.NewProxyExitInfoProber() + proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber) adminUserHandler := admin.NewUserHandler(adminService) groupHandler := admin.NewGroupHandler(adminService) @@ -101,10 +101,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService) httpUpstream := repository.NewHTTPUpstream(configConfig) antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream, settingService) - accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream) + accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) - concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) - crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService) + concurrencyService := service.NewConcurrencyService(concurrencyCache) + crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService) oAuthHandler := admin.NewOAuthHandler(oAuthService) openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService) @@ -125,7 +125,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository) userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler) - pricingRemoteClient := repository.NewPricingRemoteClient() + pricingRemoteClient := repository.NewPricingRemoteClient(configConfig) pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient) if err != nil { return nil, err @@ -136,10 +136,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { timingWheelService := service.ProvideTimingWheelService() deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService) - geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService) - gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService) + geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) + gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService) - openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService) + openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 916efd87..1d8c64ef 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -2,7 +2,10 @@ package config import ( + "crypto/rand" + "encoding/hex" "fmt" + "log" "strings" "time" @@ -14,6 +17,8 @@ const ( RunModeSimple = "simple" ) +const DefaultCSPPolicy = "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self' https:; frame-ancestors 'none'; base-uri 'self'; form-action 'self'" + // 连接池隔离策略常量 // 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗 const ( @@ -30,6 +35,10 @@ const ( type Config struct { Server ServerConfig `mapstructure:"server"` + CORS CORSConfig `mapstructure:"cors"` + Security SecurityConfig `mapstructure:"security"` + Billing BillingConfig `mapstructure:"billing"` + Turnstile TurnstileConfig `mapstructure:"turnstile"` Database DatabaseConfig `mapstructure:"database"` Redis RedisConfig `mapstructure:"redis"` JWT JWTConfig `mapstructure:"jwt"` @@ -37,6 +46,7 @@ type Config struct { RateLimit RateLimitConfig `mapstructure:"rate_limit"` Pricing PricingConfig `mapstructure:"pricing"` Gateway GatewayConfig `mapstructure:"gateway"` + Concurrency ConcurrencyConfig `mapstructure:"concurrency"` TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` RunMode string `mapstructure:"run_mode" yaml:"run_mode"` Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" @@ -95,11 +105,61 @@ type PricingConfig struct { } type ServerConfig struct { - Host string `mapstructure:"host"` - Port int `mapstructure:"port"` - Mode string `mapstructure:"mode"` // debug/release - ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒) - IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒) + Host string `mapstructure:"host"` + Port int `mapstructure:"port"` + Mode string `mapstructure:"mode"` // debug/release + ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒) + IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒) + TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP) +} + +type CORSConfig struct { + AllowedOrigins []string `mapstructure:"allowed_origins"` + AllowCredentials bool `mapstructure:"allow_credentials"` +} + +type SecurityConfig struct { + URLAllowlist URLAllowlistConfig `mapstructure:"url_allowlist"` + ResponseHeaders ResponseHeaderConfig `mapstructure:"response_headers"` + CSP CSPConfig `mapstructure:"csp"` + ProxyProbe ProxyProbeConfig `mapstructure:"proxy_probe"` +} + +type URLAllowlistConfig struct { + UpstreamHosts []string `mapstructure:"upstream_hosts"` + PricingHosts []string `mapstructure:"pricing_hosts"` + CRSHosts []string `mapstructure:"crs_hosts"` + AllowPrivateHosts bool `mapstructure:"allow_private_hosts"` +} + +type ResponseHeaderConfig struct { + AdditionalAllowed []string `mapstructure:"additional_allowed"` + ForceRemove []string `mapstructure:"force_remove"` +} + +type CSPConfig struct { + Enabled bool `mapstructure:"enabled"` + Policy string `mapstructure:"policy"` +} + +type ProxyProbeConfig struct { + InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"` +} + +type BillingConfig struct { + CircuitBreaker CircuitBreakerConfig `mapstructure:"circuit_breaker"` +} + +type CircuitBreakerConfig struct { + Enabled bool `mapstructure:"enabled"` + FailureThreshold int `mapstructure:"failure_threshold"` + ResetTimeoutSeconds int `mapstructure:"reset_timeout_seconds"` + HalfOpenRequests int `mapstructure:"half_open_requests"` +} + +type ConcurrencyConfig struct { + // PingInterval: 并发等待期间的 SSE ping 间隔(秒) + PingInterval int `mapstructure:"ping_interval"` } // GatewayConfig API网关相关配置 @@ -134,6 +194,13 @@ type GatewayConfig struct { // 应大于最长 LLM 请求时间,防止请求完成前槽位过期 ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"` + // StreamDataIntervalTimeout: 流数据间隔超时(秒),0表示禁用 + StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"` + // StreamKeepaliveInterval: 流式 keepalive 间隔(秒),0表示禁用 + StreamKeepaliveInterval int `mapstructure:"stream_keepalive_interval"` + // MaxLineSize: 上游 SSE 单行最大字节数(0使用默认值) + MaxLineSize int `mapstructure:"max_line_size"` + // 是否记录上游错误响应体摘要(避免输出请求内容) LogUpstreamErrorBody bool `mapstructure:"log_upstream_error_body"` // 上游错误响应体记录最大字节数(超过会截断) @@ -237,6 +304,10 @@ type JWTConfig struct { ExpireHour int `mapstructure:"expire_hour"` } +type TurnstileConfig struct { + Required bool `mapstructure:"required"` +} + type DefaultConfig struct { AdminEmail string `mapstructure:"admin_email"` AdminPassword string `mapstructure:"admin_password"` @@ -287,11 +358,39 @@ func Load() (*Config, error) { } cfg.RunMode = NormalizeRunMode(cfg.RunMode) + cfg.Server.Mode = strings.ToLower(strings.TrimSpace(cfg.Server.Mode)) + if cfg.Server.Mode == "" { + cfg.Server.Mode = "debug" + } + cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret) + cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins) + cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed) + cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove) + cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy) + + if cfg.Server.Mode != "release" && cfg.JWT.Secret == "" { + secret, err := generateJWTSecret(64) + if err != nil { + return nil, fmt.Errorf("generate jwt secret error: %w", err) + } + cfg.JWT.Secret = secret + log.Println("Warning: JWT secret auto-generated for non-release mode. Do not use in production.") + } if err := cfg.Validate(); err != nil { return nil, fmt.Errorf("validate config error: %w", err) } + if cfg.Server.Mode != "release" && cfg.JWT.Secret != "" && isWeakJWTSecret(cfg.JWT.Secret) { + log.Println("Warning: JWT secret appears weak; use a 32+ character random secret in production.") + } + if len(cfg.Security.ResponseHeaders.AdditionalAllowed) > 0 || len(cfg.Security.ResponseHeaders.ForceRemove) > 0 { + log.Printf("AUDIT: response header policy configured additional_allowed=%v force_remove=%v", + cfg.Security.ResponseHeaders.AdditionalAllowed, + cfg.Security.ResponseHeaders.ForceRemove, + ) + } + return &cfg, nil } @@ -304,6 +403,39 @@ func setDefaults() { viper.SetDefault("server.mode", "debug") viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头 viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时 + viper.SetDefault("server.trusted_proxies", []string{}) + + // CORS + viper.SetDefault("cors.allowed_origins", []string{}) + viper.SetDefault("cors.allow_credentials", true) + + // Security + viper.SetDefault("security.url_allowlist.upstream_hosts", []string{ + "api.openai.com", + "api.anthropic.com", + "generativelanguage.googleapis.com", + "cloudcode-pa.googleapis.com", + "*.openai.azure.com", + }) + viper.SetDefault("security.url_allowlist.pricing_hosts", []string{ + "raw.githubusercontent.com", + }) + viper.SetDefault("security.url_allowlist.crs_hosts", []string{}) + viper.SetDefault("security.url_allowlist.allow_private_hosts", false) + viper.SetDefault("security.response_headers.additional_allowed", []string{}) + viper.SetDefault("security.response_headers.force_remove", []string{}) + viper.SetDefault("security.csp.enabled", true) + viper.SetDefault("security.csp.policy", DefaultCSPPolicy) + viper.SetDefault("security.proxy_probe.insecure_skip_verify", false) + + // Billing + viper.SetDefault("billing.circuit_breaker.enabled", true) + viper.SetDefault("billing.circuit_breaker.failure_threshold", 5) + viper.SetDefault("billing.circuit_breaker.reset_timeout_seconds", 30) + viper.SetDefault("billing.circuit_breaker.half_open_requests", 3) + + // Turnstile + viper.SetDefault("turnstile.required", false) // Database viper.SetDefault("database.host", "localhost") @@ -329,7 +461,7 @@ func setDefaults() { viper.SetDefault("redis.min_idle_conns", 10) // JWT - viper.SetDefault("jwt.secret", "change-me-in-production") + viper.SetDefault("jwt.secret", "") viper.SetDefault("jwt.expire_hour", 24) // Default @@ -357,7 +489,7 @@ func setDefaults() { viper.SetDefault("timezone", "Asia/Shanghai") // Gateway - viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久 + viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久 viper.SetDefault("gateway.log_upstream_error_body", false) viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048) viper.SetDefault("gateway.inject_beta_for_apikey", false) @@ -365,19 +497,23 @@ func setDefaults() { viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) // HTTP 上游连接池配置(针对 5000+ 并发用户优化) - viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认) - viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认) - viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃,HTTP/2 场景默认) - viper.SetDefault("gateway.idle_conn_timeout_seconds", 300) // 空闲连接超时(秒) + viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认) + viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认) + viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃,HTTP/2 场景默认) + viper.SetDefault("gateway.idle_conn_timeout_seconds", 90) // 空闲连接超时(秒) viper.SetDefault("gateway.max_upstream_clients", 5000) viper.SetDefault("gateway.client_idle_ttl_seconds", 900) - viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 15) // 并发槽位过期时间(支持超长请求) + viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求) + viper.SetDefault("gateway.stream_data_interval_timeout", 180) + viper.SetDefault("gateway.stream_keepalive_interval", 10) + viper.SetDefault("gateway.max_line_size", 10*1024*1024) viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3) viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second) viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second) viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100) viper.SetDefault("gateway.scheduling.load_batch_enabled", true) viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second) + viper.SetDefault("concurrency.ping_interval", 10) // TokenRefresh viper.SetDefault("token_refresh.enabled", true) @@ -396,11 +532,39 @@ func setDefaults() { } func (c *Config) Validate() error { - if c.JWT.Secret == "" { - return fmt.Errorf("jwt.secret is required") + if c.Server.Mode == "release" { + if c.JWT.Secret == "" { + return fmt.Errorf("jwt.secret is required in release mode") + } + if len(c.JWT.Secret) < 32 { + return fmt.Errorf("jwt.secret must be at least 32 characters") + } + if isWeakJWTSecret(c.JWT.Secret) { + return fmt.Errorf("jwt.secret is too weak") + } } - if c.JWT.Secret == "change-me-in-production" && c.Server.Mode == "release" { - return fmt.Errorf("jwt.secret must be changed in production") + if c.JWT.ExpireHour <= 0 { + return fmt.Errorf("jwt.expire_hour must be positive") + } + if c.JWT.ExpireHour > 168 { + return fmt.Errorf("jwt.expire_hour must be <= 168 (7 days)") + } + if c.JWT.ExpireHour > 24 { + log.Printf("Warning: jwt.expire_hour is %d hours (> 24). Consider shorter expiration for security.", c.JWT.ExpireHour) + } + if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" { + return fmt.Errorf("security.csp.policy is required when CSP is enabled") + } + if c.Billing.CircuitBreaker.Enabled { + if c.Billing.CircuitBreaker.FailureThreshold <= 0 { + return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive") + } + if c.Billing.CircuitBreaker.ResetTimeoutSeconds <= 0 { + return fmt.Errorf("billing.circuit_breaker.reset_timeout_seconds must be positive") + } + if c.Billing.CircuitBreaker.HalfOpenRequests <= 0 { + return fmt.Errorf("billing.circuit_breaker.half_open_requests must be positive") + } } if c.Database.MaxOpenConns <= 0 { return fmt.Errorf("database.max_open_conns must be positive") @@ -458,6 +622,9 @@ func (c *Config) Validate() error { if c.Gateway.IdleConnTimeoutSeconds <= 0 { return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive") } + if c.Gateway.IdleConnTimeoutSeconds > 180 { + log.Printf("Warning: gateway.idle_conn_timeout_seconds is %d (> 180). Consider 60-120 seconds for better connection reuse.", c.Gateway.IdleConnTimeoutSeconds) + } if c.Gateway.MaxUpstreamClients <= 0 { return fmt.Errorf("gateway.max_upstream_clients must be positive") } @@ -467,6 +634,26 @@ func (c *Config) Validate() error { if c.Gateway.ConcurrencySlotTTLMinutes <= 0 { return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive") } + if c.Gateway.StreamDataIntervalTimeout < 0 { + return fmt.Errorf("gateway.stream_data_interval_timeout must be non-negative") + } + if c.Gateway.StreamDataIntervalTimeout != 0 && + (c.Gateway.StreamDataIntervalTimeout < 30 || c.Gateway.StreamDataIntervalTimeout > 300) { + return fmt.Errorf("gateway.stream_data_interval_timeout must be 0 or between 30-300 seconds") + } + if c.Gateway.StreamKeepaliveInterval < 0 { + return fmt.Errorf("gateway.stream_keepalive_interval must be non-negative") + } + if c.Gateway.StreamKeepaliveInterval != 0 && + (c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) { + return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds") + } + if c.Gateway.MaxLineSize < 0 { + return fmt.Errorf("gateway.max_line_size must be non-negative") + } + if c.Gateway.MaxLineSize != 0 && c.Gateway.MaxLineSize < 1024*1024 { + return fmt.Errorf("gateway.max_line_size must be at least 1MB") + } if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 { return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive") } @@ -482,9 +669,57 @@ func (c *Config) Validate() error { if c.Gateway.Scheduling.SlotCleanupInterval < 0 { return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative") } + if c.Concurrency.PingInterval < 5 || c.Concurrency.PingInterval > 30 { + return fmt.Errorf("concurrency.ping_interval must be between 5-30 seconds") + } return nil } +func normalizeStringSlice(values []string) []string { + if len(values) == 0 { + return values + } + normalized := make([]string, 0, len(values)) + for _, v := range values { + trimmed := strings.TrimSpace(v) + if trimmed == "" { + continue + } + normalized = append(normalized, trimmed) + } + return normalized +} + +func isWeakJWTSecret(secret string) bool { + lower := strings.ToLower(strings.TrimSpace(secret)) + if lower == "" { + return true + } + weak := map[string]struct{}{ + "change-me-in-production": {}, + "changeme": {}, + "secret": {}, + "password": {}, + "123456": {}, + "12345678": {}, + "admin": {}, + "jwt-secret": {}, + } + _, exists := weak[lower] + return exists +} + +func generateJWTSecret(byteLength int) (string, error) { + if byteLength <= 0 { + byteLength = 32 + } + buf := make([]byte, byteLength) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return hex.EncodeToString(buf), nil +} + // GetServerAddress returns the server address (host:port) from config file or environment variable. // This is a lightweight function that can be used before full config validation, // such as during setup wizard startup. diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index a52b06b4..9ce29785 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -1,8 +1,12 @@ package admin import ( + "log" + "time" + "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -34,33 +38,33 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { } response.Success(c, dto.SystemSettings{ - RegistrationEnabled: settings.RegistrationEnabled, - EmailVerifyEnabled: settings.EmailVerifyEnabled, - SMTPHost: settings.SMTPHost, - SMTPPort: settings.SMTPPort, - SMTPUsername: settings.SMTPUsername, - SMTPPassword: settings.SMTPPassword, - SMTPFrom: settings.SMTPFrom, - SMTPFromName: settings.SMTPFromName, - SMTPUseTLS: settings.SMTPUseTLS, - TurnstileEnabled: settings.TurnstileEnabled, - TurnstileSiteKey: settings.TurnstileSiteKey, - TurnstileSecretKey: settings.TurnstileSecretKey, - SiteName: settings.SiteName, - SiteLogo: settings.SiteLogo, - SiteSubtitle: settings.SiteSubtitle, - APIBaseURL: settings.APIBaseURL, - ContactInfo: settings.ContactInfo, - DocURL: settings.DocURL, - DefaultConcurrency: settings.DefaultConcurrency, - DefaultBalance: settings.DefaultBalance, - EnableModelFallback: settings.EnableModelFallback, - FallbackModelAnthropic: settings.FallbackModelAnthropic, - FallbackModelOpenAI: settings.FallbackModelOpenAI, - FallbackModelGemini: settings.FallbackModelGemini, - FallbackModelAntigravity: settings.FallbackModelAntigravity, - EnableIdentityPatch: settings.EnableIdentityPatch, - IdentityPatchPrompt: settings.IdentityPatchPrompt, + RegistrationEnabled: settings.RegistrationEnabled, + EmailVerifyEnabled: settings.EmailVerifyEnabled, + SMTPHost: settings.SMTPHost, + SMTPPort: settings.SMTPPort, + SMTPUsername: settings.SMTPUsername, + SMTPPasswordConfigured: settings.SMTPPasswordConfigured, + SMTPFrom: settings.SMTPFrom, + SMTPFromName: settings.SMTPFromName, + SMTPUseTLS: settings.SMTPUseTLS, + TurnstileEnabled: settings.TurnstileEnabled, + TurnstileSiteKey: settings.TurnstileSiteKey, + TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured, + SiteName: settings.SiteName, + SiteLogo: settings.SiteLogo, + SiteSubtitle: settings.SiteSubtitle, + APIBaseURL: settings.APIBaseURL, + ContactInfo: settings.ContactInfo, + DocURL: settings.DocURL, + DefaultConcurrency: settings.DefaultConcurrency, + DefaultBalance: settings.DefaultBalance, + EnableModelFallback: settings.EnableModelFallback, + FallbackModelAnthropic: settings.FallbackModelAnthropic, + FallbackModelOpenAI: settings.FallbackModelOpenAI, + FallbackModelGemini: settings.FallbackModelGemini, + FallbackModelAntigravity: settings.FallbackModelAntigravity, + EnableIdentityPatch: settings.EnableIdentityPatch, + IdentityPatchPrompt: settings.IdentityPatchPrompt, }) } @@ -117,6 +121,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { return } + previousSettings, err := h.settingService.GetAllSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + // 验证参数 if req.DefaultConcurrency < 1 { req.DefaultConcurrency = 1 @@ -193,6 +203,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { return } + h.auditSettingsUpdate(c, previousSettings, settings, req) + // 重新获取设置返回 updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context()) if err != nil { @@ -201,36 +213,136 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } response.Success(c, dto.SystemSettings{ - RegistrationEnabled: updatedSettings.RegistrationEnabled, - EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, - SMTPHost: updatedSettings.SMTPHost, - SMTPPort: updatedSettings.SMTPPort, - SMTPUsername: updatedSettings.SMTPUsername, - SMTPPassword: updatedSettings.SMTPPassword, - SMTPFrom: updatedSettings.SMTPFrom, - SMTPFromName: updatedSettings.SMTPFromName, - SMTPUseTLS: updatedSettings.SMTPUseTLS, - TurnstileEnabled: updatedSettings.TurnstileEnabled, - TurnstileSiteKey: updatedSettings.TurnstileSiteKey, - TurnstileSecretKey: updatedSettings.TurnstileSecretKey, - SiteName: updatedSettings.SiteName, - SiteLogo: updatedSettings.SiteLogo, - SiteSubtitle: updatedSettings.SiteSubtitle, - APIBaseURL: updatedSettings.APIBaseURL, - ContactInfo: updatedSettings.ContactInfo, - DocURL: updatedSettings.DocURL, - DefaultConcurrency: updatedSettings.DefaultConcurrency, - DefaultBalance: updatedSettings.DefaultBalance, - EnableModelFallback: updatedSettings.EnableModelFallback, - FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic, - FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI, - FallbackModelGemini: updatedSettings.FallbackModelGemini, - FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity, - EnableIdentityPatch: updatedSettings.EnableIdentityPatch, - IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt, + RegistrationEnabled: updatedSettings.RegistrationEnabled, + EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, + SMTPHost: updatedSettings.SMTPHost, + SMTPPort: updatedSettings.SMTPPort, + SMTPUsername: updatedSettings.SMTPUsername, + SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured, + SMTPFrom: updatedSettings.SMTPFrom, + SMTPFromName: updatedSettings.SMTPFromName, + SMTPUseTLS: updatedSettings.SMTPUseTLS, + TurnstileEnabled: updatedSettings.TurnstileEnabled, + TurnstileSiteKey: updatedSettings.TurnstileSiteKey, + TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured, + SiteName: updatedSettings.SiteName, + SiteLogo: updatedSettings.SiteLogo, + SiteSubtitle: updatedSettings.SiteSubtitle, + APIBaseURL: updatedSettings.APIBaseURL, + ContactInfo: updatedSettings.ContactInfo, + DocURL: updatedSettings.DocURL, + DefaultConcurrency: updatedSettings.DefaultConcurrency, + DefaultBalance: updatedSettings.DefaultBalance, + EnableModelFallback: updatedSettings.EnableModelFallback, + FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic, + FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI, + FallbackModelGemini: updatedSettings.FallbackModelGemini, + FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity, + EnableIdentityPatch: updatedSettings.EnableIdentityPatch, + IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt, }) } +func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) { + if before == nil || after == nil { + return + } + + changed := diffSettings(before, after, req) + if len(changed) == 0 { + return + } + + subject, _ := middleware.GetAuthSubjectFromContext(c) + role, _ := middleware.GetUserRoleFromContext(c) + log.Printf("AUDIT: settings updated at=%s user_id=%d role=%s changed=%v", + time.Now().UTC().Format(time.RFC3339), + subject.UserID, + role, + changed, + ) +} + +func diffSettings(before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) []string { + changed := make([]string, 0, 20) + if before.RegistrationEnabled != after.RegistrationEnabled { + changed = append(changed, "registration_enabled") + } + if before.EmailVerifyEnabled != after.EmailVerifyEnabled { + changed = append(changed, "email_verify_enabled") + } + if before.SMTPHost != after.SMTPHost { + changed = append(changed, "smtp_host") + } + if before.SMTPPort != after.SMTPPort { + changed = append(changed, "smtp_port") + } + if before.SMTPUsername != after.SMTPUsername { + changed = append(changed, "smtp_username") + } + if req.SMTPPassword != "" { + changed = append(changed, "smtp_password") + } + if before.SMTPFrom != after.SMTPFrom { + changed = append(changed, "smtp_from_email") + } + if before.SMTPFromName != after.SMTPFromName { + changed = append(changed, "smtp_from_name") + } + if before.SMTPUseTLS != after.SMTPUseTLS { + changed = append(changed, "smtp_use_tls") + } + if before.TurnstileEnabled != after.TurnstileEnabled { + changed = append(changed, "turnstile_enabled") + } + if before.TurnstileSiteKey != after.TurnstileSiteKey { + changed = append(changed, "turnstile_site_key") + } + if req.TurnstileSecretKey != "" { + changed = append(changed, "turnstile_secret_key") + } + if before.SiteName != after.SiteName { + changed = append(changed, "site_name") + } + if before.SiteLogo != after.SiteLogo { + changed = append(changed, "site_logo") + } + if before.SiteSubtitle != after.SiteSubtitle { + changed = append(changed, "site_subtitle") + } + if before.APIBaseURL != after.APIBaseURL { + changed = append(changed, "api_base_url") + } + if before.ContactInfo != after.ContactInfo { + changed = append(changed, "contact_info") + } + if before.DocURL != after.DocURL { + changed = append(changed, "doc_url") + } + if before.DefaultConcurrency != after.DefaultConcurrency { + changed = append(changed, "default_concurrency") + } + if before.DefaultBalance != after.DefaultBalance { + changed = append(changed, "default_balance") + } + if before.EnableModelFallback != after.EnableModelFallback { + changed = append(changed, "enable_model_fallback") + } + if before.FallbackModelAnthropic != after.FallbackModelAnthropic { + changed = append(changed, "fallback_model_anthropic") + } + if before.FallbackModelOpenAI != after.FallbackModelOpenAI { + changed = append(changed, "fallback_model_openai") + } + if before.FallbackModelGemini != after.FallbackModelGemini { + changed = append(changed, "fallback_model_gemini") + } + if before.FallbackModelAntigravity != after.FallbackModelAntigravity { + changed = append(changed, "fallback_model_antigravity") + } + return changed +} + // TestSMTPRequest 测试SMTP连接请求 type TestSMTPRequest struct { SMTPHost string `json:"smtp_host" binding:"required"` diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 668fb2dc..4c50cedf 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -5,17 +5,17 @@ type SystemSettings struct { RegistrationEnabled bool `json:"registration_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"` - SMTPHost string `json:"smtp_host"` - SMTPPort int `json:"smtp_port"` - SMTPUsername string `json:"smtp_username"` - SMTPPassword string `json:"smtp_password,omitempty"` - SMTPFrom string `json:"smtp_from_email"` - SMTPFromName string `json:"smtp_from_name"` - SMTPUseTLS bool `json:"smtp_use_tls"` + SMTPHost string `json:"smtp_host"` + SMTPPort int `json:"smtp_port"` + SMTPUsername string `json:"smtp_username"` + SMTPPasswordConfigured bool `json:"smtp_password_configured"` + SMTPFrom string `json:"smtp_from_email"` + SMTPFromName string `json:"smtp_from_name"` + SMTPUseTLS bool `json:"smtp_use_tls"` - TurnstileEnabled bool `json:"turnstile_enabled"` - TurnstileSiteKey string `json:"turnstile_site_key"` - TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"` + TurnstileEnabled bool `json:"turnstile_enabled"` + TurnstileSiteKey string `json:"turnstile_site_key"` + TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"` SiteName string `json:"site_name"` SiteLogo string `json:"site_logo"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 8247a0c3..de3cbad9 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -1,7 +1,6 @@ package handler import ( - "bytes" "context" "encoding/json" "errors" @@ -12,8 +11,10 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" @@ -21,10 +22,6 @@ import ( "github.com/gin-gonic/gin" ) -const maxGatewayRequestBodyBytes int64 = 10 * 1024 * 1024 // 10MB - -var errEmptyRequestBody = errors.New("request body is empty") - // GatewayHandler handles API gateway requests type GatewayHandler struct { gatewayService *service.GatewayService @@ -35,23 +32,6 @@ type GatewayHandler struct { concurrencyHelper *ConcurrencyHelper } -func (h *GatewayHandler) recordUsageSync(apiKey *service.APIKey, subscription *service.UserSubscription, result *service.ForwardResult, usedAccount *service.Account) { - // 计费属于关键数据:同步写入,避免 goroutine 异步导致进程崩溃时丢失使用量/扣费数据。 - // 使用独立 Background context,避免客户端取消请求导致计费中断。 - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: usedAccount, - Subscription: subscription, - }); err != nil { - log.Printf("Record usage failed: request_id=%s user=%d api_key=%d account=%d err=%v", result.RequestID, apiKey.UserID, apiKey.ID, usedAccount.ID, err) - } -} - // NewGatewayHandler creates a new GatewayHandler func NewGatewayHandler( gatewayService *service.GatewayService, @@ -60,89 +40,22 @@ func NewGatewayHandler( userService *service.UserService, concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService, + cfg *config.Config, ) *GatewayHandler { + pingInterval := time.Duration(0) + if cfg != nil { + pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second + } return &GatewayHandler{ gatewayService: gatewayService, geminiCompatService: geminiCompatService, antigravityGatewayService: antigravityGatewayService, userService: userService, billingCacheService: billingCacheService, - concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude), + concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), } } -func parseGatewayRequestStream(r io.Reader, limit int64) (*service.ParsedRequest, error) { - if r == nil { - return nil, errEmptyRequestBody - } - - var raw bytes.Buffer - limited := io.LimitReader(r, limit+1) - tee := io.TeeReader(limited, &raw) - decoder := json.NewDecoder(tee) - - var req map[string]any - if err := decoder.Decode(&req); err != nil { - if errors.Is(err, io.EOF) { - return nil, errEmptyRequestBody - } - if int64(raw.Len()) > limit { - return nil, &http.MaxBytesError{Limit: limit} - } - return nil, err - } - - // Ensure the body contains exactly one JSON value (allowing trailing whitespace). - var extra any - if err := decoder.Decode(&extra); err != io.EOF { - if int64(raw.Len()) > limit { - return nil, &http.MaxBytesError{Limit: limit} - } - if err == nil { - return nil, fmt.Errorf("request body must contain a single JSON object") - } - return nil, err - } - if int64(raw.Len()) > limit { - return nil, &http.MaxBytesError{Limit: limit} - } - - parsed := &service.ParsedRequest{ - Body: raw.Bytes(), - } - - if rawModel, exists := req["model"]; exists { - model, ok := rawModel.(string) - if !ok { - return nil, fmt.Errorf("invalid model field type") - } - parsed.Model = model - } - if rawStream, exists := req["stream"]; exists { - stream, ok := rawStream.(bool) - if !ok { - return nil, fmt.Errorf("invalid stream field type") - } - parsed.Stream = stream - } - if metadata, ok := req["metadata"].(map[string]any); ok { - if userID, ok := metadata["user_id"].(string); ok { - parsed.MetadataUserID = userID - } - } - // system 字段只要存在就视为显式提供(即使为 null), - // 以避免客户端传 null 时被默认 system 误注入。 - if system, ok := req["system"]; ok { - parsed.HasSystem = true - parsed.System = system - } - if messages, ok := req["messages"].([]any); ok { - parsed.Messages = messages - } - - return parsed, nil -} - // Messages handles Claude API compatible messages endpoint // POST /v1/messages func (h *GatewayHandler) Messages(c *gin.Context) { @@ -159,29 +72,27 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } - parsedReq, err := parseGatewayRequestStream(c.Request.Body, maxGatewayRequestBodyBytes) + // 读取请求体 + body, err := io.ReadAll(c.Request.Body) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) return } - if errors.Is(err, errEmptyRequestBody) { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") - return - } - var syntaxErr *json.SyntaxError - var typeErr *json.UnmarshalTypeError - if errors.As(err, &syntaxErr) || errors.As(err, &typeErr) || errors.Is(err, io.ErrUnexpectedEOF) { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") - return - } h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") return } - if len(parsedReq.Body) == 0 { + + if len(body) == 0 { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") return } + + parsedReq, err := service.ParseGatewayRequest(body) + if err != nil { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } reqModel := parsedReq.Model reqStream := parsedReq.Stream @@ -217,6 +128,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleConcurrencyError(c, err, "user", streamStarted) return } + // 在请求结束或 Context 取消时确保释放槽位,避免客户端断开造成泄漏 + userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) if userReleaseFunc != nil { defer userReleaseFunc() } @@ -224,7 +137,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 2. 【新增】Wait后二次检查余额/订阅 if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { log.Printf("Billing eligibility check failed after wait: %v", err) - h.handleStreamingAwareError(c, http.StatusForbidden, "permission_error", "Insufficient balance or active subscription required", streamStarted) + status, code, message := billingErrorDetails(err) + h.handleStreamingAwareError(c, status, code, message, streamStarted) return } @@ -252,9 +166,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { for { selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs) if err != nil { - log.Printf("Select account failed: %v", err) if len(failedAccountIDs) == 0 { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts for requested model", streamStarted) + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) @@ -263,7 +176,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { account := selection.Account // 检查预热请求拦截(在账号选择后、转发前检查) - if account.IsInterceptWarmupEnabled() && isWarmupRequest(parsedReq.Body) { + if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { if selection.Acquired && selection.ReleaseFunc != nil { selection.ReleaseFunc() } @@ -317,13 +230,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) { log.Printf("Bind sticky session failed: %v", err) } } + // 账号槽位/等待计数需要在超时或断开时安全回收 + accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease) // 转发请求 - 根据账号平台分流 var result *service.ForwardResult if account.Platform == service.PlatformAntigravity { - result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, parsedReq.Body) + result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, body) } else { - result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, parsedReq.Body) + result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body) } if accountReleaseFunc != nil { accountReleaseFunc() @@ -350,8 +266,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } - // 同步记录使用量,避免进程崩溃导致计费数据丢失(subscription已在函数开头获取) - h.recordUsageSync(apiKey, subscription, result, account) + // 异步记录使用量(subscription已在函数开头获取) + go func(result *service.ForwardResult, usedAccount *service.Account) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: usedAccount, + Subscription: subscription, + }); err != nil { + log.Printf("Record usage failed: %v", err) + } + }(result, account) return } } @@ -365,9 +293,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 选择支持该模型的账号 selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs) if err != nil { - log.Printf("Select account failed: %v", err) if len(failedAccountIDs) == 0 { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts for requested model", streamStarted) + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) @@ -376,7 +303,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { account := selection.Account // 检查预热请求拦截(在账号选择后、转发前检查) - if account.IsInterceptWarmupEnabled() && isWarmupRequest(parsedReq.Body) { + if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { if selection.Acquired && selection.ReleaseFunc != nil { selection.ReleaseFunc() } @@ -430,11 +357,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) { log.Printf("Bind sticky session failed: %v", err) } } + // 账号槽位/等待计数需要在超时或断开时安全回收 + accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease) // 转发请求 - 根据账号平台分流 var result *service.ForwardResult if account.Platform == service.PlatformAntigravity { - result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, parsedReq.Body) + result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body) } else { result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq) } @@ -463,8 +393,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } - // 同步记录使用量,避免进程崩溃导致计费数据丢失(subscription已在函数开头获取) - h.recordUsageSync(apiKey, subscription, result, account) + // 异步记录使用量(subscription已在函数开头获取) + go func(result *service.ForwardResult, usedAccount *service.Account) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: usedAccount, + Subscription: subscription, + }); err != nil { + log.Printf("Record usage failed: %v", err) + } + }(result, account) return } } @@ -640,71 +582,32 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, func (h *GatewayHandler) mapUpstreamError(statusCode int) (int, string, string) { switch statusCode { case 401: - return http.StatusBadGateway, "api_error", "Upstream authentication failed, please contact administrator" + return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator" case 403: - return http.StatusBadGateway, "api_error", "Upstream access forbidden, please contact administrator" + return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator" case 429: return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later" case 529: return http.StatusServiceUnavailable, "overloaded_error", "Upstream service overloaded, please retry later" case 500, 502, 503, 504: - return http.StatusBadGateway, "api_error", "Upstream service temporarily unavailable" + return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable" default: - return http.StatusBadGateway, "api_error", "Upstream request failed" + return http.StatusBadGateway, "upstream_error", "Upstream request failed" } } -func normalizeAnthropicErrorType(errType string) string { - switch errType { - case "invalid_request_error", - "authentication_error", - "permission_error", - "not_found_error", - "rate_limit_error", - "api_error", - "overloaded_error": - return errType - case "billing_error": - // Not an Anthropic-standard error type; map to the closest equivalent. - return "permission_error" - case "subscription_error": - // Not an Anthropic-standard error type; map to the closest equivalent. - return "permission_error" - case "upstream_error": - // Not an Anthropic-standard error type; keep clients compatible. - return "api_error" - default: - return "api_error" - } -} - -const maxPublicErrorMessageLen = 512 - -func sanitizePublicErrorMessage(message string) string { - cleaned := strings.TrimSpace(message) - cleaned = strings.ReplaceAll(cleaned, "\r", " ") - cleaned = strings.ReplaceAll(cleaned, "\n", " ") - if len(cleaned) > maxPublicErrorMessageLen { - cleaned = cleaned[:maxPublicErrorMessageLen] + "..." - } - return cleaned -} - // handleStreamingAwareError handles errors that may occur after streaming has started func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) { - normalizedType := normalizeAnthropicErrorType(errType) - publicMessage := sanitizePublicErrorMessage(message) - if streamStarted { // Stream already started, send error as SSE event then close flusher, ok := c.Writer.(http.Flusher) if ok { - // Anthropic streaming spec: send `event: error` with JSON `data`. + // Send error event in SSE format with proper JSON marshaling errorData := map[string]any{ "type": "error", "error": map[string]string{ - "type": normalizedType, - "message": publicMessage, + "type": errType, + "message": message, }, } jsonBytes, err := json.Marshal(errorData) @@ -712,11 +615,8 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e _ = c.Error(err) return } - if _, err := fmt.Fprintf(c.Writer, "event: error\n"); err != nil { - _ = c.Error(err) - return - } - if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", string(jsonBytes)); err != nil { + errorEvent := fmt.Sprintf("data: %s\n\n", string(jsonBytes)) + if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil { _ = c.Error(err) } flusher.Flush() @@ -725,19 +625,16 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e } // Normal case: return JSON response with proper status code - h.errorResponse(c, status, normalizedType, publicMessage) + h.errorResponse(c, status, errType, message) } // errorResponse 返回Claude API格式的错误响应 func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) { - normalizedType := normalizeAnthropicErrorType(errType) - publicMessage := sanitizePublicErrorMessage(message) - c.JSON(status, gin.H{ "type": "error", "error": gin.H{ - "type": normalizedType, - "message": publicMessage, + "type": errType, + "message": message, }, }) } @@ -759,30 +656,28 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { return } - parsedReq, err := parseGatewayRequestStream(c.Request.Body, maxGatewayRequestBodyBytes) + // 读取请求体 + body, err := io.ReadAll(c.Request.Body) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) return } - if errors.Is(err, errEmptyRequestBody) { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") - return - } - var syntaxErr *json.SyntaxError - var typeErr *json.UnmarshalTypeError - if errors.As(err, &syntaxErr) || errors.As(err, &typeErr) || errors.Is(err, io.ErrUnexpectedEOF) { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") - return - } h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") return } - if len(parsedReq.Body) == 0 { + + if len(body) == 0 { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") return } + parsedReq, err := service.ParseGatewayRequest(body) + if err != nil { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } + // 验证 model 必填 if parsedReq.Model == "" { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") @@ -795,8 +690,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { // 校验 billing eligibility(订阅/余额) // 【注意】不计算并发,但需要校验订阅/余额 if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { - log.Printf("Billing eligibility check failed: %v", err) - h.errorResponse(c, http.StatusForbidden, "permission_error", "Insufficient balance or active subscription required") + status, code, message := billingErrorDetails(err) + h.errorResponse(c, status, code, message) return } @@ -806,8 +701,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { // 选择支持该模型的账号 account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model) if err != nil { - log.Printf("Select account failed: %v", err) - h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts for requested model") + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) return } @@ -923,3 +817,18 @@ func sendMockWarmupResponse(c *gin.Context, model string) { }, }) } + +func billingErrorDetails(err error) (status int, code, message string) { + if errors.Is(err, service.ErrBillingServiceUnavailable) { + msg := pkgerrors.Message(err) + if msg == "" { + msg = "Billing service temporarily unavailable. Please retry later." + } + return http.StatusServiceUnavailable, "billing_service_error", msg + } + msg := pkgerrors.Message(err) + if msg == "" { + msg = err.Error() + } + return http.StatusForbidden, "billing_error", msg +} diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go index 9d2e4a9d..2eb3ac72 100644 --- a/backend/internal/handler/gateway_helper.go +++ b/backend/internal/handler/gateway_helper.go @@ -5,6 +5,7 @@ import ( "fmt" "math/rand" "net/http" + "sync" "time" "github.com/Wei-Shaw/sub2api/internal/service" @@ -26,8 +27,8 @@ import ( const ( // maxConcurrencyWait 等待并发槽位的最大时间 maxConcurrencyWait = 30 * time.Second - // pingInterval 流式响应等待时发送 ping 的间隔 - pingInterval = 15 * time.Second + // defaultPingInterval 流式响应等待时发送 ping 的默认间隔 + defaultPingInterval = 10 * time.Second // initialBackoff 初始退避时间 initialBackoff = 100 * time.Millisecond // backoffMultiplier 退避时间乘数(指数退避) @@ -44,6 +45,8 @@ const ( SSEPingFormatClaude SSEPingFormat = "data: {\"type\": \"ping\"}\n\n" // SSEPingFormatNone indicates no ping should be sent (e.g., OpenAI has no ping spec) SSEPingFormatNone SSEPingFormat = "" + // SSEPingFormatComment is an SSE comment ping for OpenAI/Codex CLI clients + SSEPingFormatComment SSEPingFormat = ":\n\n" ) // ConcurrencyError represents a concurrency limit error with context @@ -63,16 +66,38 @@ func (e *ConcurrencyError) Error() string { type ConcurrencyHelper struct { concurrencyService *service.ConcurrencyService pingFormat SSEPingFormat + pingInterval time.Duration } // NewConcurrencyHelper creates a new ConcurrencyHelper -func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat) *ConcurrencyHelper { +func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat, pingInterval time.Duration) *ConcurrencyHelper { + if pingInterval <= 0 { + pingInterval = defaultPingInterval + } return &ConcurrencyHelper{ concurrencyService: concurrencyService, pingFormat: pingFormat, + pingInterval: pingInterval, } } +// wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation. +// 用于避免客户端断开或上游超时导致的并发槽位泄漏。 +func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() { + if releaseFunc == nil { + return nil + } + var once sync.Once + wrapped := func() { + once.Do(releaseFunc) + } + go func() { + <-ctx.Done() + wrapped() + }() + return wrapped +} + // IncrementWaitCount increments the wait count for a user func (h *ConcurrencyHelper) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { return h.concurrencyService.IncrementWaitCount(ctx, userID, maxWait) @@ -174,7 +199,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType // Only create ping ticker if ping is needed var pingCh <-chan time.Time if needPing { - pingTicker := time.NewTicker(pingInterval) + pingTicker := time.NewTicker(h.pingInterval) defer pingTicker.Stop() pingCh = pingTicker.C } diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 2dbc7660..aa75e6c1 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -165,7 +165,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { subscription, _ := middleware.GetSubscriptionFromContext(c) // For Gemini native API, do not send Claude-style ping frames. - geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone) + geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0) // 0) wait queue check maxWait := service.CalculateMaxWait(authSubject.Concurrency) @@ -185,13 +185,16 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { googleError(c, http.StatusTooManyRequests, err.Error()) return } + // 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏 + userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) if userReleaseFunc != nil { defer userReleaseFunc() } // 2) billing eligibility check (after wait) if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { - googleError(c, http.StatusForbidden, err.Error()) + status, _, message := billingErrorDetails(err) + googleError(c, status, message) return } @@ -260,6 +263,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { log.Printf("Bind sticky session failed: %v", err) } } + // 账号槽位/等待计数需要在超时或断开时安全回收 + accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease) // 5) forward (根据平台分流) var result *service.ForwardResult @@ -373,7 +379,7 @@ func writeUpstreamResponse(c *gin.Context, res *service.UpstreamHTTPResult) { } for k, vv := range res.Headers { // Avoid overriding content-length and hop-by-hop headers. - if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Connection") { + if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Connection") || strings.EqualFold(k, "Www-Authenticate") { continue } for _, v := range vv { diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 518cd10a..04d268a5 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -10,6 +10,7 @@ import ( "net/http" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" @@ -29,11 +30,16 @@ func NewOpenAIGatewayHandler( gatewayService *service.OpenAIGatewayService, concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService, + cfg *config.Config, ) *OpenAIGatewayHandler { + pingInterval := time.Duration(0) + if cfg != nil { + pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second + } return &OpenAIGatewayHandler{ gatewayService: gatewayService, billingCacheService: billingCacheService, - concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatNone), + concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), } } @@ -124,6 +130,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { h.handleConcurrencyError(c, err, "user", streamStarted) return } + // 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏 + userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) if userReleaseFunc != nil { defer userReleaseFunc() } @@ -131,7 +139,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // 2. Re-check billing eligibility after wait if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { log.Printf("Billing eligibility check failed after wait: %v", err) - h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted) + status, code, message := billingErrorDetails(err) + h.handleStreamingAwareError(c, status, code, message, streamStarted) return } @@ -201,6 +210,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { log.Printf("Bind sticky session failed: %v", err) } } + // 账号槽位/等待计数需要在超时或断开时安全回收 + accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease) // Forward request result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) diff --git a/backend/internal/pkg/httpclient/pool.go b/backend/internal/pkg/httpclient/pool.go index 8a81c09a..7bf5cff4 100644 --- a/backend/internal/pkg/httpclient/pool.go +++ b/backend/internal/pkg/httpclient/pool.go @@ -25,13 +25,14 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" + "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" ) // Transport 连接池默认配置 const ( defaultMaxIdleConns = 100 // 最大空闲连接数 defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数 - defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间 + defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间(建议小于上游 LB 超时) ) // Options 定义共享 HTTP 客户端的构建参数 @@ -40,6 +41,9 @@ type Options struct { Timeout time.Duration // 请求总超时时间 ResponseHeaderTimeout time.Duration // 等待响应头超时时间 InsecureSkipVerify bool // 是否跳过 TLS 证书验证 + ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退 + ValidateResolvedIP bool // 是否校验解析后的 IP(防止 DNS Rebinding) + AllowPrivateHosts bool // 允许私有地址解析(与 ValidateResolvedIP 一起使用) // 可选的连接池参数(不设置则使用默认值) MaxIdleConns int // 最大空闲连接总数(默认 100) @@ -79,8 +83,12 @@ func buildClient(opts Options) (*http.Client, error) { return nil, err } + var rt http.RoundTripper = transport + if opts.ValidateResolvedIP && !opts.AllowPrivateHosts { + rt = &validatedTransport{base: transport} + } return &http.Client{ - Transport: transport, + Transport: rt, Timeout: opts.Timeout, }, nil } @@ -126,13 +134,32 @@ func buildTransport(opts Options) (*http.Transport, error) { } func buildClientKey(opts Options) string { - return fmt.Sprintf("%s|%s|%s|%t|%d|%d|%d", + return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%t|%d|%d|%d", strings.TrimSpace(opts.ProxyURL), opts.Timeout.String(), opts.ResponseHeaderTimeout.String(), opts.InsecureSkipVerify, + opts.ProxyStrict, + opts.ValidateResolvedIP, + opts.AllowPrivateHosts, opts.MaxIdleConns, opts.MaxIdleConnsPerHost, opts.MaxConnsPerHost, ) } + +type validatedTransport struct { + base http.RoundTripper +} + +func (t *validatedTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if req != nil && req.URL != nil { + host := strings.TrimSpace(req.URL.Hostname()) + if host != "" { + if err := urlvalidator.ValidateResolvedIP(host); err != nil { + return nil, err + } + } + } + return t.base.RoundTrip(req) +} diff --git a/backend/internal/repository/claude_oauth_service.go b/backend/internal/repository/claude_oauth_service.go index 35e7f535..677fce52 100644 --- a/backend/internal/repository/claude_oauth_service.go +++ b/backend/internal/repository/claude_oauth_service.go @@ -12,6 +12,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/Wei-Shaw/sub2api/internal/util/logredact" "github.com/imroc/req/v3" ) @@ -54,7 +55,7 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey return "", fmt.Errorf("request failed: %w", err) } - log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String()) + log.Printf("[OAuth] Step 1 Response - Status: %d", resp.StatusCode) if !resp.IsSuccessState() { return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String()) @@ -84,8 +85,8 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe "code_challenge_method": "S256", } - reqBodyJSON, _ := json.Marshal(reqBody) log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL) + reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody)) log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON)) var result struct { @@ -113,7 +114,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe return "", fmt.Errorf("request failed: %w", err) } - log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String()) + log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes())) if !resp.IsSuccessState() { return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String()) @@ -141,7 +142,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe fullCode = authCode + "#" + responseState } - log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", prefix(authCode, 20)) + log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code") return fullCode, nil } @@ -173,8 +174,8 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds } - reqBodyJSON, _ := json.Marshal(reqBody) log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL) + reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody)) log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON)) var tokenResp oauth.TokenResponse @@ -191,7 +192,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod return nil, fmt.Errorf("request failed: %w", err) } - log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String()) + log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes())) if !resp.IsSuccessState() { return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String()) @@ -245,13 +246,3 @@ func createReqClient(proxyURL string) *req.Client { return client } - -func prefix(s string, n int) string { - if n <= 0 { - return "" - } - if len(s) <= n { - return s - } - return s[:n] -} diff --git a/backend/internal/repository/claude_usage_service.go b/backend/internal/repository/claude_usage_service.go index 424d1a9a..4c87b2de 100644 --- a/backend/internal/repository/claude_usage_service.go +++ b/backend/internal/repository/claude_usage_service.go @@ -15,7 +15,8 @@ import ( const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage" type claudeUsageService struct { - usageURL string + usageURL string + allowPrivateHosts bool } func NewClaudeUsageFetcher() service.ClaudeUsageFetcher { @@ -24,8 +25,10 @@ func NewClaudeUsageFetcher() service.ClaudeUsageFetcher { func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) { client, err := httpclient.GetClient(httpclient.Options{ - ProxyURL: proxyURL, - Timeout: 30 * time.Second, + ProxyURL: proxyURL, + Timeout: 30 * time.Second, + ValidateResolvedIP: true, + AllowPrivateHosts: s.allowPrivateHosts, }) if err != nil { client = &http.Client{Timeout: 30 * time.Second} diff --git a/backend/internal/repository/claude_usage_service_test.go b/backend/internal/repository/claude_usage_service_test.go index c3570076..2e10f3e5 100644 --- a/backend/internal/repository/claude_usage_service_test.go +++ b/backend/internal/repository/claude_usage_service_test.go @@ -45,7 +45,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() { }`) })) - s.fetcher = &claudeUsageService{usageURL: s.srv.URL} + s.fetcher = &claudeUsageService{ + usageURL: s.srv.URL, + allowPrivateHosts: true, + } resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url") require.NoError(s.T(), err, "FetchUsage") @@ -64,7 +67,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() { _, _ = io.WriteString(w, "nope") })) - s.fetcher = &claudeUsageService{usageURL: s.srv.URL} + s.fetcher = &claudeUsageService{ + usageURL: s.srv.URL, + allowPrivateHosts: true, + } _, err := s.fetcher.FetchUsage(context.Background(), "at", "") require.Error(s.T(), err) @@ -78,7 +84,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() { _, _ = io.WriteString(w, "not-json") })) - s.fetcher = &claudeUsageService{usageURL: s.srv.URL} + s.fetcher = &claudeUsageService{ + usageURL: s.srv.URL, + allowPrivateHosts: true, + } _, err := s.fetcher.FetchUsage(context.Background(), "at", "") require.Error(s.T(), err) @@ -91,7 +100,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() { <-r.Context().Done() })) - s.fetcher = &claudeUsageService{usageURL: s.srv.URL} + s.fetcher = &claudeUsageService{ + usageURL: s.srv.URL, + allowPrivateHosts: true, + } ctx, cancel := context.WithCancel(context.Background()) cancel() // Cancel immediately diff --git a/backend/internal/repository/github_release_service.go b/backend/internal/repository/github_release_service.go index 3fa4b1ff..dd53c091 100644 --- a/backend/internal/repository/github_release_service.go +++ b/backend/internal/repository/github_release_service.go @@ -14,18 +14,23 @@ import ( ) type githubReleaseClient struct { - httpClient *http.Client + httpClient *http.Client + allowPrivateHosts bool } func NewGitHubReleaseClient() service.GitHubReleaseClient { + allowPrivate := false sharedClient, err := httpclient.GetClient(httpclient.Options{ - Timeout: 30 * time.Second, + Timeout: 30 * time.Second, + ValidateResolvedIP: true, + AllowPrivateHosts: allowPrivate, }) if err != nil { sharedClient = &http.Client{Timeout: 30 * time.Second} } return &githubReleaseClient{ - httpClient: sharedClient, + httpClient: sharedClient, + allowPrivateHosts: allowPrivate, } } @@ -64,7 +69,9 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string } downloadClient, err := httpclient.GetClient(httpclient.Options{ - Timeout: 10 * time.Minute, + Timeout: 10 * time.Minute, + ValidateResolvedIP: true, + AllowPrivateHosts: c.allowPrivateHosts, }) if err != nil { downloadClient = &http.Client{Timeout: 10 * time.Minute} diff --git a/backend/internal/repository/github_release_service_test.go b/backend/internal/repository/github_release_service_test.go index ea849d46..4eebe81d 100644 --- a/backend/internal/repository/github_release_service_test.go +++ b/backend/internal/repository/github_release_service_test.go @@ -37,6 +37,13 @@ func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) { return http.DefaultTransport.RoundTrip(newReq) } +func newTestGitHubReleaseClient() *githubReleaseClient { + return &githubReleaseClient{ + httpClient: &http.Client{}, + allowPrivateHosts: true, + } +} + func (s *GitHubReleaseServiceSuite) SetupTest() { s.tempDir = s.T().TempDir() } @@ -55,9 +62,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLeng _, _ = w.Write(bytes.Repeat([]byte("a"), 100)) })) - client, ok := NewGitHubReleaseClient().(*githubReleaseClient) - require.True(s.T(), ok, "type assertion failed") - s.client = client + s.client = newTestGitHubReleaseClient() dest := filepath.Join(s.tempDir, "file1.bin") err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10) @@ -82,9 +87,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() { } })) - client, ok := NewGitHubReleaseClient().(*githubReleaseClient) - require.True(s.T(), ok, "type assertion failed") - s.client = client + s.client = newTestGitHubReleaseClient() dest := filepath.Join(s.tempDir, "file2.bin") err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10) @@ -108,9 +111,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() { } })) - client, ok := NewGitHubReleaseClient().(*githubReleaseClient) - require.True(s.T(), ok, "type assertion failed") - s.client = client + s.client = newTestGitHubReleaseClient() dest := filepath.Join(s.tempDir, "file3.bin") err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 200) @@ -127,9 +128,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() { w.WriteHeader(http.StatusNotFound) })) - client, ok := NewGitHubReleaseClient().(*githubReleaseClient) - require.True(s.T(), ok, "type assertion failed") - s.client = client + s.client = newTestGitHubReleaseClient() dest := filepath.Join(s.tempDir, "notfound.bin") err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100) @@ -145,9 +144,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() { _, _ = w.Write([]byte("sum")) })) - client, ok := NewGitHubReleaseClient().(*githubReleaseClient) - require.True(s.T(), ok, "type assertion failed") - s.client = client + s.client = newTestGitHubReleaseClient() body, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL) require.NoError(s.T(), err, "FetchChecksumFile") @@ -159,9 +156,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() { w.WriteHeader(http.StatusInternalServerError) })) - client, ok := NewGitHubReleaseClient().(*githubReleaseClient) - require.True(s.T(), ok, "type assertion failed") - s.client = client + s.client = newTestGitHubReleaseClient() _, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL) require.Error(s.T(), err, "expected error for non-200") @@ -172,9 +167,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() { <-r.Context().Done() })) - client, ok := NewGitHubReleaseClient().(*githubReleaseClient) - require.True(s.T(), ok, "type assertion failed") - s.client = client + s.client = newTestGitHubReleaseClient() ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -185,9 +178,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() { } func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() { - client, ok := NewGitHubReleaseClient().(*githubReleaseClient) - require.True(s.T(), ok, "type assertion failed") - s.client = client + s.client = newTestGitHubReleaseClient() dest := filepath.Join(s.tempDir, "invalid.bin") err := s.client.DownloadFile(context.Background(), "://invalid-url", dest, 100) @@ -200,9 +191,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() { _, _ = w.Write([]byte("content")) })) - client, ok := NewGitHubReleaseClient().(*githubReleaseClient) - require.True(s.T(), ok, "type assertion failed") - s.client = client + s.client = newTestGitHubReleaseClient() // Use a path that cannot be created (directory doesn't exist) dest := filepath.Join(s.tempDir, "nonexistent", "subdir", "file.bin") @@ -211,9 +200,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() { } func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_InvalidURL() { - client, ok := NewGitHubReleaseClient().(*githubReleaseClient) - require.True(s.T(), ok, "type assertion failed") - s.client = client + s.client = newTestGitHubReleaseClient() _, err := s.client.FetchChecksumFile(context.Background(), "://invalid-url") require.Error(s.T(), err, "expected error for invalid URL") @@ -247,6 +234,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() { httpClient: &http.Client{ Transport: &testTransport{testServerURL: s.srv.URL}, }, + allowPrivateHosts: true, } release, err := s.client.FetchLatestRelease(context.Background(), "test/repo") @@ -266,6 +254,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() { httpClient: &http.Client{ Transport: &testTransport{testServerURL: s.srv.URL}, }, + allowPrivateHosts: true, } _, err := s.client.FetchLatestRelease(context.Background(), "test/repo") @@ -283,6 +272,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() { httpClient: &http.Client{ Transport: &testTransport{testServerURL: s.srv.URL}, }, + allowPrivateHosts: true, } _, err := s.client.FetchLatestRelease(context.Background(), "test/repo") @@ -298,6 +288,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() { httpClient: &http.Client{ Transport: &testTransport{testServerURL: s.srv.URL}, }, + allowPrivateHosts: true, } ctx, cancel := context.WithCancel(context.Background()) @@ -312,9 +303,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() { <-r.Context().Done() })) - client, ok := NewGitHubReleaseClient().(*githubReleaseClient) - require.True(s.T(), ok, "type assertion failed") - s.client = client + s.client = newTestGitHubReleaseClient() ctx, cancel := context.WithCancel(context.Background()) cancel() diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go index f0669979..21723d4a 100644 --- a/backend/internal/repository/http_upstream.go +++ b/backend/internal/repository/http_upstream.go @@ -15,6 +15,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" ) // 默认配置常量 @@ -30,9 +31,9 @@ const ( // defaultMaxConnsPerHost: 默认每主机最大连接数(含活跃连接) // 达到上限后新请求会等待,而非无限创建连接 defaultMaxConnsPerHost = 240 - // defaultIdleConnTimeout: 默认空闲连接超时时间(5分钟) - // 超时后连接会被关闭,释放系统资源 - defaultIdleConnTimeout = 300 * time.Second + // defaultIdleConnTimeout: 默认空闲连接超时时间(90秒) + // 超时后连接会被关闭,释放系统资源(建议小于上游 LB 超时) + defaultIdleConnTimeout = 90 * time.Second // defaultResponseHeaderTimeout: 默认等待响应头超时时间(5分钟) // LLM 请求可能排队较久,需要较长超时 defaultResponseHeaderTimeout = 300 * time.Second @@ -120,6 +121,10 @@ func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream { // - 调用方必须关闭 resp.Body,否则会导致 inFlight 计数泄漏 // - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断 func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { + if err := s.validateRequestHost(req); err != nil { + return nil, err + } + // 获取或创建对应的客户端,并标记请求占用 entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency) if err != nil { @@ -145,6 +150,37 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i return resp, nil } +func (s *httpUpstreamService) shouldValidateResolvedIP() bool { + if s.cfg == nil { + return false + } + return !s.cfg.Security.URLAllowlist.AllowPrivateHosts +} + +func (s *httpUpstreamService) validateRequestHost(req *http.Request) error { + if !s.shouldValidateResolvedIP() { + return nil + } + if req == nil || req.URL == nil { + return errors.New("request url is nil") + } + host := strings.TrimSpace(req.URL.Hostname()) + if host == "" { + return errors.New("request host is empty") + } + if err := urlvalidator.ValidateResolvedIP(host); err != nil { + return err + } + return nil +} + +func (s *httpUpstreamService) redirectChecker(req *http.Request, via []*http.Request) error { + if len(via) >= 10 { + return errors.New("stopped after 10 redirects") + } + return s.validateRequestHost(req) +} + // acquireClient 获取或创建客户端,并标记为进行中请求 // 用于请求路径,避免在获取后被淘汰 func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) { @@ -232,6 +268,9 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a return nil, fmt.Errorf("build transport: %w", err) } client := &http.Client{Transport: transport} + if s.shouldValidateResolvedIP() { + client.CheckRedirect = s.redirectChecker + } entry := &upstreamClientEntry{ client: client, proxyKey: proxyKey, diff --git a/backend/internal/repository/http_upstream_test.go b/backend/internal/repository/http_upstream_test.go index 241b490f..fbe44c5e 100644 --- a/backend/internal/repository/http_upstream_test.go +++ b/backend/internal/repository/http_upstream_test.go @@ -22,7 +22,13 @@ type HTTPUpstreamSuite struct { // SetupTest 每个测试用例执行前的初始化 // 创建空配置,各测试用例可按需覆盖 func (s *HTTPUpstreamSuite) SetupTest() { - s.cfg = &config.Config{} + s.cfg = &config.Config{ + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{ + AllowPrivateHosts: true, + }, + }, + } } // newService 创建测试用的 httpUpstreamService 实例 diff --git a/backend/internal/repository/pricing_service.go b/backend/internal/repository/pricing_service.go index 11f82fd3..0a6d0cd9 100644 --- a/backend/internal/repository/pricing_service.go +++ b/backend/internal/repository/pricing_service.go @@ -8,6 +8,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/service" ) @@ -16,9 +17,15 @@ type pricingRemoteClient struct { httpClient *http.Client } -func NewPricingRemoteClient() service.PricingRemoteClient { +func NewPricingRemoteClient(cfg *config.Config) service.PricingRemoteClient { + allowPrivate := false + if cfg != nil { + allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts + } sharedClient, err := httpclient.GetClient(httpclient.Options{ - Timeout: 30 * time.Second, + Timeout: 30 * time.Second, + ValidateResolvedIP: true, + AllowPrivateHosts: allowPrivate, }) if err != nil { sharedClient = &http.Client{Timeout: 30 * time.Second} diff --git a/backend/internal/repository/pricing_service_test.go b/backend/internal/repository/pricing_service_test.go index 112c7eaa..6745ac58 100644 --- a/backend/internal/repository/pricing_service_test.go +++ b/backend/internal/repository/pricing_service_test.go @@ -6,6 +6,7 @@ import ( "net/http/httptest" "testing" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -19,7 +20,13 @@ type PricingServiceSuite struct { func (s *PricingServiceSuite) SetupTest() { s.ctx = context.Background() - client, ok := NewPricingRemoteClient().(*pricingRemoteClient) + client, ok := NewPricingRemoteClient(&config.Config{ + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{ + AllowPrivateHosts: true, + }, + }, + }).(*pricingRemoteClient) require.True(s.T(), ok, "type assertion failed") s.client = client } diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go index f5f625f9..b49b4efb 100644 --- a/backend/internal/repository/proxy_probe_service.go +++ b/backend/internal/repository/proxy_probe_service.go @@ -5,28 +5,48 @@ import ( "encoding/json" "fmt" "io" + "log" "net/http" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/service" ) -func NewProxyExitInfoProber() service.ProxyExitInfoProber { - return &proxyProbeService{ipInfoURL: defaultIPInfoURL} +func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber { + insecure := false + allowPrivate := false + if cfg != nil { + insecure = cfg.Security.ProxyProbe.InsecureSkipVerify + allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts + } + if insecure { + log.Printf("[ProxyProbe] Warning: TLS verification is disabled for proxy probing.") + } + return &proxyProbeService{ + ipInfoURL: defaultIPInfoURL, + insecureSkipVerify: insecure, + allowPrivateHosts: allowPrivate, + } } const defaultIPInfoURL = "https://ipinfo.io/json" type proxyProbeService struct { - ipInfoURL string + ipInfoURL string + insecureSkipVerify bool + allowPrivateHosts bool } func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) { client, err := httpclient.GetClient(httpclient.Options{ ProxyURL: proxyURL, Timeout: 15 * time.Second, - InsecureSkipVerify: true, + InsecureSkipVerify: s.insecureSkipVerify, + ProxyStrict: true, + ValidateResolvedIP: true, + AllowPrivateHosts: s.allowPrivateHosts, }) if err != nil { return nil, 0, fmt.Errorf("failed to create proxy client: %w", err) diff --git a/backend/internal/repository/proxy_probe_service_test.go b/backend/internal/repository/proxy_probe_service_test.go index e7270324..fe45adbb 100644 --- a/backend/internal/repository/proxy_probe_service_test.go +++ b/backend/internal/repository/proxy_probe_service_test.go @@ -20,7 +20,10 @@ type ProxyProbeServiceSuite struct { func (s *ProxyProbeServiceSuite) SetupTest() { s.ctx = context.Background() - s.prober = &proxyProbeService{ipInfoURL: "http://ipinfo.test/json"} + s.prober = &proxyProbeService{ + ipInfoURL: "http://ipinfo.test/json", + allowPrivateHosts: true, + } } func (s *ProxyProbeServiceSuite) TearDownTest() { diff --git a/backend/internal/repository/turnstile_service.go b/backend/internal/repository/turnstile_service.go index cf6083e2..89748cd3 100644 --- a/backend/internal/repository/turnstile_service.go +++ b/backend/internal/repository/turnstile_service.go @@ -22,7 +22,8 @@ type turnstileVerifier struct { func NewTurnstileVerifier() service.TurnstileVerifier { sharedClient, err := httpclient.GetClient(httpclient.Options{ - Timeout: 10 * time.Second, + Timeout: 10 * time.Second, + ValidateResolvedIP: true, }) if err != nil { sharedClient = &http.Client{Timeout: 10 * time.Second} diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 6b8af91e..8a469661 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -294,13 +294,13 @@ func TestAPIContracts(t *testing.T) { "smtp_host": "smtp.example.com", "smtp_port": 587, "smtp_username": "user", - "smtp_password": "secret", + "smtp_password_configured": true, "smtp_from_email": "no-reply@example.com", "smtp_from_name": "Sub2API", "smtp_use_tls": true, "turnstile_enabled": true, "turnstile_site_key": "site-key", - "turnstile_secret_key": "secret-key", + "turnstile_secret_key_configured": true, "site_name": "Sub2API", "site_logo": "", "site_subtitle": "Subtitle", diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go index 0239410a..a8740ecc 100644 --- a/backend/internal/server/http.go +++ b/backend/internal/server/http.go @@ -2,6 +2,7 @@ package server import ( + "log" "net/http" "time" @@ -36,6 +37,15 @@ func ProvideRouter( r := gin.New() r.Use(middleware2.Recovery()) + if len(cfg.Server.TrustedProxies) > 0 { + if err := r.SetTrustedProxies(cfg.Server.TrustedProxies); err != nil { + log.Printf("Failed to set trusted proxies: %v", err) + } + } else { + if err := r.SetTrustedProxies(nil); err != nil { + log.Printf("Failed to disable trusted proxies: %v", err) + } + } return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg) } diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index eb8c2aff..74ff8af3 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -19,6 +19,13 @@ func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionS // apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证) func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc { return func(c *gin.Context) { + queryKey := strings.TrimSpace(c.Query("key")) + queryApiKey := strings.TrimSpace(c.Query("api_key")) + if queryKey != "" || queryApiKey != "" { + AbortWithError(c, 400, "api_key_in_query_deprecated", "API key in query parameter is deprecated. Please use Authorization header instead.") + return + } + // 尝试从Authorization header中提取API key (Bearer scheme) authHeader := c.GetHeader("Authorization") var apiKeyString string @@ -41,19 +48,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti apiKeyString = c.GetHeader("x-goog-api-key") } - // 如果header中没有,尝试从query参数中提取(Google API key风格) - if apiKeyString == "" { - apiKeyString = c.Query("key") - } - - // 兼容常见别名 - if apiKeyString == "" { - apiKeyString = c.Query("api_key") - } - // 如果所有header都没有API key if apiKeyString == "" { - AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter") + AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, or x-goog-api-key header") return } diff --git a/backend/internal/server/middleware/api_key_auth_google.go b/backend/internal/server/middleware/api_key_auth_google.go index f2796db4..c5afd7ef 100644 --- a/backend/internal/server/middleware/api_key_auth_google.go +++ b/backend/internal/server/middleware/api_key_auth_google.go @@ -22,6 +22,10 @@ func APIKeyAuthGoogle(apiKeyService *service.APIKeyService, cfg *config.Config) // It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations. func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc { return func(c *gin.Context) { + if v := strings.TrimSpace(c.Query("api_key")); v != "" { + abortWithGoogleError(c, 400, "Query parameter api_key is deprecated. Use Authorization header or key instead.") + return + } apiKeyString := extractAPIKeyFromRequest(c) if apiKeyString == "" { abortWithGoogleError(c, 401, "API key is required") @@ -116,15 +120,18 @@ func extractAPIKeyFromRequest(c *gin.Context) string { if v := strings.TrimSpace(c.GetHeader("x-goog-api-key")); v != "" { return v } - if v := strings.TrimSpace(c.Query("key")); v != "" { - return v - } - if v := strings.TrimSpace(c.Query("api_key")); v != "" { - return v + if allowGoogleQueryKey(c.Request.URL.Path) { + if v := strings.TrimSpace(c.Query("key")); v != "" { + return v + } } return "" } +func allowGoogleQueryKey(path string) bool { + return strings.HasPrefix(path, "/v1beta") || strings.HasPrefix(path, "/antigravity/v1beta") +} + func abortWithGoogleError(c *gin.Context, status int, message string) { c.JSON(status, gin.H{ "error": gin.H{ diff --git a/backend/internal/server/middleware/api_key_auth_google_test.go b/backend/internal/server/middleware/api_key_auth_google_test.go index 96dcdbb3..0ed5a4a2 100644 --- a/backend/internal/server/middleware/api_key_auth_google_test.go +++ b/backend/internal/server/middleware/api_key_auth_google_test.go @@ -109,6 +109,58 @@ func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) { require.Equal(t, "UNAUTHENTICATED", resp.Error.Status) } +func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + return nil, errors.New("should not be called") + }, + }) + r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) + r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) + + req := httptest.NewRequest(http.MethodGet, "/v1beta/test?api_key=legacy", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) + var resp googleErrorResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, http.StatusBadRequest, resp.Error.Code) + require.Equal(t, "Query parameter api_key is deprecated. Use Authorization header or key instead.", resp.Error.Message) + require.Equal(t, "INVALID_ARGUMENT", resp.Error.Status) +} + +func TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + return &service.APIKey{ + ID: 1, + Key: key, + Status: service.StatusActive, + User: &service.User{ + ID: 123, + Status: service.StatusActive, + }, + }, nil + }, + }) + cfg := &config.Config{RunMode: config.RunModeSimple} + r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)) + r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) + + req := httptest.NewRequest(http.MethodGet, "/v1beta/test?key=valid", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) +} + func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/server/middleware/cors.go b/backend/internal/server/middleware/cors.go index bc16279f..7d82f183 100644 --- a/backend/internal/server/middleware/cors.go +++ b/backend/internal/server/middleware/cors.go @@ -1,24 +1,103 @@ package middleware import ( + "log" + "net/http" + "strings" + "sync" + + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/gin-gonic/gin" ) +var corsWarningOnce sync.Once + // CORS 跨域中间件 -func CORS() gin.HandlerFunc { +func CORS(cfg config.CORSConfig) gin.HandlerFunc { + allowedOrigins := normalizeOrigins(cfg.AllowedOrigins) + allowAll := false + for _, origin := range allowedOrigins { + if origin == "*" { + allowAll = true + break + } + } + wildcardWithSpecific := allowAll && len(allowedOrigins) > 1 + if wildcardWithSpecific { + allowedOrigins = []string{"*"} + } + allowCredentials := cfg.AllowCredentials + + corsWarningOnce.Do(func() { + if len(allowedOrigins) == 0 { + log.Println("Warning: CORS allowed_origins not configured; cross-origin requests will be rejected.") + } + if wildcardWithSpecific { + log.Println("Warning: CORS allowed_origins includes '*'; wildcard will take precedence over explicit origins.") + } + if allowAll && allowCredentials { + log.Println("Warning: CORS allowed_origins set to '*', disabling allow_credentials.") + } + }) + if allowAll && allowCredentials { + allowCredentials = false + } + + allowedSet := make(map[string]struct{}, len(allowedOrigins)) + for _, origin := range allowedOrigins { + if origin == "" || origin == "*" { + continue + } + allowedSet[origin] = struct{}{} + } + return func(c *gin.Context) { - // 设置允许跨域的响应头 - c.Writer.Header().Set("Access-Control-Allow-Origin", "*") - c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") + origin := strings.TrimSpace(c.GetHeader("Origin")) + originAllowed := allowAll + if origin != "" && !allowAll { + _, originAllowed = allowedSet[origin] + } + + if originAllowed { + if allowAll { + c.Writer.Header().Set("Access-Control-Allow-Origin", "*") + } else if origin != "" { + c.Writer.Header().Set("Access-Control-Allow-Origin", origin) + c.Writer.Header().Add("Vary", "Origin") + } + if allowCredentials { + c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") + } + } + c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key") c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH") // 处理预检请求 - if c.Request.Method == "OPTIONS" { - c.AbortWithStatus(204) + if c.Request.Method == http.MethodOptions { + if originAllowed { + c.AbortWithStatus(http.StatusNoContent) + } else { + c.AbortWithStatus(http.StatusForbidden) + } return } c.Next() } } + +func normalizeOrigins(values []string) []string { + if len(values) == 0 { + return nil + } + normalized := make([]string, 0, len(values)) + for _, value := range values { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + continue + } + normalized = append(normalized, trimmed) + } + return normalized +} diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go new file mode 100644 index 00000000..9fca0cd3 --- /dev/null +++ b/backend/internal/server/middleware/security_headers.go @@ -0,0 +1,26 @@ +package middleware + +import ( + "strings" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" +) + +// SecurityHeaders sets baseline security headers for all responses. +func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc { + policy := strings.TrimSpace(cfg.Policy) + if policy == "" { + policy = config.DefaultCSPPolicy + } + + return func(c *gin.Context) { + c.Header("X-Content-Type-Options", "nosniff") + c.Header("X-Frame-Options", "DENY") + c.Header("Referrer-Policy", "strict-origin-when-cross-origin") + if cfg.Enabled { + c.Header("Content-Security-Policy", policy) + } + c.Next() + } +} diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go index 6eebb6d8..15a1b325 100644 --- a/backend/internal/server/router.go +++ b/backend/internal/server/router.go @@ -24,7 +24,8 @@ func SetupRouter( ) *gin.Engine { // 应用中间件 r.Use(middleware2.Logger()) - r.Use(middleware2.CORS()) + r.Use(middleware2.CORS(cfg.CORS)) + r.Use(middleware2.SecurityHeaders(cfg.Security.CSP)) // Serve embedded frontend if available if web.HasEmbeddedFrontend() { diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 6573be4b..e49da48f 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -7,6 +7,7 @@ import ( "crypto/rand" "encoding/hex" "encoding/json" + "errors" "fmt" "io" "log" @@ -14,9 +15,11 @@ import ( "regexp" "strings" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/gin-gonic/gin" "github.com/google/uuid" ) @@ -45,6 +48,7 @@ type AccountTestService struct { geminiTokenProvider *GeminiTokenProvider antigravityGatewayService *AntigravityGatewayService httpUpstream HTTPUpstream + cfg *config.Config } // NewAccountTestService creates a new AccountTestService @@ -53,15 +57,32 @@ func NewAccountTestService( geminiTokenProvider *GeminiTokenProvider, antigravityGatewayService *AntigravityGatewayService, httpUpstream HTTPUpstream, + cfg *config.Config, ) *AccountTestService { return &AccountTestService{ accountRepo: accountRepo, geminiTokenProvider: geminiTokenProvider, antigravityGatewayService: antigravityGatewayService, httpUpstream: httpUpstream, + cfg: cfg, } } +func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error) { + if s.cfg == nil { + return "", errors.New("config is not available") + } + normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{ + AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts, + RequireAllowlist: true, + AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts, + }) + if err != nil { + return "", err + } + return normalized, nil +} + // generateSessionString generates a Claude Code style session string func generateSessionString() (string, error) { bytes := make([]byte, 32) @@ -183,11 +204,15 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account return s.sendErrorAndEnd(c, "No API key available") } - apiURL = account.GetBaseURL() - if apiURL == "" { - apiURL = "https://api.anthropic.com" + baseURL := account.GetBaseURL() + if baseURL == "" { + baseURL = "https://api.anthropic.com" } - apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages" + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error())) + } + apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages" } else { return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) } @@ -300,7 +325,11 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account if baseURL == "" { baseURL = "https://api.openai.com" } - apiURL = strings.TrimSuffix(baseURL, "/") + "/responses" + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error())) + } + apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/responses" } else { return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) } @@ -480,10 +509,14 @@ func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, accou if baseURL == "" { baseURL = geminicli.AIStudioBaseURL } + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } // Use streamGenerateContent for real-time feedback fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", - strings.TrimRight(baseURL, "/"), modelID) + strings.TrimRight(normalizedBaseURL, "/"), modelID) req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(payload)) if err != nil { @@ -515,7 +548,11 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun if strings.TrimSpace(baseURL) == "" { baseURL = geminicli.AIStudioBaseURL } - fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(baseURL, "/"), modelID) + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(normalizedBaseURL, "/"), modelID) req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload)) if err != nil { @@ -544,7 +581,11 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT } wrappedBytes, _ := json.Marshal(wrapped) - fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", geminicli.GeminiCliBaseURL) + normalizedBaseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL) + if err != nil { + return nil, err + } + fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", normalizedBaseURL) req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(wrappedBytes)) if err != nil { diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 7776e4c3..7763bc40 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -11,6 +11,7 @@ import ( "log" "net/http" "strings" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" @@ -1103,57 +1104,145 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context return nil, errors.New("streaming not supported") } - reader := bufio.NewReader(resp.Body) + // 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.settingService.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 64*1024), maxLineSize) usage := &ClaudeUsage{} var firstTokenMs *int + type scanEvent struct { + line string + err error + } + // 独立 goroutine 读取上游,避免读取阻塞影响超时处理 + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + go func() { + defer close(events) + for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }() + defer close(done) + + // 上游数据间隔超时保护(防止上游挂起长期占用连接) + streamInterval := time.Duration(0) + if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + + // 仅发送一次错误事件,避免多次写入导致协议混乱 + errorEventSent := false + sendErrorEvent := func(reason string) { + if errorEventSent { + return + } + errorEventSent = true + _, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason) + flusher.Flush() + } + for { - line, err := reader.ReadString('\n') - if len(line) > 0 { + select { + case ev, ok := <-events: + if !ok { + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil + } + if ev.err != nil { + if errors.Is(ev.err, bufio.ErrTooLong) { + log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err) + sendErrorEvent("response_too_large") + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err + } + sendErrorEvent("stream_read_error") + return nil, ev.err + } + + line := ev.line trimmed := strings.TrimRight(line, "\r\n") if strings.HasPrefix(trimmed, "data:") { payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:")) if payload == "" || payload == "[DONE]" { - _, _ = io.WriteString(c.Writer, line) - flusher.Flush() - } else { - // 解包 v1internal 响应 - inner, parseErr := s.unwrapV1InternalResponse([]byte(payload)) - if parseErr == nil && inner != nil { - payload = string(inner) - } - - // 解析 usage - var parsed map[string]any - if json.Unmarshal(inner, &parsed) == nil { - if u := extractGeminiUsage(parsed); u != nil { - usage = u - } - } - - if firstTokenMs == nil { - ms := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &ms - } - - _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", payload) + if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil { + sendErrorEvent("write_failed") + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err + } flusher.Flush() + continue } - } else { - _, _ = io.WriteString(c.Writer, line) - flusher.Flush() - } - } - if errors.Is(err, io.EOF) { - break - } - if err != nil { - return nil, err + // 解包 v1internal 响应 + inner, parseErr := s.unwrapV1InternalResponse([]byte(payload)) + if parseErr == nil && inner != nil { + payload = string(inner) + } + + // 解析 usage + var parsed map[string]any + if json.Unmarshal(inner, &parsed) == nil { + if u := extractGeminiUsage(parsed); u != nil { + usage = u + } + } + + if firstTokenMs == nil { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + + if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", payload); err != nil { + sendErrorEvent("write_failed") + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err + } + flusher.Flush() + continue + } + + if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil { + sendErrorEvent("write_failed") + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err + } + flusher.Flush() + + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + log.Printf("Stream data interval timeout (antigravity)") + sendErrorEvent("stream_timeout") + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") } } - - return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil } func (s *AntigravityGatewayService) handleGeminiNonStreamingResponse(c *gin.Context, resp *http.Response) (*ClaudeUsage, error) { @@ -1292,7 +1381,13 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context processor := antigravity.NewStreamingProcessor(originalModel) var firstTokenMs *int - reader := bufio.NewReader(resp.Body) + // 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.settingService.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 64*1024), maxLineSize) // 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage { @@ -1307,13 +1402,85 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context } } - for { - line, err := reader.ReadString('\n') - if err != nil && !errors.Is(err, io.EOF) { - return nil, fmt.Errorf("stream read error: %w", err) + type scanEvent struct { + line string + err error + } + // 独立 goroutine 读取上游,避免读取阻塞影响超时处理 + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false } + } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + go func() { + defer close(events) + for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }() + defer close(done) - if len(line) > 0 { + streamInterval := time.Duration(0) + if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + + // 仅发送一次错误事件,避免多次写入导致协议混乱 + errorEventSent := false + sendErrorEvent := func(reason string) { + if errorEventSent { + return + } + errorEventSent = true + _, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason) + flusher.Flush() + } + + for { + select { + case ev, ok := <-events: + if !ok { + // 发送结束事件 + finalEvents, agUsage := processor.Finish() + if len(finalEvents) > 0 { + _, _ = c.Writer.Write(finalEvents) + flusher.Flush() + } + return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil + } + if ev.err != nil { + if errors.Is(ev.err, bufio.ErrTooLong) { + log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err) + sendErrorEvent("response_too_large") + return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, ev.err + } + sendErrorEvent("stream_read_error") + return nil, fmt.Errorf("stream read error: %w", ev.err) + } + + line := ev.line // 处理 SSE 行,转换为 Claude 格式 claudeEvents := processor.ProcessLine(strings.TrimRight(line, "\r\n")) @@ -1328,23 +1495,21 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context if len(finalEvents) > 0 { _, _ = c.Writer.Write(finalEvents) } + sendErrorEvent("write_failed") return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, writeErr } flusher.Flush() } - } - if errors.Is(err, io.EOF) { - break + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + log.Printf("Stream data interval timeout (antigravity)") + sendErrorEvent("stream_timeout") + return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") } } - // 发送结束事件 - finalEvents, agUsage := processor.Finish() - if len(finalEvents) > 0 { - _, _ = c.Writer.Write(finalEvents) - flusher.Flush() - } - - return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil } diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 69765520..91551314 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -221,9 +221,33 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S // VerifyTurnstile 验证Turnstile token func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error { + required := s.cfg != nil && s.cfg.Server.Mode == "release" && s.cfg.Turnstile.Required + + if required { + if s.settingService == nil { + log.Println("[Auth] Turnstile required but settings service is not configured") + return ErrTurnstileNotConfigured + } + enabled := s.settingService.IsTurnstileEnabled(ctx) + secretConfigured := s.settingService.GetTurnstileSecretKey(ctx) != "" + if !enabled || !secretConfigured { + log.Printf("[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)", enabled, secretConfigured) + return ErrTurnstileNotConfigured + } + } + if s.turnstileService == nil { + if required { + log.Println("[Auth] Turnstile required but service not configured") + return ErrTurnstileNotConfigured + } return nil // 服务未配置则跳过验证 } + + if !required && s.settingService != nil && s.settingService.IsTurnstileEnabled(ctx) && s.settingService.GetTurnstileSecretKey(ctx) == "" { + log.Println("[Auth] Turnstile enabled but secret key not configured") + } + return s.turnstileService.VerifyToken(ctx, token, remoteIP) } diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index 86148b37..c09cafb9 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -16,7 +16,8 @@ import ( // 注:ErrInsufficientBalance在redeem_service.go中定义 // 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义 var ( - ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired") + ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired") + ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.") ) // subscriptionCacheData 订阅缓存数据结构(内部使用) @@ -72,10 +73,11 @@ type cacheWriteTask struct { // BillingCacheService 计费缓存服务 // 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查 type BillingCacheService struct { - cache BillingCache - userRepo UserRepository - subRepo UserSubscriptionRepository - cfg *config.Config + cache BillingCache + userRepo UserRepository + subRepo UserSubscriptionRepository + cfg *config.Config + circuitBreaker *billingCircuitBreaker cacheWriteChan chan cacheWriteTask cacheWriteWg sync.WaitGroup @@ -95,6 +97,7 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo subRepo: subRepo, cfg: cfg, } + svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker) svc.startCacheWriteWorkers() return svc } @@ -450,6 +453,9 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user if s.cfg.RunMode == config.RunModeSimple { return nil } + if s.circuitBreaker != nil && !s.circuitBreaker.Allow() { + return ErrBillingServiceUnavailable + } // 判断计费模式 isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil @@ -465,9 +471,14 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userID int64) error { balance, err := s.GetUserBalance(ctx, userID) if err != nil { - // 缓存/数据库错误,允许通过(降级处理) - log.Printf("Warning: get user balance failed, allowing request: %v", err) - return nil + if s.circuitBreaker != nil { + s.circuitBreaker.OnFailure(err) + } + log.Printf("ALERT: billing balance check failed for user %d: %v", userID, err) + return ErrBillingServiceUnavailable.WithCause(err) + } + if s.circuitBreaker != nil { + s.circuitBreaker.OnSuccess() } if balance <= 0 { @@ -482,9 +493,14 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, // 获取订阅缓存数据 subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID) if err != nil { - // 缓存/数据库错误,降级使用传入的subscription进行检查 - log.Printf("Warning: get subscription cache failed, using fallback: %v", err) - return s.checkSubscriptionLimitsFallback(subscription, group) + if s.circuitBreaker != nil { + s.circuitBreaker.OnFailure(err) + } + log.Printf("ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err) + return ErrBillingServiceUnavailable.WithCause(err) + } + if s.circuitBreaker != nil { + s.circuitBreaker.OnSuccess() } // 检查订阅状态 @@ -513,27 +529,133 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, return nil } -// checkSubscriptionLimitsFallback 降级检查订阅限额 -func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *UserSubscription, group *Group) error { - if subscription == nil { - return ErrSubscriptionInvalid - } +type billingCircuitBreakerState int - if !subscription.IsActive() { - return ErrSubscriptionInvalid - } +const ( + billingCircuitClosed billingCircuitBreakerState = iota + billingCircuitOpen + billingCircuitHalfOpen +) - if !subscription.CheckDailyLimit(group, 0) { - return ErrDailyLimitExceeded - } - - if !subscription.CheckWeeklyLimit(group, 0) { - return ErrWeeklyLimitExceeded - } - - if !subscription.CheckMonthlyLimit(group, 0) { - return ErrMonthlyLimitExceeded - } - - return nil +type billingCircuitBreaker struct { + mu sync.Mutex + state billingCircuitBreakerState + failures int + openedAt time.Time + failureThreshold int + resetTimeout time.Duration + halfOpenRequests int + halfOpenRemaining int +} + +func newBillingCircuitBreaker(cfg config.CircuitBreakerConfig) *billingCircuitBreaker { + if !cfg.Enabled { + return nil + } + resetTimeout := time.Duration(cfg.ResetTimeoutSeconds) * time.Second + if resetTimeout <= 0 { + resetTimeout = 30 * time.Second + } + halfOpen := cfg.HalfOpenRequests + if halfOpen <= 0 { + halfOpen = 1 + } + threshold := cfg.FailureThreshold + if threshold <= 0 { + threshold = 5 + } + return &billingCircuitBreaker{ + state: billingCircuitClosed, + failureThreshold: threshold, + resetTimeout: resetTimeout, + halfOpenRequests: halfOpen, + } +} + +func (b *billingCircuitBreaker) Allow() bool { + b.mu.Lock() + defer b.mu.Unlock() + + switch b.state { + case billingCircuitClosed: + return true + case billingCircuitOpen: + if time.Since(b.openedAt) < b.resetTimeout { + return false + } + b.state = billingCircuitHalfOpen + b.halfOpenRemaining = b.halfOpenRequests + log.Printf("ALERT: billing circuit breaker entering half-open state") + fallthrough + case billingCircuitHalfOpen: + if b.halfOpenRemaining <= 0 { + return false + } + b.halfOpenRemaining-- + return true + default: + return false + } +} + +func (b *billingCircuitBreaker) OnFailure(err error) { + if b == nil { + return + } + b.mu.Lock() + defer b.mu.Unlock() + + switch b.state { + case billingCircuitOpen: + return + case billingCircuitHalfOpen: + b.state = billingCircuitOpen + b.openedAt = time.Now() + b.halfOpenRemaining = 0 + log.Printf("ALERT: billing circuit breaker opened after half-open failure: %v", err) + return + default: + b.failures++ + if b.failures >= b.failureThreshold { + b.state = billingCircuitOpen + b.openedAt = time.Now() + b.halfOpenRemaining = 0 + log.Printf("ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err) + } + } +} + +func (b *billingCircuitBreaker) OnSuccess() { + if b == nil { + return + } + b.mu.Lock() + defer b.mu.Unlock() + + previousState := b.state + previousFailures := b.failures + + b.state = billingCircuitClosed + b.failures = 0 + b.halfOpenRemaining = 0 + + // 只有状态真正发生变化时才记录日志 + if previousState != billingCircuitClosed { + log.Printf("ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState)) + } else if previousFailures > 0 { + log.Printf("INFO: billing circuit breaker failures reset from %d", previousFailures) + } +} + +func circuitStateString(state billingCircuitBreakerState) string { + switch state { + case billingCircuitClosed: + return "closed" + case billingCircuitOpen: + return "open" + case billingCircuitHalfOpen: + return "half-open" + default: + return "unknown" + } } diff --git a/backend/internal/service/crs_sync_service.go b/backend/internal/service/crs_sync_service.go index 1bf5a11e..759034e7 100644 --- a/backend/internal/service/crs_sync_service.go +++ b/backend/internal/service/crs_sync_service.go @@ -8,12 +8,13 @@ import ( "fmt" "io" "net/http" - "net/url" "strconv" "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" + "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" ) type CRSSyncService struct { @@ -22,6 +23,7 @@ type CRSSyncService struct { oauthService *OAuthService openaiOAuthService *OpenAIOAuthService geminiOAuthService *GeminiOAuthService + cfg *config.Config } func NewCRSSyncService( @@ -30,6 +32,7 @@ func NewCRSSyncService( oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, geminiOAuthService *GeminiOAuthService, + cfg *config.Config, ) *CRSSyncService { return &CRSSyncService{ accountRepo: accountRepo, @@ -37,6 +40,7 @@ func NewCRSSyncService( oauthService: oauthService, openaiOAuthService: openaiOAuthService, geminiOAuthService: geminiOAuthService, + cfg: cfg, } } @@ -187,7 +191,10 @@ type crsGeminiAPIKeyAccount struct { } func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) { - baseURL, err := normalizeBaseURL(input.BaseURL) + if s.cfg == nil { + return nil, errors.New("config is not available") + } + baseURL, err := normalizeBaseURL(input.BaseURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts) if err != nil { return nil, err } @@ -196,7 +203,9 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } client, err := httpclient.GetClient(httpclient.Options{ - Timeout: 20 * time.Second, + Timeout: 20 * time.Second, + ValidateResolvedIP: true, + AllowPrivateHosts: s.cfg.Security.URLAllowlist.AllowPrivateHosts, }) if err != nil { client = &http.Client{Timeout: 20 * time.Second} @@ -1055,17 +1064,18 @@ func mapCRSStatus(isActive bool, status string) string { return "active" } -func normalizeBaseURL(raw string) (string, error) { - trimmed := strings.TrimSpace(raw) - if trimmed == "" { - return "", errors.New("base_url is required") +func normalizeBaseURL(raw string, allowlist []string, allowPrivate bool) (string, error) { + // 当 allowlist 为空时,不强制要求白名单(只进行基本的 URL 和 SSRF 验证) + requireAllowlist := len(allowlist) > 0 + normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{ + AllowedHosts: allowlist, + RequireAllowlist: requireAllowlist, + AllowPrivate: allowPrivate, + }) + if err != nil { + return "", fmt.Errorf("invalid base_url: %w", err) } - u, err := url.Parse(trimmed) - if err != nil || u.Scheme == "" || u.Host == "" { - return "", fmt.Errorf("invalid base_url: %s", trimmed) - } - u.Path = strings.TrimRight(u.Path, "/") - return strings.TrimRight(u.String(), "/"), nil + return normalized, nil } // cleanBaseURL removes trailing suffix from base_url in credentials diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index c706fb80..5d39c01d 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -15,11 +15,14 @@ import ( "regexp" "sort" "strings" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -30,6 +33,7 @@ const ( claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" stickySessionTTL = time.Hour // 粘性会话TTL + defaultMaxLineSize = 10 * 1024 * 1024 claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude." ) @@ -1342,7 +1346,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex targetURL := claudeAPIURL if account.Type == AccountTypeAPIKey { baseURL := account.GetBaseURL() - targetURL = baseURL + "/v1/messages" + if baseURL != "" { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + targetURL = validatedURL + "/v1/messages" + } } // OAuth账号:应用统一指纹 @@ -1711,51 +1721,133 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http var firstTokenMs *int scanner := bufio.NewScanner(resp.Body) // 设置更大的buffer以处理长行 - scanner.Buffer(make([]byte, 64*1024), 1024*1024) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 64*1024), maxLineSize) + + type scanEvent struct { + line string + err error + } + // 独立 goroutine 读取上游,避免读取阻塞导致超时/keepalive无法处理 + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + go func() { + defer close(events) + for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }() + defer close(done) + + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + // 仅监控上游数据间隔超时,避免下游写入阻塞导致误判 + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + + // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端) + errorEventSent := false + sendErrorEvent := func(reason string) { + if errorEventSent { + return + } + errorEventSent = true + _, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason) + flusher.Flush() + } needModelReplace := originalModel != mappedModel - for scanner.Scan() { - line := scanner.Text() - if line == "event: error" { - return nil, errors.New("have error in stream") - } - - // Extract data from SSE line (supports both "data: " and "data:" formats) - if sseDataRe.MatchString(line) { - data := sseDataRe.ReplaceAllString(line, "") - - // 如果有模型映射,替换响应中的model字段 - if needModelReplace { - line = s.replaceModelInSSELine(line, mappedModel, originalModel) + for { + select { + case ev, ok := <-events: + if !ok { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil + } + if ev.err != nil { + if errors.Is(ev.err, bufio.ErrTooLong) { + log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) + sendErrorEvent("response_too_large") + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err + } + sendErrorEvent("stream_read_error") + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) + } + line := ev.line + if line == "event: error" { + return nil, errors.New("have error in stream") } - // 转发行 - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err - } - flusher.Flush() + // Extract data from SSE line (supports both "data: " and "data:" formats) + if sseDataRe.MatchString(line) { + data := sseDataRe.ReplaceAllString(line, "") - // 记录首字时间:第一个有效的 content_block_delta 或 message_start - if firstTokenMs == nil && data != "" && data != "[DONE]" { - ms := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &ms + // 如果有模型映射,替换响应中的model字段 + if needModelReplace { + line = s.replaceModelInSSELine(line, mappedModel, originalModel) + } + + // 转发行 + if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { + sendErrorEvent("write_failed") + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err + } + flusher.Flush() + + // 记录首字时间:第一个有效的 content_block_delta 或 message_start + if firstTokenMs == nil && data != "" && data != "[DONE]" { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + s.parseSSEUsage(data, usage) + } else { + // 非 data 行直接转发 + if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { + sendErrorEvent("write_failed") + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err + } + flusher.Flush() } - s.parseSSEUsage(data, usage) - } else { - // 非 data 行直接转发 - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err + + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue } - flusher.Flush() + log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) + sendErrorEvent("stream_timeout") + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") } } - if err := scanner.Err(); err != nil { - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err) - } - - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil } // replaceModelInSSELine 替换SSE数据行中的model字段 @@ -1860,12 +1952,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h body = s.replaceModelInResponseBody(body, mappedModel, originalModel) } - // 透传响应头 - for key, values := range resp.Header { - for _, value := range values { - c.Header(key, value) - } - } + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) // 写入响应 c.Data(resp.StatusCode, "application/json", body) @@ -2137,7 +2224,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con targetURL := claudeAPICountTokensURL if account.Type == AccountTypeAPIKey { baseURL := account.GetBaseURL() - targetURL = baseURL + "/v1/messages/count_tokens" + if baseURL != "" { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + targetURL = validatedURL + "/v1/messages/count_tokens" + } } // OAuth 账号:应用统一指纹和重写 userID @@ -2217,6 +2310,18 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m }) } +func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) { + normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{ + AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts, + RequireAllowlist: true, + AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts, + }) + if err != nil { + return "", fmt.Errorf("invalid base_url: %w", err) + } + return normalized, nil +} + // GetAvailableModels returns the list of models available for a group // It aggregates model_mapping keys from all schedulable accounts in the group func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string { diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 99e5bdf3..38050eab 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -18,9 +18,12 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/gin-gonic/gin" ) @@ -41,6 +44,7 @@ type GeminiMessagesCompatService struct { rateLimitService *RateLimitService httpUpstream HTTPUpstream antigravityGatewayService *AntigravityGatewayService + cfg *config.Config } func NewGeminiMessagesCompatService( @@ -51,6 +55,7 @@ func NewGeminiMessagesCompatService( rateLimitService *RateLimitService, httpUpstream HTTPUpstream, antigravityGatewayService *AntigravityGatewayService, + cfg *config.Config, ) *GeminiMessagesCompatService { return &GeminiMessagesCompatService{ accountRepo: accountRepo, @@ -60,6 +65,7 @@ func NewGeminiMessagesCompatService( rateLimitService: rateLimitService, httpUpstream: httpUpstream, antigravityGatewayService: antigravityGatewayService, + cfg: cfg, } } @@ -230,6 +236,18 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit return s.antigravityGatewayService } +func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) { + normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{ + AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts, + RequireAllowlist: true, + AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts, + }) + if err != nil { + return "", fmt.Errorf("invalid base_url: %w", err) + } + return normalized, nil +} + // HasAntigravityAccounts 检查是否有可用的 antigravity 账户 func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) { var accounts []Account @@ -382,16 +400,20 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex return nil, "", errors.New("gemini api_key not configured") } - baseURL := strings.TrimRight(account.GetCredential("base_url"), "/") + baseURL := strings.TrimSpace(account.GetCredential("base_url")) if baseURL == "" { baseURL = geminicli.AIStudioBaseURL } + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, "", err + } action := "generateContent" if req.Stream { action = "streamGenerateContent" } - fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, action) + fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action) if req.Stream { fullURL += "?alt=sse" } @@ -428,7 +450,11 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex // 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token) if projectID != "" { // Mode 1: Code Assist API - fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, action) + baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL) + if err != nil { + return nil, "", err + } + fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), action) if useUpstreamStream { fullURL += "?alt=sse" } @@ -454,12 +480,16 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex return upstreamReq, "x-request-id", nil } else { // Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token) - baseURL := strings.TrimRight(account.GetCredential("base_url"), "/") + baseURL := strings.TrimSpace(account.GetCredential("base_url")) if baseURL == "" { baseURL = geminicli.AIStudioBaseURL } + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, "", err + } - fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, action) + fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action) if useUpstreamStream { fullURL += "?alt=sse" } @@ -700,12 +730,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. return nil, "", errors.New("gemini api_key not configured") } - baseURL := strings.TrimRight(account.GetCredential("base_url"), "/") + baseURL := strings.TrimSpace(account.GetCredential("base_url")) if baseURL == "" { baseURL = geminicli.AIStudioBaseURL } + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, "", err + } - fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, upstreamAction) + fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction) if useUpstreamStream { fullURL += "?alt=sse" } @@ -737,7 +771,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. // 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token) if projectID != "" && !forceAIStudio { // Mode 1: Code Assist API - fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, upstreamAction) + baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL) + if err != nil { + return nil, "", err + } + fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), upstreamAction) if useUpstreamStream { fullURL += "?alt=sse" } @@ -763,12 +801,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. return upstreamReq, "x-request-id", nil } else { // Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token) - baseURL := strings.TrimRight(account.GetCredential("base_url"), "/") + baseURL := strings.TrimSpace(account.GetCredential("base_url")) if baseURL == "" { baseURL = geminicli.AIStudioBaseURL } + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, "", err + } - fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, upstreamAction) + fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction) if useUpstreamStream { fullURL += "?alt=sse" } @@ -1702,6 +1744,8 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co _ = json.Unmarshal(respBody, &parsed) } + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + contentType := resp.Header.Get("Content-Type") if contentType == "" { contentType = "application/json" @@ -1823,11 +1867,15 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac return nil, errors.New("invalid path") } - baseURL := strings.TrimRight(account.GetCredential("base_url"), "/") + baseURL := strings.TrimSpace(account.GetCredential("base_url")) if baseURL == "" { baseURL = geminicli.AIStudioBaseURL } - fullURL := strings.TrimRight(baseURL, "/") + path + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + fullURL := strings.TrimRight(normalizedBaseURL, "/") + path var proxyURL string if account.ProxyID != nil && account.Proxy != nil { @@ -1866,9 +1914,14 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac defer func() { _ = resp.Body.Close() }() body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20)) + wwwAuthenticate := resp.Header.Get("Www-Authenticate") + filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.cfg.Security.ResponseHeaders) + if wwwAuthenticate != "" { + filteredHeaders.Set("Www-Authenticate", wwwAuthenticate) + } return &UpstreamHTTPResult{ StatusCode: resp.StatusCode, - Headers: resp.Header.Clone(), + Headers: filteredHeaders, Body: body, }, nil } diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index 1dce9da9..48d31da9 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -1000,8 +1000,9 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent) client, err := httpclient.GetClient(httpclient.Options{ - ProxyURL: strings.TrimSpace(proxyURL), - Timeout: 30 * time.Second, + ProxyURL: strings.TrimSpace(proxyURL), + Timeout: 30 * time.Second, + ValidateResolvedIP: true, }) if err != nil { client = &http.Client{Timeout: 30 * time.Second} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 02f58369..b9cf4b9e 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -16,9 +16,12 @@ import ( "sort" "strconv" "strings" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/gin-gonic/gin" ) @@ -630,10 +633,14 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. case AccountTypeAPIKey: // API Key accounts use Platform API or custom base URL baseURL := account.GetOpenAIBaseURL() - if baseURL != "" { - targetURL = baseURL + "/responses" - } else { + if baseURL == "" { targetURL = openaiPlatformAPIURL + } else { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + targetURL = validatedURL + "/responses" } default: targetURL = openaiPlatformAPIURL @@ -775,48 +782,158 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp usage := &OpenAIUsage{} var firstTokenMs *int scanner := bufio.NewScanner(resp.Body) - scanner.Buffer(make([]byte, 64*1024), 1024*1024) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 64*1024), maxLineSize) + + type scanEvent struct { + line string + err error + } + // 独立 goroutine 读取上游,避免读取阻塞影响 keepalive/超时处理 + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + go func() { + defer close(events) + for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }() + defer close(done) + + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + // 仅监控上游数据间隔超时,不被下游写入阻塞影响 + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + + keepaliveInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 { + keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second + } + // 下游 keepalive 仅用于防止代理空闲断开 + var keepaliveTicker *time.Ticker + if keepaliveInterval > 0 { + keepaliveTicker = time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + } + var keepaliveCh <-chan time.Time + if keepaliveTicker != nil { + keepaliveCh = keepaliveTicker.C + } + // 记录上次收到上游数据的时间,用于控制 keepalive 发送频率 + lastDataAt := time.Now() + + // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端) + errorEventSent := false + sendErrorEvent := func(reason string) { + if errorEventSent { + return + } + errorEventSent = true + _, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason) + flusher.Flush() + } needModelReplace := originalModel != mappedModel - for scanner.Scan() { - line := scanner.Text() - - // Extract data from SSE line (supports both "data: " and "data:" formats) - if openaiSSEDataRe.MatchString(line) { - data := openaiSSEDataRe.ReplaceAllString(line, "") - - // Replace model in response if needed - if needModelReplace { - line = s.replaceModelInSSELine(line, mappedModel, originalModel) + for { + select { + case ev, ok := <-events: + if !ok { + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil + } + if ev.err != nil { + if errors.Is(ev.err, bufio.ErrTooLong) { + log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) + sendErrorEvent("response_too_large") + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err + } + sendErrorEvent("stream_read_error") + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) } - // Forward line - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err - } - flusher.Flush() + line := ev.line + lastDataAt = time.Now() - // Record first token time - if firstTokenMs == nil && data != "" && data != "[DONE]" { - ms := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &ms + // Extract data from SSE line (supports both "data: " and "data:" formats) + if openaiSSEDataRe.MatchString(line) { + data := openaiSSEDataRe.ReplaceAllString(line, "") + + // Replace model in response if needed + if needModelReplace { + line = s.replaceModelInSSELine(line, mappedModel, originalModel) + } + + // Forward line + if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { + sendErrorEvent("write_failed") + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err + } + flusher.Flush() + + // Record first token time + if firstTokenMs == nil && data != "" && data != "[DONE]" { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + s.parseSSEUsage(data, usage) + } else { + // Forward non-data lines as-is + if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { + sendErrorEvent("write_failed") + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err + } + flusher.Flush() } - s.parseSSEUsage(data, usage) - } else { - // Forward non-data lines as-is - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { + + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) + sendErrorEvent("stream_timeout") + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") + + case <-keepaliveCh: + if time.Since(lastDataAt) < keepaliveInterval { + continue + } + if _, err := fmt.Fprint(w, ":\n\n"); err != nil { return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err } flusher.Flush() } } - if err := scanner.Err(); err != nil { - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err) - } - - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil } func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string { @@ -911,18 +1028,25 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r body = s.replaceModelInResponseBody(body, mappedModel, originalModel) } - // Pass through headers - for key, values := range resp.Header { - for _, value := range values { - c.Header(key, value) - } - } + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) c.Data(resp.StatusCode, "application/json", body) return usage, nil } +func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) { + normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{ + AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts, + RequireAllowlist: true, + AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts, + }) + if err != nil { + return "", fmt.Errorf("invalid base_url: %w", err) + } + return normalized, nil +} + func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte { var resp map[string]any if err := json.Unmarshal(body, &resp); err != nil { diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go new file mode 100644 index 00000000..bcad7ac8 --- /dev/null +++ b/backend/internal/service/openai_gateway_service_test.go @@ -0,0 +1,90 @@ +package service + +import ( + "bufio" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" +) + +func TestOpenAIStreamingTimeout(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 1, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + start := time.Now() + _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, start, "model", "model") + _ = pw.Close() + _ = pr.Close() + + if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") { + t.Fatalf("expected stream timeout error, got %v", err) + } + if !strings.Contains(rec.Body.String(), "stream_timeout") { + t.Fatalf("expected stream_timeout SSE error, got %q", rec.Body.String()) + } +} + +func TestOpenAIStreamingTooLong(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: 64 * 1024, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + go func() { + defer func() { _ = pw.Close() }() + // 写入超过 MaxLineSize 的单行数据,触发 ErrTooLong + payload := "data: " + strings.Repeat("a", 128*1024) + "\n" + _, _ = pw.Write([]byte(payload)) + }() + + _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 2}, time.Now(), "model", "model") + _ = pr.Close() + + if !errors.Is(err, bufio.ErrTooLong) { + t.Fatalf("expected ErrTooLong, got %v", err) + } + if !strings.Contains(rec.Body.String(), "response_too_large") { + t.Fatalf("expected response_too_large SSE error, got %q", rec.Body.String()) + } +} diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index bb050d0a..58a24c0d 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -16,6 +16,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" ) var ( @@ -211,16 +212,35 @@ func (s *PricingService) syncWithRemote() error { // downloadPricingData 从远程下载价格数据 func (s *PricingService) downloadPricingData() error { - log.Printf("[Pricing] Downloading from %s", s.cfg.Pricing.RemoteURL) + remoteURL, err := s.validatePricingURL(s.cfg.Pricing.RemoteURL) + if err != nil { + return err + } + log.Printf("[Pricing] Downloading from %s", remoteURL) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - body, err := s.remoteClient.FetchPricingJSON(ctx, s.cfg.Pricing.RemoteURL) + var expectedHash string + if strings.TrimSpace(s.cfg.Pricing.HashURL) != "" { + expectedHash, err = s.fetchRemoteHash() + if err != nil { + return fmt.Errorf("fetch remote hash: %w", err) + } + } + + body, err := s.remoteClient.FetchPricingJSON(ctx, remoteURL) if err != nil { return fmt.Errorf("download failed: %w", err) } + if expectedHash != "" { + actualHash := sha256.Sum256(body) + if !strings.EqualFold(expectedHash, hex.EncodeToString(actualHash[:])) { + return fmt.Errorf("pricing hash mismatch") + } + } + // 解析JSON数据(使用灵活的解析方式) data, err := s.parsePricingData(body) if err != nil { @@ -373,10 +393,31 @@ func (s *PricingService) useFallbackPricing() error { // fetchRemoteHash 从远程获取哈希值 func (s *PricingService) fetchRemoteHash() (string, error) { + hashURL, err := s.validatePricingURL(s.cfg.Pricing.HashURL) + if err != nil { + return "", err + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - return s.remoteClient.FetchHashText(ctx, s.cfg.Pricing.HashURL) + hash, err := s.remoteClient.FetchHashText(ctx, hashURL) + if err != nil { + return "", err + } + return strings.TrimSpace(hash), nil +} + +func (s *PricingService) validatePricingURL(raw string) (string, error) { + normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{ + AllowedHosts: s.cfg.Security.URLAllowlist.PricingHosts, + RequireAllowlist: true, + AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts, + }) + if err != nil { + return "", fmt.Errorf("invalid pricing url: %w", err) + } + return normalized, nil } // computeFileHash 计算文件哈希 diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index a331594e..6ce8ba2b 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -228,21 +228,23 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // parseSettings 解析设置到结构体 func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings { result := &SystemSettings{ - RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", - EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true", - SMTPHost: settings[SettingKeySMTPHost], - SMTPUsername: settings[SettingKeySMTPUsername], - SMTPFrom: settings[SettingKeySMTPFrom], - SMTPFromName: settings[SettingKeySMTPFromName], - SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true", - TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", - TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], - SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), - SiteLogo: settings[SettingKeySiteLogo], - SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), - APIBaseURL: settings[SettingKeyAPIBaseURL], - ContactInfo: settings[SettingKeyContactInfo], - DocURL: settings[SettingKeyDocURL], + RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", + EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true", + SMTPHost: settings[SettingKeySMTPHost], + SMTPUsername: settings[SettingKeySMTPUsername], + SMTPFrom: settings[SettingKeySMTPFrom], + SMTPFromName: settings[SettingKeySMTPFromName], + SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true", + SMTPPasswordConfigured: settings[SettingKeySMTPPassword] != "", + TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", + TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], + TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "", + SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), + SiteLogo: settings[SettingKeySiteLogo], + SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), + APIBaseURL: settings[SettingKeyAPIBaseURL], + ContactInfo: settings[SettingKeyContactInfo], + DocURL: settings[SettingKeyDocURL], } // 解析整数类型 diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 1fba5e13..de0331f7 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -4,17 +4,19 @@ type SystemSettings struct { RegistrationEnabled bool EmailVerifyEnabled bool - SMTPHost string - SMTPPort int - SMTPUsername string - SMTPPassword string - SMTPFrom string - SMTPFromName string - SMTPUseTLS bool + SMTPHost string + SMTPPort int + SMTPUsername string + SMTPPassword string + SMTPPasswordConfigured bool + SMTPFrom string + SMTPFromName string + SMTPUseTLS bool - TurnstileEnabled bool - TurnstileSiteKey string - TurnstileSecretKey string + TurnstileEnabled bool + TurnstileSiteKey string + TurnstileSecretKey string + TurnstileSecretKeyConfigured bool SiteName string SiteLogo string diff --git a/backend/internal/setup/setup.go b/backend/internal/setup/setup.go index 435f6289..ad077735 100644 --- a/backend/internal/setup/setup.go +++ b/backend/internal/setup/setup.go @@ -9,6 +9,7 @@ import ( "log" "os" "strconv" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/repository" @@ -196,11 +197,17 @@ func Install(cfg *SetupConfig) error { // Generate JWT secret if not provided if cfg.JWT.Secret == "" { + if strings.EqualFold(cfg.Server.Mode, "release") { + return fmt.Errorf("jwt secret is required in release mode") + } secret, err := generateSecret(32) if err != nil { return fmt.Errorf("failed to generate jwt secret: %w", err) } cfg.JWT.Secret = secret + log.Println("Warning: JWT secret auto-generated for non-release mode. Do not use in production.") + } else if strings.EqualFold(cfg.Server.Mode, "release") && len(cfg.JWT.Secret) < 32 { + return fmt.Errorf("jwt secret must be at least 32 characters in release mode") } // Test connections @@ -474,12 +481,17 @@ func AutoSetupFromEnv() error { // Generate JWT secret if not provided if cfg.JWT.Secret == "" { + if strings.EqualFold(cfg.Server.Mode, "release") { + return fmt.Errorf("jwt secret is required in release mode") + } secret, err := generateSecret(32) if err != nil { return fmt.Errorf("failed to generate jwt secret: %w", err) } cfg.JWT.Secret = secret - log.Println("Generated JWT secret automatically") + log.Println("Warning: JWT secret auto-generated for non-release mode. Do not use in production.") + } else if strings.EqualFold(cfg.Server.Mode, "release") && len(cfg.JWT.Secret) < 32 { + return fmt.Errorf("jwt secret must be at least 32 characters in release mode") } // Generate admin password if not provided @@ -489,8 +501,8 @@ func AutoSetupFromEnv() error { return fmt.Errorf("failed to generate admin password: %w", err) } cfg.Admin.Password = password - log.Printf("Generated admin password: %s", cfg.Admin.Password) - log.Println("IMPORTANT: Save this password! It will not be shown again.") + fmt.Printf("Generated admin password (one-time): %s\n", cfg.Admin.Password) + fmt.Println("IMPORTANT: Save this password! It will not be shown again.") } // Test database connection diff --git a/backend/internal/util/logredact/redact.go b/backend/internal/util/logredact/redact.go new file mode 100644 index 00000000..b2d2429f --- /dev/null +++ b/backend/internal/util/logredact/redact.go @@ -0,0 +1,100 @@ +package logredact + +import ( + "encoding/json" + "strings" +) + +// maxRedactDepth 限制递归深度以防止栈溢出 +const maxRedactDepth = 32 + +var defaultSensitiveKeys = map[string]struct{}{ + "authorization_code": {}, + "code": {}, + "code_verifier": {}, + "access_token": {}, + "refresh_token": {}, + "id_token": {}, + "client_secret": {}, + "password": {}, +} + +func RedactMap(input map[string]any, extraKeys ...string) map[string]any { + if input == nil { + return map[string]any{} + } + keys := buildKeySet(extraKeys) + redacted, ok := redactValueWithDepth(input, keys, 0).(map[string]any) + if !ok { + return map[string]any{} + } + return redacted +} + +func RedactJSON(raw []byte, extraKeys ...string) string { + if len(raw) == 0 { + return "" + } + var value any + if err := json.Unmarshal(raw, &value); err != nil { + return "" + } + keys := buildKeySet(extraKeys) + redacted := redactValueWithDepth(value, keys, 0) + encoded, err := json.Marshal(redacted) + if err != nil { + return "" + } + return string(encoded) +} + +func buildKeySet(extraKeys []string) map[string]struct{} { + keys := make(map[string]struct{}, len(defaultSensitiveKeys)+len(extraKeys)) + for k := range defaultSensitiveKeys { + keys[k] = struct{}{} + } + for _, key := range extraKeys { + normalized := normalizeKey(key) + if normalized == "" { + continue + } + keys[normalized] = struct{}{} + } + return keys +} + +func redactValueWithDepth(value any, keys map[string]struct{}, depth int) any { + if depth > maxRedactDepth { + return "" + } + + switch v := value.(type) { + case map[string]any: + out := make(map[string]any, len(v)) + for k, val := range v { + if isSensitiveKey(k, keys) { + out[k] = "***" + continue + } + out[k] = redactValueWithDepth(val, keys, depth+1) + } + return out + case []any: + out := make([]any, len(v)) + for i, item := range v { + out[i] = redactValueWithDepth(item, keys, depth+1) + } + return out + default: + return value + } +} + +func isSensitiveKey(key string, keys map[string]struct{}) bool { + _, ok := keys[normalizeKey(key)] + return ok +} + +func normalizeKey(key string) string { + return strings.ToLower(strings.TrimSpace(key)) +} diff --git a/backend/internal/util/responseheaders/responseheaders.go b/backend/internal/util/responseheaders/responseheaders.go new file mode 100644 index 00000000..53fc03bc --- /dev/null +++ b/backend/internal/util/responseheaders/responseheaders.go @@ -0,0 +1,93 @@ +package responseheaders + +import ( + "net/http" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +// defaultAllowed 定义允许透传的响应头白名单 +// 注意:以下头部由 Go HTTP 包自动处理,不应手动设置: +// - content-length: 由 ResponseWriter 根据实际写入数据自动设置 +// - transfer-encoding: 由 HTTP 库根据需要自动添加/移除 +// - connection: 由 HTTP 库管理连接复用 +var defaultAllowed = map[string]struct{}{ + "content-type": {}, + "content-encoding": {}, + "content-language": {}, + "cache-control": {}, + "etag": {}, + "last-modified": {}, + "expires": {}, + "vary": {}, + "date": {}, + "x-request-id": {}, + "x-ratelimit-limit-requests": {}, + "x-ratelimit-limit-tokens": {}, + "x-ratelimit-remaining-requests": {}, + "x-ratelimit-remaining-tokens": {}, + "x-ratelimit-reset-requests": {}, + "x-ratelimit-reset-tokens": {}, + "retry-after": {}, + "location": {}, + "www-authenticate": {}, +} + +// hopByHopHeaders 是跳过的 hop-by-hop 头部,这些头部由 HTTP 库自动处理 +var hopByHopHeaders = map[string]struct{}{ + "content-length": {}, + "transfer-encoding": {}, + "connection": {}, +} + +func FilterHeaders(src http.Header, cfg config.ResponseHeaderConfig) http.Header { + allowed := make(map[string]struct{}, len(defaultAllowed)+len(cfg.AdditionalAllowed)) + for key := range defaultAllowed { + allowed[key] = struct{}{} + } + for _, key := range cfg.AdditionalAllowed { + normalized := strings.ToLower(strings.TrimSpace(key)) + if normalized == "" { + continue + } + allowed[normalized] = struct{}{} + } + + forceRemove := make(map[string]struct{}, len(cfg.ForceRemove)) + for _, key := range cfg.ForceRemove { + normalized := strings.ToLower(strings.TrimSpace(key)) + if normalized == "" { + continue + } + forceRemove[normalized] = struct{}{} + } + + filtered := make(http.Header, len(src)) + for key, values := range src { + lower := strings.ToLower(key) + if _, blocked := forceRemove[lower]; blocked { + continue + } + if _, ok := allowed[lower]; !ok { + continue + } + // 跳过 hop-by-hop 头部,这些由 HTTP 库自动处理 + if _, isHopByHop := hopByHopHeaders[lower]; isHopByHop { + continue + } + for _, value := range values { + filtered.Add(key, value) + } + } + return filtered +} + +func WriteFilteredHeaders(dst http.Header, src http.Header, cfg config.ResponseHeaderConfig) { + filtered := FilterHeaders(src, cfg) + for key, values := range filtered { + for _, value := range values { + dst.Add(key, value) + } + } +} diff --git a/backend/internal/util/urlvalidator/validator.go b/backend/internal/util/urlvalidator/validator.go new file mode 100644 index 00000000..b8f8c72f --- /dev/null +++ b/backend/internal/util/urlvalidator/validator.go @@ -0,0 +1,121 @@ +package urlvalidator + +import ( + "context" + "errors" + "fmt" + "net" + "net/url" + "strings" + "time" +) + +type ValidationOptions struct { + AllowedHosts []string + RequireAllowlist bool + AllowPrivate bool +} + +func ValidateHTTPSURL(raw string, opts ValidationOptions) (string, error) { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "", errors.New("url is required") + } + + parsed, err := url.Parse(trimmed) + if err != nil || parsed.Scheme == "" || parsed.Host == "" { + return "", fmt.Errorf("invalid url: %s", trimmed) + } + if !strings.EqualFold(parsed.Scheme, "https") { + return "", fmt.Errorf("invalid url scheme: %s", parsed.Scheme) + } + + host := strings.ToLower(strings.TrimSpace(parsed.Hostname())) + if host == "" { + return "", errors.New("invalid host") + } + if !opts.AllowPrivate && isBlockedHost(host) { + return "", fmt.Errorf("host is not allowed: %s", host) + } + + allowlist := normalizeAllowlist(opts.AllowedHosts) + if opts.RequireAllowlist && len(allowlist) == 0 { + return "", errors.New("allowlist is not configured") + } + if len(allowlist) > 0 && !isAllowedHost(host, allowlist) { + return "", fmt.Errorf("host is not allowed: %s", host) + } + + parsed.Path = strings.TrimRight(parsed.Path, "/") + parsed.RawPath = "" + return strings.TrimRight(parsed.String(), "/"), nil +} + +// ValidateResolvedIP 验证 DNS 解析后的 IP 地址是否安全 +// 用于防止 DNS Rebinding 攻击:在实际 HTTP 请求时调用此函数验证解析后的 IP +func ValidateResolvedIP(host string) error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host) + if err != nil { + return fmt.Errorf("dns resolution failed: %w", err) + } + + for _, ip := range ips { + if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || + ip.IsLinkLocalMulticast() || ip.IsUnspecified() { + return fmt.Errorf("resolved ip %s is not allowed", ip.String()) + } + } + return nil +} + +func normalizeAllowlist(values []string) []string { + if len(values) == 0 { + return nil + } + normalized := make([]string, 0, len(values)) + for _, v := range values { + entry := strings.ToLower(strings.TrimSpace(v)) + if entry == "" { + continue + } + if host, _, err := net.SplitHostPort(entry); err == nil { + entry = host + } + normalized = append(normalized, entry) + } + return normalized +} + +func isAllowedHost(host string, allowlist []string) bool { + for _, entry := range allowlist { + if entry == "" { + continue + } + if strings.HasPrefix(entry, "*.") { + suffix := strings.TrimPrefix(entry, "*.") + if host == suffix || strings.HasSuffix(host, "."+suffix) { + return true + } + continue + } + if host == entry { + return true + } + } + return false +} + +func isBlockedHost(host string) bool { + if host == "localhost" || strings.HasSuffix(host, ".localhost") { + return true + } + if ip := net.ParseIP(host); ip != nil { + if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsUnspecified() { + return true + } + } + return false +} diff --git a/deploy/Caddyfile b/deploy/Caddyfile index eaba462b..3aeef51a 100644 --- a/deploy/Caddyfile +++ b/deploy/Caddyfile @@ -25,8 +25,8 @@ timeouts { read_body 30s read_header 10s - write 60s - idle 120s + write 300s + idle 300s } } } @@ -77,7 +77,10 @@ example.com { write_buffer 16KB compression off } - + + # SSE/流式传输优化:禁用响应缓冲,立即刷新数据给客户端 + flush_interval -1 + # 故障转移 fail_duration 30s max_fails 3 @@ -92,6 +95,10 @@ example.com { gzip 6 minimum_length 256 match { + # SSE 请求通常会带 Accept: text/event-stream,需排除压缩 + not header Accept text/event-stream* + # 排除已知 SSE 路径(即便 Accept 缺失) + not path /v1/messages /v1/responses /responses /antigravity/v1/messages /v1beta/models/* /antigravity/v1beta/models/* header Content-Type text/* header Content-Type application/json* header Content-Type application/javascript* @@ -179,6 +186,3 @@ example.com { # ============================================================================= # HTTP 重定向到 HTTPS (Caddy 默认自动处理,此处显式声明) # ============================================================================= -; http://example.com { -; redir https://{host}{uri} permanent -; } diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index f07e893c..7b2c7d39 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -12,6 +12,8 @@ server: port: 8080 # Mode: "debug" for development, "release" for production mode: "release" + # Trusted proxies for X-Forwarded-For parsing (CIDR/IP). Empty disables trusted proxies. + trusted_proxies: [] # ============================================================================= # Run Mode Configuration @@ -21,12 +23,54 @@ server: # - simple: Hides SaaS features and skips billing/balance checks run_mode: "standard" +# ============================================================================= +# CORS Configuration +# ============================================================================= +cors: + # Allowed origins list. Leave empty to disable cross-origin requests. + allowed_origins: [] + # Allow credentials (cookies/authorization headers). Cannot be used with "*". + allow_credentials: true + +# ============================================================================= +# Security Configuration +# ============================================================================= +security: + url_allowlist: + # Allowed upstream hosts for API proxying + upstream_hosts: + - "api.openai.com" + - "api.anthropic.com" + - "generativelanguage.googleapis.com" + - "cloudcode-pa.googleapis.com" + - "*.openai.azure.com" + # Allowed hosts for pricing data download + pricing_hosts: + - "raw.githubusercontent.com" + # Allowed hosts for CRS sync (required when using CRS sync) + crs_hosts: [] + # Allow localhost/private IPs for upstream/pricing/CRS (use only in trusted networks) + allow_private_hosts: false + response_headers: + # Extra allowed response headers from upstream + additional_allowed: [] + # Force-remove response headers from upstream + force_remove: [] + csp: + # Enable Content-Security-Policy header + enabled: true + # Default CSP policy (override if you host assets on other domains) + policy: "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self' https:; frame-ancestors 'none'; base-uri 'self'; form-action 'self'" + proxy_probe: + # Allow skipping TLS verification for proxy probe (debug only) + insecure_skip_verify: false + # ============================================================================= # 网关配置 # ============================================================================= gateway: # 等待上游响应头超时时间(秒) - response_header_timeout: 300 + response_header_timeout: 600 # 请求体最大字节数(默认 100MB) max_body_size: 104857600 # 连接池隔离策略: @@ -38,14 +82,35 @@ gateway: max_idle_conns: 240 max_idle_conns_per_host: 120 max_conns_per_host: 240 - idle_conn_timeout_seconds: 300 + idle_conn_timeout_seconds: 90 # 上游连接池客户端缓存配置 # max_upstream_clients: 最大缓存客户端数量,超出后淘汰最久未使用的 # client_idle_ttl_seconds: 客户端空闲回收阈值(秒),超时且无活跃请求时回收 max_upstream_clients: 5000 client_idle_ttl_seconds: 900 # 并发槽位过期时间(分钟) - concurrency_slot_ttl_minutes: 15 + concurrency_slot_ttl_minutes: 30 + # 流数据间隔超时(秒),0=禁用 + stream_data_interval_timeout: 180 + # 流式 keepalive 间隔(秒),0=禁用 + stream_keepalive_interval: 10 + # SSE 单行最大字节数(默认 10MB) + max_line_size: 10485760 + # Log upstream error response body summary (safe/truncated; does not log request content) + log_upstream_error_body: false + # Max bytes to log from upstream error body + log_upstream_error_body_max_bytes: 2048 + # Auto inject anthropic-beta for API-key accounts when needed (default off) + inject_beta_for_apikey: false + # Allow failover on selected 400 errors (default off) + failover_on_400: false + +# ============================================================================= +# 并发等待配置 +# ============================================================================= +concurrency: + # 并发等待期间的 SSE ping 间隔(秒) + ping_interval: 10 # ============================================================================= # Database Configuration (PostgreSQL) @@ -77,7 +142,7 @@ jwt: # IMPORTANT: Change this to a random string in production! # Generate with: openssl rand -hex 32 secret: "change-this-to-a-secure-random-string" - # Token expiration time in hours + # Token expiration time in hours (max 24) expire_hour: 24 # ============================================================================= @@ -123,19 +188,21 @@ pricing: hash_check_interval_minutes: 10 # ============================================================================= -# Gateway (Optional) +# Billing Configuration # ============================================================================= -gateway: - # Wait time (in seconds) for upstream response headers (streaming body not affected) - response_header_timeout: 300 - # Log upstream error response body summary (safe/truncated; does not log request content) - log_upstream_error_body: false - # Max bytes to log from upstream error body - log_upstream_error_body_max_bytes: 2048 - # Auto inject anthropic-beta for API-key accounts when needed (default off) - inject_beta_for_apikey: false - # Allow failover on selected 400 errors (default off) - failover_on_400: false +billing: + circuit_breaker: + enabled: true + failure_threshold: 5 + reset_timeout_seconds: 30 + half_open_requests: 3 + +# ============================================================================= +# Turnstile Configuration +# ============================================================================= +turnstile: + # Require Turnstile in release mode (when enabled, login/register will fail if not configured) + required: false # ============================================================================= # Gemini OAuth (Required for Gemini accounts) diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index cc91c09b..6b46de7d 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -26,20 +26,44 @@ export interface SystemSettings { smtp_host: string smtp_port: number smtp_username: string - smtp_password: string + smtp_password_configured: boolean smtp_from_email: string smtp_from_name: string smtp_use_tls: boolean // Cloudflare Turnstile settings turnstile_enabled: boolean turnstile_site_key: string - turnstile_secret_key: string - + turnstile_secret_key_configured: boolean // Identity patch configuration (Claude -> Gemini) enable_identity_patch: boolean identity_patch_prompt: string } +export interface UpdateSettingsRequest { + registration_enabled?: boolean + email_verify_enabled?: boolean + default_balance?: number + default_concurrency?: number + site_name?: string + site_logo?: string + site_subtitle?: string + api_base_url?: string + contact_info?: string + doc_url?: string + smtp_host?: string + smtp_port?: number + smtp_username?: string + smtp_password?: string + smtp_from_email?: string + smtp_from_name?: string + smtp_use_tls?: boolean + turnstile_enabled?: boolean + turnstile_site_key?: string + turnstile_secret_key?: string + enable_identity_patch?: boolean + identity_patch_prompt?: string +} + /** * Get all system settings * @returns System settings @@ -54,7 +78,7 @@ export async function getSettings(): Promise { * @param settings - Partial settings to update * @returns Updated settings */ -export async function updateSettings(settings: Partial): Promise { +export async function updateSettings(settings: UpdateSettingsRequest): Promise { const { data } = await apiClient.put('/admin/settings', settings) return data } diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index 0aafd893..1cc8e55b 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -69,8 +69,24 @@ apiClient.interceptors.response.use( // 401: Unauthorized - clear token and redirect to login if (status === 401) { + const hasToken = !!localStorage.getItem('auth_token') + const url = error.config?.url || '' + const isAuthEndpoint = + url.includes('/auth/login') || url.includes('/auth/register') || url.includes('/auth/refresh') + const headers = error.config?.headers as Record | undefined + const authHeader = headers?.Authorization ?? headers?.authorization + const sentAuth = + typeof authHeader === 'string' + ? authHeader.trim() !== '' + : Array.isArray(authHeader) + ? authHeader.length > 0 + : !!authHeader + localStorage.removeItem('auth_token') localStorage.removeItem('auth_user') + if ((hasToken || sentAuth) && !isAuthEndpoint) { + sessionStorage.setItem('auth_expired', '1') + } // Only redirect if not already on login page if (!window.location.pathname.includes('/login')) { window.location.href = '/login' diff --git a/frontend/src/components/account/OAuthAuthorizationFlow.vue b/frontend/src/components/account/OAuthAuthorizationFlow.vue index 41c316f5..7ce30b46 100644 --- a/frontend/src/components/account/OAuthAuthorizationFlow.vue +++ b/frontend/src/components/account/OAuthAuthorizationFlow.vue @@ -136,16 +136,16 @@
    -
  1. -
  2. -
  3. -
  4. -
  5. -
  6. +
  7. {{ t('admin.accounts.oauth.step1') }}
  8. +
  9. {{ t('admin.accounts.oauth.step2') }}
  10. +
  11. {{ t('admin.accounts.oauth.step3') }}
  12. +
  13. {{ t('admin.accounts.oauth.step4') }}
  14. +
  15. {{ t('admin.accounts.oauth.step5') }}
  16. +
  17. {{ t('admin.accounts.oauth.step6') }}

@@ -390,7 +390,7 @@ >

@@ -400,7 +400,7 @@ >

@@ -423,7 +423,7 @@

-
+
+                
+                
+              
@@ -164,8 +167,8 @@ interface TabConfig { interface FileConfig { path: string content: string - highlighted: string hint?: string // Optional hint message for this file + highlighted?: string } const props = defineProps() @@ -311,14 +314,23 @@ const platformNote = computed(() => { } }) -// Syntax highlighting helpers -const keyword = (text: string) => `${text}` -const variable = (text: string) => `${text}` -const string = (text: string) => `${text}` -const operator = (text: string) => `${text}` -const comment = (text: string) => `${text}` -const key = (text: string) => `${text}` +const escapeHtml = (value: string) => value + .replace(/&/g, '&') + .replace(//g, '>') + .replace(/"/g, '"') + .replace(/'/g, ''') +const wrapToken = (className: string, value: string) => + `${escapeHtml(value)}` + +const keyword = (value: string) => wrapToken('text-emerald-300', value) +const variable = (value: string) => wrapToken('text-sky-200', value) +const operator = (value: string) => wrapToken('text-slate-400', value) +const string = (value: string) => wrapToken('text-amber-200', value) +const comment = (value: string) => wrapToken('text-slate-500', value) + +// Syntax highlighting helpers // Generate file configs based on platform and active tab const currentFiles = computed((): FileConfig[] => { const baseUrl = props.baseUrl || window.location.origin @@ -343,37 +355,29 @@ const currentFiles = computed((): FileConfig[] => { function generateAnthropicFiles(baseUrl: string, apiKey: string): FileConfig[] { let path: string let content: string - let highlighted: string switch (activeTab.value) { case 'unix': path = 'Terminal' content = `export ANTHROPIC_BASE_URL="${baseUrl}" export ANTHROPIC_AUTH_TOKEN="${apiKey}"` - highlighted = `${keyword('export')} ${variable('ANTHROPIC_BASE_URL')}${operator('=')}${string(`"${baseUrl}"`)} -${keyword('export')} ${variable('ANTHROPIC_AUTH_TOKEN')}${operator('=')}${string(`"${apiKey}"`)}` break case 'cmd': path = 'Command Prompt' content = `set ANTHROPIC_BASE_URL=${baseUrl} set ANTHROPIC_AUTH_TOKEN=${apiKey}` - highlighted = `${keyword('set')} ${variable('ANTHROPIC_BASE_URL')}${operator('=')}${baseUrl} -${keyword('set')} ${variable('ANTHROPIC_AUTH_TOKEN')}${operator('=')}${apiKey}` break case 'powershell': path = 'PowerShell' content = `$env:ANTHROPIC_BASE_URL="${baseUrl}" $env:ANTHROPIC_AUTH_TOKEN="${apiKey}"` - highlighted = `${keyword('$env:')}${variable('ANTHROPIC_BASE_URL')}${operator('=')}${string(`"${baseUrl}"`)} -${keyword('$env:')}${variable('ANTHROPIC_AUTH_TOKEN')}${operator('=')}${string(`"${apiKey}"`)}` break default: path = 'Terminal' content = '' - highlighted = '' } - return [{ path, content, highlighted }] + return [{ path, content }] } function generateGeminiCliContent(baseUrl: string, apiKey: string): FileConfig { @@ -398,9 +402,9 @@ ${keyword('export')} ${variable('GEMINI_MODEL')}${operator('=')}${string(`"${mod content = `set GOOGLE_GEMINI_BASE_URL=${baseUrl} set GEMINI_API_KEY=${apiKey} set GEMINI_MODEL=${model}` - highlighted = `${keyword('set')} ${variable('GOOGLE_GEMINI_BASE_URL')}${operator('=')}${baseUrl} -${keyword('set')} ${variable('GEMINI_API_KEY')}${operator('=')}${apiKey} -${keyword('set')} ${variable('GEMINI_MODEL')}${operator('=')}${model} + highlighted = `${keyword('set')} ${variable('GOOGLE_GEMINI_BASE_URL')}${operator('=')}${string(baseUrl)} +${keyword('set')} ${variable('GEMINI_API_KEY')}${operator('=')}${string(apiKey)} +${keyword('set')} ${variable('GEMINI_MODEL')}${operator('=')}${string(model)} ${comment(`REM ${modelComment}`)}` break case 'powershell': @@ -440,40 +444,20 @@ base_url = "${baseUrl}" wire_api = "responses" requires_openai_auth = true` - const configHighlighted = `${key('model_provider')} ${operator('=')} ${string('"sub2api"')} -${key('model')} ${operator('=')} ${string('"gpt-5.2-codex"')} -${key('model_reasoning_effort')} ${operator('=')} ${string('"high"')} -${key('network_access')} ${operator('=')} ${string('"enabled"')} -${key('disable_response_storage')} ${operator('=')} ${keyword('true')} -${key('windows_wsl_setup_acknowledged')} ${operator('=')} ${keyword('true')} -${key('model_verbosity')} ${operator('=')} ${string('"high"')} - -${comment('[model_providers.sub2api]')} -${key('name')} ${operator('=')} ${string('"sub2api"')} -${key('base_url')} ${operator('=')} ${string(`"${baseUrl}"`)} -${key('wire_api')} ${operator('=')} ${string('"responses"')} -${key('requires_openai_auth')} ${operator('=')} ${keyword('true')}` - // auth.json content const authContent = `{ "OPENAI_API_KEY": "${apiKey}" }` - const authHighlighted = `{ - ${key('"OPENAI_API_KEY"')}: ${string(`"${apiKey}"`)} -}` - return [ { path: `${configDir}/config.toml`, content: configContent, - highlighted: configHighlighted, hint: t('keys.useKeyModal.openai.configTomlHint') }, { path: `${configDir}/auth.json`, - content: authContent, - highlighted: authHighlighted + content: authContent } ] } diff --git a/frontend/src/components/layout/AuthLayout.vue b/frontend/src/components/layout/AuthLayout.vue index 1a0cfec7..3cfc1d4d 100644 --- a/frontend/src/components/layout/AuthLayout.vue +++ b/frontend/src/components/layout/AuthLayout.vue @@ -63,6 +63,7 @@