Merge PR #146: feat: 提升流式网关稳定性与安全策略强化

This commit is contained in:
shaw
2026-01-04 22:45:01 +08:00
62 changed files with 2575 additions and 551 deletions

View File

@@ -268,6 +268,15 @@ default:
rate_multiplier: 1.0
```
Additional security-related options are available in `config.yaml`:
- `cors.allowed_origins` for CORS allowlist
- `security.url_allowlist` for upstream/pricing/CRS host allowlists
- `security.csp` to control Content-Security-Policy headers
- `billing.circuit_breaker` to fail closed on billing errors
- `server.trusted_proxies` to enable X-Forwarded-For parsing
- `turnstile.required` to require Turnstile in release mode
```bash
# 6. Run the application
./sub2api

View File

@@ -268,6 +268,15 @@ default:
rate_multiplier: 1.0
```
`config.yaml` 还支持以下安全相关配置:
- `cors.allowed_origins` 配置 CORS 白名单
- `security.url_allowlist` 配置上游/价格数据/CRS 主机白名单
- `security.csp` 配置 Content-Security-Policy
- `billing.circuit_breaker` 计费异常时 fail-closed
- `server.trusted_proxies` 启用可信代理解析 X-Forwarded-For
- `turnstile.required` 在 release 模式强制启用 Turnstile
```bash
# 6. 运行应用
./sub2api

View File

@@ -86,7 +86,8 @@ func main() {
func runSetupServer() {
r := gin.New()
r.Use(middleware.Recovery())
r.Use(middleware.CORS())
r.Use(middleware.CORS(config.CORSConfig{}))
r.Use(middleware.SecurityHeaders(config.CSPConfig{Enabled: true, Policy: config.DefaultCSPPolicy}))
// Register setup routes
setup.RegisterRoutes(r)

View File

@@ -76,7 +76,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
dashboardHandler := admin.NewDashboardHandler(dashboardService)
accountRepository := repository.NewAccountRepository(client, db)
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber()
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber)
adminUserHandler := admin.NewUserHandler(adminService)
groupHandler := admin.NewGroupHandler(adminService)
@@ -101,10 +101,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
httpUpstream := repository.NewHTTPUpstream(configConfig)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService)
concurrencyService := service.NewConcurrencyService(concurrencyCache)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
@@ -125,7 +125,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
pricingRemoteClient := repository.NewPricingRemoteClient()
pricingRemoteClient := repository.NewPricingRemoteClient(configConfig)
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil {
return nil, err
@@ -136,10 +136,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
timingWheelService := service.ProvideTimingWheelService()
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)

View File

@@ -2,7 +2,10 @@
package config
import (
"crypto/rand"
"encoding/hex"
"fmt"
"log"
"strings"
"time"
@@ -14,6 +17,8 @@ const (
RunModeSimple = "simple"
)
const DefaultCSPPolicy = "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self' https:; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
// 连接池隔离策略常量
// 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗
const (
@@ -30,6 +35,10 @@ const (
type Config struct {
Server ServerConfig `mapstructure:"server"`
CORS CORSConfig `mapstructure:"cors"`
Security SecurityConfig `mapstructure:"security"`
Billing BillingConfig `mapstructure:"billing"`
Turnstile TurnstileConfig `mapstructure:"turnstile"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
JWT JWTConfig `mapstructure:"jwt"`
@@ -37,6 +46,7 @@ type Config struct {
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
@@ -95,11 +105,61 @@ type PricingConfig struct {
}
type ServerConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Mode string `mapstructure:"mode"` // debug/release
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Mode string `mapstructure:"mode"` // debug/release
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表CIDR/IP
}
type CORSConfig struct {
AllowedOrigins []string `mapstructure:"allowed_origins"`
AllowCredentials bool `mapstructure:"allow_credentials"`
}
type SecurityConfig struct {
URLAllowlist URLAllowlistConfig `mapstructure:"url_allowlist"`
ResponseHeaders ResponseHeaderConfig `mapstructure:"response_headers"`
CSP CSPConfig `mapstructure:"csp"`
ProxyProbe ProxyProbeConfig `mapstructure:"proxy_probe"`
}
type URLAllowlistConfig struct {
UpstreamHosts []string `mapstructure:"upstream_hosts"`
PricingHosts []string `mapstructure:"pricing_hosts"`
CRSHosts []string `mapstructure:"crs_hosts"`
AllowPrivateHosts bool `mapstructure:"allow_private_hosts"`
}
type ResponseHeaderConfig struct {
AdditionalAllowed []string `mapstructure:"additional_allowed"`
ForceRemove []string `mapstructure:"force_remove"`
}
type CSPConfig struct {
Enabled bool `mapstructure:"enabled"`
Policy string `mapstructure:"policy"`
}
type ProxyProbeConfig struct {
InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"`
}
type BillingConfig struct {
CircuitBreaker CircuitBreakerConfig `mapstructure:"circuit_breaker"`
}
type CircuitBreakerConfig struct {
Enabled bool `mapstructure:"enabled"`
FailureThreshold int `mapstructure:"failure_threshold"`
ResetTimeoutSeconds int `mapstructure:"reset_timeout_seconds"`
HalfOpenRequests int `mapstructure:"half_open_requests"`
}
type ConcurrencyConfig struct {
// PingInterval: 并发等待期间的 SSE ping 间隔(秒)
PingInterval int `mapstructure:"ping_interval"`
}
// GatewayConfig API网关相关配置
@@ -134,6 +194,13 @@ type GatewayConfig struct {
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
// StreamDataIntervalTimeout: 流数据间隔超时0表示禁用
StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"`
// StreamKeepaliveInterval: 流式 keepalive 间隔0表示禁用
StreamKeepaliveInterval int `mapstructure:"stream_keepalive_interval"`
// MaxLineSize: 上游 SSE 单行最大字节数0使用默认值
MaxLineSize int `mapstructure:"max_line_size"`
// 是否记录上游错误响应体摘要(避免输出请求内容)
LogUpstreamErrorBody bool `mapstructure:"log_upstream_error_body"`
// 上游错误响应体记录最大字节数(超过会截断)
@@ -237,6 +304,10 @@ type JWTConfig struct {
ExpireHour int `mapstructure:"expire_hour"`
}
type TurnstileConfig struct {
Required bool `mapstructure:"required"`
}
type DefaultConfig struct {
AdminEmail string `mapstructure:"admin_email"`
AdminPassword string `mapstructure:"admin_password"`
@@ -287,11 +358,39 @@ func Load() (*Config, error) {
}
cfg.RunMode = NormalizeRunMode(cfg.RunMode)
cfg.Server.Mode = strings.ToLower(strings.TrimSpace(cfg.Server.Mode))
if cfg.Server.Mode == "" {
cfg.Server.Mode = "debug"
}
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy)
if cfg.Server.Mode != "release" && cfg.JWT.Secret == "" {
secret, err := generateJWTSecret(64)
if err != nil {
return nil, fmt.Errorf("generate jwt secret error: %w", err)
}
cfg.JWT.Secret = secret
log.Println("Warning: JWT secret auto-generated for non-release mode. Do not use in production.")
}
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("validate config error: %w", err)
}
if cfg.Server.Mode != "release" && cfg.JWT.Secret != "" && isWeakJWTSecret(cfg.JWT.Secret) {
log.Println("Warning: JWT secret appears weak; use a 32+ character random secret in production.")
}
if len(cfg.Security.ResponseHeaders.AdditionalAllowed) > 0 || len(cfg.Security.ResponseHeaders.ForceRemove) > 0 {
log.Printf("AUDIT: response header policy configured additional_allowed=%v force_remove=%v",
cfg.Security.ResponseHeaders.AdditionalAllowed,
cfg.Security.ResponseHeaders.ForceRemove,
)
}
return &cfg, nil
}
@@ -304,6 +403,39 @@ func setDefaults() {
viper.SetDefault("server.mode", "debug")
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
viper.SetDefault("server.trusted_proxies", []string{})
// CORS
viper.SetDefault("cors.allowed_origins", []string{})
viper.SetDefault("cors.allow_credentials", true)
// Security
viper.SetDefault("security.url_allowlist.upstream_hosts", []string{
"api.openai.com",
"api.anthropic.com",
"generativelanguage.googleapis.com",
"cloudcode-pa.googleapis.com",
"*.openai.azure.com",
})
viper.SetDefault("security.url_allowlist.pricing_hosts", []string{
"raw.githubusercontent.com",
})
viper.SetDefault("security.url_allowlist.crs_hosts", []string{})
viper.SetDefault("security.url_allowlist.allow_private_hosts", false)
viper.SetDefault("security.response_headers.additional_allowed", []string{})
viper.SetDefault("security.response_headers.force_remove", []string{})
viper.SetDefault("security.csp.enabled", true)
viper.SetDefault("security.csp.policy", DefaultCSPPolicy)
viper.SetDefault("security.proxy_probe.insecure_skip_verify", false)
// Billing
viper.SetDefault("billing.circuit_breaker.enabled", true)
viper.SetDefault("billing.circuit_breaker.failure_threshold", 5)
viper.SetDefault("billing.circuit_breaker.reset_timeout_seconds", 30)
viper.SetDefault("billing.circuit_breaker.half_open_requests", 3)
// Turnstile
viper.SetDefault("turnstile.required", false)
// Database
viper.SetDefault("database.host", "localhost")
@@ -329,7 +461,7 @@ func setDefaults() {
viper.SetDefault("redis.min_idle_conns", 10)
// JWT
viper.SetDefault("jwt.secret", "change-me-in-production")
viper.SetDefault("jwt.secret", "")
viper.SetDefault("jwt.expire_hour", 24)
// Default
@@ -357,7 +489,7 @@ func setDefaults() {
viper.SetDefault("timezone", "Asia/Shanghai")
// Gateway
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头LLM高负载时可能排队较久
viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头LLM高负载时可能排队较久
viper.SetDefault("gateway.log_upstream_error_body", false)
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
viper.SetDefault("gateway.inject_beta_for_apikey", false)
@@ -365,19 +497,23 @@ func setDefaults() {
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数HTTP/2 场景默认)
viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接HTTP/2 场景默认)
viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数含活跃HTTP/2 场景默认)
viper.SetDefault("gateway.idle_conn_timeout_seconds", 300) // 空闲连接超时(秒)
viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数HTTP/2 场景默认)
viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接HTTP/2 场景默认)
viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数含活跃HTTP/2 场景默认)
viper.SetDefault("gateway.idle_conn_timeout_seconds", 90) // 空闲连接超时(秒)
viper.SetDefault("gateway.max_upstream_clients", 5000)
viper.SetDefault("gateway.client_idle_ttl_seconds", 900)
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 15) // 并发槽位过期时间(支持超长请求)
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
viper.SetDefault("gateway.stream_keepalive_interval", 10)
viper.SetDefault("gateway.max_line_size", 10*1024*1024)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second)
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
viper.SetDefault("concurrency.ping_interval", 10)
// TokenRefresh
viper.SetDefault("token_refresh.enabled", true)
@@ -396,11 +532,39 @@ func setDefaults() {
}
func (c *Config) Validate() error {
if c.JWT.Secret == "" {
return fmt.Errorf("jwt.secret is required")
if c.Server.Mode == "release" {
if c.JWT.Secret == "" {
return fmt.Errorf("jwt.secret is required in release mode")
}
if len(c.JWT.Secret) < 32 {
return fmt.Errorf("jwt.secret must be at least 32 characters")
}
if isWeakJWTSecret(c.JWT.Secret) {
return fmt.Errorf("jwt.secret is too weak")
}
}
if c.JWT.Secret == "change-me-in-production" && c.Server.Mode == "release" {
return fmt.Errorf("jwt.secret must be changed in production")
if c.JWT.ExpireHour <= 0 {
return fmt.Errorf("jwt.expire_hour must be positive")
}
if c.JWT.ExpireHour > 168 {
return fmt.Errorf("jwt.expire_hour must be <= 168 (7 days)")
}
if c.JWT.ExpireHour > 24 {
log.Printf("Warning: jwt.expire_hour is %d hours (> 24). Consider shorter expiration for security.", c.JWT.ExpireHour)
}
if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" {
return fmt.Errorf("security.csp.policy is required when CSP is enabled")
}
if c.Billing.CircuitBreaker.Enabled {
if c.Billing.CircuitBreaker.FailureThreshold <= 0 {
return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive")
}
if c.Billing.CircuitBreaker.ResetTimeoutSeconds <= 0 {
return fmt.Errorf("billing.circuit_breaker.reset_timeout_seconds must be positive")
}
if c.Billing.CircuitBreaker.HalfOpenRequests <= 0 {
return fmt.Errorf("billing.circuit_breaker.half_open_requests must be positive")
}
}
if c.Database.MaxOpenConns <= 0 {
return fmt.Errorf("database.max_open_conns must be positive")
@@ -458,6 +622,9 @@ func (c *Config) Validate() error {
if c.Gateway.IdleConnTimeoutSeconds <= 0 {
return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive")
}
if c.Gateway.IdleConnTimeoutSeconds > 180 {
log.Printf("Warning: gateway.idle_conn_timeout_seconds is %d (> 180). Consider 60-120 seconds for better connection reuse.", c.Gateway.IdleConnTimeoutSeconds)
}
if c.Gateway.MaxUpstreamClients <= 0 {
return fmt.Errorf("gateway.max_upstream_clients must be positive")
}
@@ -467,6 +634,26 @@ func (c *Config) Validate() error {
if c.Gateway.ConcurrencySlotTTLMinutes <= 0 {
return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive")
}
if c.Gateway.StreamDataIntervalTimeout < 0 {
return fmt.Errorf("gateway.stream_data_interval_timeout must be non-negative")
}
if c.Gateway.StreamDataIntervalTimeout != 0 &&
(c.Gateway.StreamDataIntervalTimeout < 30 || c.Gateway.StreamDataIntervalTimeout > 300) {
return fmt.Errorf("gateway.stream_data_interval_timeout must be 0 or between 30-300 seconds")
}
if c.Gateway.StreamKeepaliveInterval < 0 {
return fmt.Errorf("gateway.stream_keepalive_interval must be non-negative")
}
if c.Gateway.StreamKeepaliveInterval != 0 &&
(c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) {
return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds")
}
if c.Gateway.MaxLineSize < 0 {
return fmt.Errorf("gateway.max_line_size must be non-negative")
}
if c.Gateway.MaxLineSize != 0 && c.Gateway.MaxLineSize < 1024*1024 {
return fmt.Errorf("gateway.max_line_size must be at least 1MB")
}
if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 {
return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive")
}
@@ -482,9 +669,57 @@ func (c *Config) Validate() error {
if c.Gateway.Scheduling.SlotCleanupInterval < 0 {
return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative")
}
if c.Concurrency.PingInterval < 5 || c.Concurrency.PingInterval > 30 {
return fmt.Errorf("concurrency.ping_interval must be between 5-30 seconds")
}
return nil
}
func normalizeStringSlice(values []string) []string {
if len(values) == 0 {
return values
}
normalized := make([]string, 0, len(values))
for _, v := range values {
trimmed := strings.TrimSpace(v)
if trimmed == "" {
continue
}
normalized = append(normalized, trimmed)
}
return normalized
}
func isWeakJWTSecret(secret string) bool {
lower := strings.ToLower(strings.TrimSpace(secret))
if lower == "" {
return true
}
weak := map[string]struct{}{
"change-me-in-production": {},
"changeme": {},
"secret": {},
"password": {},
"123456": {},
"12345678": {},
"admin": {},
"jwt-secret": {},
}
_, exists := weak[lower]
return exists
}
func generateJWTSecret(byteLength int) (string, error) {
if byteLength <= 0 {
byteLength = 32
}
buf := make([]byte, byteLength)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return hex.EncodeToString(buf), nil
}
// GetServerAddress returns the server address (host:port) from config file or environment variable.
// This is a lightweight function that can be used before full config validation,
// such as during setup wizard startup.

View File

@@ -1,8 +1,12 @@
package admin
import (
"log"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -34,31 +38,31 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
}
response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
SMTPHost: settings.SMTPHost,
SMTPPort: settings.SMTPPort,
SMTPUsername: settings.SMTPUsername,
SMTPPassword: settings.SMTPPassword,
SMTPFrom: settings.SMTPFrom,
SMTPFromName: settings.SMTPFromName,
SMTPUseTLS: settings.SMTPUseTLS,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
TurnstileSecretKey: settings.TurnstileSecretKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
EnableModelFallback: settings.EnableModelFallback,
FallbackModelAnthropic: settings.FallbackModelAnthropic,
FallbackModelOpenAI: settings.FallbackModelOpenAI,
FallbackModelGemini: settings.FallbackModelGemini,
FallbackModelAntigravity: settings.FallbackModelAntigravity,
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
SMTPHost: settings.SMTPHost,
SMTPPort: settings.SMTPPort,
SMTPUsername: settings.SMTPUsername,
SMTPPasswordConfigured: settings.SMTPPasswordConfigured,
SMTPFrom: settings.SMTPFrom,
SMTPFromName: settings.SMTPFromName,
SMTPUseTLS: settings.SMTPUseTLS,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
EnableModelFallback: settings.EnableModelFallback,
FallbackModelAnthropic: settings.FallbackModelAnthropic,
FallbackModelOpenAI: settings.FallbackModelOpenAI,
FallbackModelGemini: settings.FallbackModelGemini,
FallbackModelAntigravity: settings.FallbackModelAntigravity,
})
}
@@ -111,6 +115,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return
}
previousSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
// 验证参数
if req.DefaultConcurrency < 1 {
req.DefaultConcurrency = 1
@@ -185,6 +195,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return
}
h.auditSettingsUpdate(c, previousSettings, settings, req)
// 重新获取设置返回
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
@@ -193,34 +205,134 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
SMTPHost: updatedSettings.SMTPHost,
SMTPPort: updatedSettings.SMTPPort,
SMTPUsername: updatedSettings.SMTPUsername,
SMTPPassword: updatedSettings.SMTPPassword,
SMTPFrom: updatedSettings.SMTPFrom,
SMTPFromName: updatedSettings.SMTPFromName,
SMTPUseTLS: updatedSettings.SMTPUseTLS,
TurnstileEnabled: updatedSettings.TurnstileEnabled,
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
TurnstileSecretKey: updatedSettings.TurnstileSecretKey,
SiteName: updatedSettings.SiteName,
SiteLogo: updatedSettings.SiteLogo,
SiteSubtitle: updatedSettings.SiteSubtitle,
APIBaseURL: updatedSettings.APIBaseURL,
ContactInfo: updatedSettings.ContactInfo,
DocURL: updatedSettings.DocURL,
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
EnableModelFallback: updatedSettings.EnableModelFallback,
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
FallbackModelGemini: updatedSettings.FallbackModelGemini,
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
SMTPHost: updatedSettings.SMTPHost,
SMTPPort: updatedSettings.SMTPPort,
SMTPUsername: updatedSettings.SMTPUsername,
SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured,
SMTPFrom: updatedSettings.SMTPFrom,
SMTPFromName: updatedSettings.SMTPFromName,
SMTPUseTLS: updatedSettings.SMTPUseTLS,
TurnstileEnabled: updatedSettings.TurnstileEnabled,
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
SiteName: updatedSettings.SiteName,
SiteLogo: updatedSettings.SiteLogo,
SiteSubtitle: updatedSettings.SiteSubtitle,
APIBaseURL: updatedSettings.APIBaseURL,
ContactInfo: updatedSettings.ContactInfo,
DocURL: updatedSettings.DocURL,
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
EnableModelFallback: updatedSettings.EnableModelFallback,
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
FallbackModelGemini: updatedSettings.FallbackModelGemini,
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
})
}
func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) {
if before == nil || after == nil {
return
}
changed := diffSettings(before, after, req)
if len(changed) == 0 {
return
}
subject, _ := middleware.GetAuthSubjectFromContext(c)
role, _ := middleware.GetUserRoleFromContext(c)
log.Printf("AUDIT: settings updated at=%s user_id=%d role=%s changed=%v",
time.Now().UTC().Format(time.RFC3339),
subject.UserID,
role,
changed,
)
}
func diffSettings(before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) []string {
changed := make([]string, 0, 20)
if before.RegistrationEnabled != after.RegistrationEnabled {
changed = append(changed, "registration_enabled")
}
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
changed = append(changed, "email_verify_enabled")
}
if before.SMTPHost != after.SMTPHost {
changed = append(changed, "smtp_host")
}
if before.SMTPPort != after.SMTPPort {
changed = append(changed, "smtp_port")
}
if before.SMTPUsername != after.SMTPUsername {
changed = append(changed, "smtp_username")
}
if req.SMTPPassword != "" {
changed = append(changed, "smtp_password")
}
if before.SMTPFrom != after.SMTPFrom {
changed = append(changed, "smtp_from_email")
}
if before.SMTPFromName != after.SMTPFromName {
changed = append(changed, "smtp_from_name")
}
if before.SMTPUseTLS != after.SMTPUseTLS {
changed = append(changed, "smtp_use_tls")
}
if before.TurnstileEnabled != after.TurnstileEnabled {
changed = append(changed, "turnstile_enabled")
}
if before.TurnstileSiteKey != after.TurnstileSiteKey {
changed = append(changed, "turnstile_site_key")
}
if req.TurnstileSecretKey != "" {
changed = append(changed, "turnstile_secret_key")
}
if before.SiteName != after.SiteName {
changed = append(changed, "site_name")
}
if before.SiteLogo != after.SiteLogo {
changed = append(changed, "site_logo")
}
if before.SiteSubtitle != after.SiteSubtitle {
changed = append(changed, "site_subtitle")
}
if before.APIBaseURL != after.APIBaseURL {
changed = append(changed, "api_base_url")
}
if before.ContactInfo != after.ContactInfo {
changed = append(changed, "contact_info")
}
if before.DocURL != after.DocURL {
changed = append(changed, "doc_url")
}
if before.DefaultConcurrency != after.DefaultConcurrency {
changed = append(changed, "default_concurrency")
}
if before.DefaultBalance != after.DefaultBalance {
changed = append(changed, "default_balance")
}
if before.EnableModelFallback != after.EnableModelFallback {
changed = append(changed, "enable_model_fallback")
}
if before.FallbackModelAnthropic != after.FallbackModelAnthropic {
changed = append(changed, "fallback_model_anthropic")
}
if before.FallbackModelOpenAI != after.FallbackModelOpenAI {
changed = append(changed, "fallback_model_openai")
}
if before.FallbackModelGemini != after.FallbackModelGemini {
changed = append(changed, "fallback_model_gemini")
}
if before.FallbackModelAntigravity != after.FallbackModelAntigravity {
changed = append(changed, "fallback_model_antigravity")
}
return changed
}
// TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host" binding:"required"`

View File

@@ -5,17 +5,17 @@ type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password,omitempty"`
SMTPFrom string `json:"smtp_from_email"`
SMTPFromName string `json:"smtp_from_name"`
SMTPUseTLS bool `json:"smtp_use_tls"`
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"`
SMTPPasswordConfigured bool `json:"smtp_password_configured"`
SMTPFrom string `json:"smtp_from_email"`
SMTPFromName string `json:"smtp_from_name"`
SMTPUseTLS bool `json:"smtp_use_tls"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`

View File

@@ -11,8 +11,10 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -38,14 +40,19 @@ func NewGatewayHandler(
userService *service.UserService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
cfg *config.Config,
) *GatewayHandler {
pingInterval := time.Duration(0)
if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
}
return &GatewayHandler{
gatewayService: gatewayService,
geminiCompatService: geminiCompatService,
antigravityGatewayService: antigravityGatewayService,
userService: userService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude),
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
}
}
@@ -121,6 +128,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
// 在请求结束或 Context 取消时确保释放槽位,避免客户端断开造成泄漏
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
@@ -128,7 +137,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 2. 【新增】Wait后二次检查余额/订阅
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
@@ -220,6 +230,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
@@ -344,6 +357,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
@@ -674,7 +690,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 校验 billing eligibility订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
h.errorResponse(c, http.StatusForbidden, "billing_error", err.Error())
status, code, message := billingErrorDetails(err)
h.errorResponse(c, status, code, message)
return
}
@@ -800,3 +817,18 @@ func sendMockWarmupResponse(c *gin.Context, model string) {
},
})
}
func billingErrorDetails(err error) (status int, code, message string) {
if errors.Is(err, service.ErrBillingServiceUnavailable) {
msg := pkgerrors.Message(err)
if msg == "" {
msg = "Billing service temporarily unavailable. Please retry later."
}
return http.StatusServiceUnavailable, "billing_service_error", msg
}
msg := pkgerrors.Message(err)
if msg == "" {
msg = err.Error()
}
return http.StatusForbidden, "billing_error", msg
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"math/rand"
"net/http"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -26,8 +27,8 @@ import (
const (
// maxConcurrencyWait 等待并发槽位的最大时间
maxConcurrencyWait = 30 * time.Second
// pingInterval 流式响应等待时发送 ping 的间隔
pingInterval = 15 * time.Second
// defaultPingInterval 流式响应等待时发送 ping 的默认间隔
defaultPingInterval = 10 * time.Second
// initialBackoff 初始退避时间
initialBackoff = 100 * time.Millisecond
// backoffMultiplier 退避时间乘数(指数退避)
@@ -44,6 +45,8 @@ const (
SSEPingFormatClaude SSEPingFormat = "data: {\"type\": \"ping\"}\n\n"
// SSEPingFormatNone indicates no ping should be sent (e.g., OpenAI has no ping spec)
SSEPingFormatNone SSEPingFormat = ""
// SSEPingFormatComment is an SSE comment ping for OpenAI/Codex CLI clients
SSEPingFormatComment SSEPingFormat = ":\n\n"
)
// ConcurrencyError represents a concurrency limit error with context
@@ -63,16 +66,38 @@ func (e *ConcurrencyError) Error() string {
type ConcurrencyHelper struct {
concurrencyService *service.ConcurrencyService
pingFormat SSEPingFormat
pingInterval time.Duration
}
// NewConcurrencyHelper creates a new ConcurrencyHelper
func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat) *ConcurrencyHelper {
func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat, pingInterval time.Duration) *ConcurrencyHelper {
if pingInterval <= 0 {
pingInterval = defaultPingInterval
}
return &ConcurrencyHelper{
concurrencyService: concurrencyService,
pingFormat: pingFormat,
pingInterval: pingInterval,
}
}
// wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation.
// 用于避免客户端断开或上游超时导致的并发槽位泄漏。
func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() {
if releaseFunc == nil {
return nil
}
var once sync.Once
wrapped := func() {
once.Do(releaseFunc)
}
go func() {
<-ctx.Done()
wrapped()
}()
return wrapped
}
// IncrementWaitCount increments the wait count for a user
func (h *ConcurrencyHelper) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
return h.concurrencyService.IncrementWaitCount(ctx, userID, maxWait)
@@ -174,7 +199,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
// Only create ping ticker if ping is needed
var pingCh <-chan time.Time
if needPing {
pingTicker := time.NewTicker(pingInterval)
pingTicker := time.NewTicker(h.pingInterval)
defer pingTicker.Stop()
pingCh = pingTicker.C
}

View File

@@ -165,7 +165,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
subscription, _ := middleware.GetSubscriptionFromContext(c)
// For Gemini native API, do not send Claude-style ping frames.
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone)
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0)
// 0) wait queue check
maxWait := service.CalculateMaxWait(authSubject.Concurrency)
@@ -185,13 +185,16 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2) billing eligibility check (after wait)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
googleError(c, http.StatusForbidden, err.Error())
status, _, message := billingErrorDetails(err)
googleError(c, status, message)
return
}
@@ -260,6 +263,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// 5) forward (根据平台分流)
var result *service.ForwardResult
@@ -373,7 +379,7 @@ func writeUpstreamResponse(c *gin.Context, res *service.UpstreamHTTPResult) {
}
for k, vv := range res.Headers {
// Avoid overriding content-length and hop-by-hop headers.
if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Connection") {
if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Connection") || strings.EqualFold(k, "Www-Authenticate") {
continue
}
for _, v := range vv {

View File

@@ -10,6 +10,7 @@ import (
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -29,11 +30,16 @@ func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
cfg *config.Config,
) *OpenAIGatewayHandler {
pingInterval := time.Duration(0)
if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
}
return &OpenAIGatewayHandler{
gatewayService: gatewayService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatNone),
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
}
}
@@ -124,6 +130,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
@@ -131,7 +139,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
@@ -201,6 +210,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// Forward request
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)

View File

@@ -25,13 +25,14 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
// Transport 连接池默认配置
const (
defaultMaxIdleConns = 100 // 最大空闲连接数
defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数
defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间
defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间(建议小于上游 LB 超时)
)
// Options 定义共享 HTTP 客户端的构建参数
@@ -40,6 +41,9 @@ type Options struct {
Timeout time.Duration // 请求总超时时间
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
InsecureSkipVerify bool // 是否跳过 TLS 证书验证
ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
ValidateResolvedIP bool // 是否校验解析后的 IP防止 DNS Rebinding
AllowPrivateHosts bool // 允许私有地址解析(与 ValidateResolvedIP 一起使用)
// 可选的连接池参数(不设置则使用默认值)
MaxIdleConns int // 最大空闲连接总数(默认 100
@@ -79,8 +83,12 @@ func buildClient(opts Options) (*http.Client, error) {
return nil, err
}
var rt http.RoundTripper = transport
if opts.ValidateResolvedIP && !opts.AllowPrivateHosts {
rt = &validatedTransport{base: transport}
}
return &http.Client{
Transport: transport,
Transport: rt,
Timeout: opts.Timeout,
}, nil
}
@@ -126,13 +134,32 @@ func buildTransport(opts Options) (*http.Transport, error) {
}
func buildClientKey(opts Options) string {
return fmt.Sprintf("%s|%s|%s|%t|%d|%d|%d",
return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%t|%d|%d|%d",
strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(),
opts.ResponseHeaderTimeout.String(),
opts.InsecureSkipVerify,
opts.ProxyStrict,
opts.ValidateResolvedIP,
opts.AllowPrivateHosts,
opts.MaxIdleConns,
opts.MaxIdleConnsPerHost,
opts.MaxConnsPerHost,
)
}
type validatedTransport struct {
base http.RoundTripper
}
func (t *validatedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if req != nil && req.URL != nil {
host := strings.TrimSpace(req.URL.Hostname())
if host != "" {
if err := urlvalidator.ValidateResolvedIP(host); err != nil {
return nil, err
}
}
}
return t.base.RoundTrip(req)
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
"github.com/imroc/req/v3"
)
@@ -54,7 +55,7 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
log.Printf("[OAuth] Step 1 Response - Status: %d", resp.StatusCode)
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
@@ -84,8 +85,8 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
"code_challenge_method": "S256",
}
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
var result struct {
@@ -113,7 +114,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
@@ -141,7 +142,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
fullCode = authCode + "#" + responseState
}
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", prefix(authCode, 20))
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code")
return fullCode, nil
}
@@ -173,8 +174,8 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds
}
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
var tokenResp oauth.TokenResponse
@@ -191,7 +192,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
return nil, fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
@@ -245,13 +246,3 @@ func createReqClient(proxyURL string) *req.Client {
return client
}
func prefix(s string, n int) string {
if n <= 0 {
return ""
}
if len(s) <= n {
return s
}
return s[:n]
}

View File

@@ -15,7 +15,8 @@ import (
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
type claudeUsageService struct {
usageURL string
usageURL string
allowPrivateHosts bool
}
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
@@ -24,8 +25,10 @@ func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 30 * time.Second,
ProxyURL: proxyURL,
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: s.allowPrivateHosts,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}

View File

@@ -45,7 +45,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
}`)
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
require.NoError(s.T(), err, "FetchUsage")
@@ -64,7 +67,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
_, _ = io.WriteString(w, "nope")
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.Error(s.T(), err)
@@ -78,7 +84,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
_, _ = io.WriteString(w, "not-json")
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.Error(s.T(), err)
@@ -91,7 +100,10 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
<-r.Context().Done()
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately

View File

@@ -14,18 +14,23 @@ import (
)
type githubReleaseClient struct {
httpClient *http.Client
httpClient *http.Client
allowPrivateHosts bool
}
func NewGitHubReleaseClient() service.GitHubReleaseClient {
allowPrivate := false
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: allowPrivate,
})
if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second}
}
return &githubReleaseClient{
httpClient: sharedClient,
httpClient: sharedClient,
allowPrivateHosts: allowPrivate,
}
}
@@ -64,7 +69,9 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
}
downloadClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 10 * time.Minute,
Timeout: 10 * time.Minute,
ValidateResolvedIP: true,
AllowPrivateHosts: c.allowPrivateHosts,
})
if err != nil {
downloadClient = &http.Client{Timeout: 10 * time.Minute}

View File

@@ -37,6 +37,13 @@ func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return http.DefaultTransport.RoundTrip(newReq)
}
func newTestGitHubReleaseClient() *githubReleaseClient {
return &githubReleaseClient{
httpClient: &http.Client{},
allowPrivateHosts: true,
}
}
func (s *GitHubReleaseServiceSuite) SetupTest() {
s.tempDir = s.T().TempDir()
}
@@ -55,9 +62,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLeng
_, _ = w.Write(bytes.Repeat([]byte("a"), 100))
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "file1.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
@@ -82,9 +87,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
}
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "file2.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
@@ -108,9 +111,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
}
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "file3.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 200)
@@ -127,9 +128,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
w.WriteHeader(http.StatusNotFound)
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "notfound.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
@@ -145,9 +144,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
_, _ = w.Write([]byte("sum"))
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
body, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
require.NoError(s.T(), err, "FetchChecksumFile")
@@ -159,9 +156,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
w.WriteHeader(http.StatusInternalServerError)
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
_, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
require.Error(s.T(), err, "expected error for non-200")
@@ -172,9 +167,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
<-r.Context().Done()
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
ctx, cancel := context.WithCancel(context.Background())
cancel()
@@ -185,9 +178,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() {
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "invalid.bin")
err := s.client.DownloadFile(context.Background(), "://invalid-url", dest, 100)
@@ -200,9 +191,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
_, _ = w.Write([]byte("content"))
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
// Use a path that cannot be created (directory doesn't exist)
dest := filepath.Join(s.tempDir, "nonexistent", "subdir", "file.bin")
@@ -211,9 +200,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_InvalidURL() {
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
_, err := s.client.FetchChecksumFile(context.Background(), "://invalid-url")
require.Error(s.T(), err, "expected error for invalid URL")
@@ -247,6 +234,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
}
release, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
@@ -266,6 +254,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
}
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
@@ -283,6 +272,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
}
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
@@ -298,6 +288,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
}
ctx, cancel := context.WithCancel(context.Background())
@@ -312,9 +303,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() {
<-r.Context().Done()
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client = newTestGitHubReleaseClient()
ctx, cancel := context.WithCancel(context.Background())
cancel()

View File

@@ -15,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
// 默认配置常量
@@ -30,9 +31,9 @@ const (
// defaultMaxConnsPerHost: 默认每主机最大连接数(含活跃连接)
// 达到上限后新请求会等待,而非无限创建连接
defaultMaxConnsPerHost = 240
// defaultIdleConnTimeout: 默认空闲连接超时时间(5分钟
// 超时后连接会被关闭,释放系统资源
defaultIdleConnTimeout = 300 * time.Second
// defaultIdleConnTimeout: 默认空闲连接超时时间(90秒
// 超时后连接会被关闭,释放系统资源(建议小于上游 LB 超时)
defaultIdleConnTimeout = 90 * time.Second
// defaultResponseHeaderTimeout: 默认等待响应头超时时间5分钟
// LLM 请求可能排队较久,需要较长超时
defaultResponseHeaderTimeout = 300 * time.Second
@@ -120,6 +121,10 @@ func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
// - 调用方必须关闭 resp.Body否则会导致 inFlight 计数泄漏
// - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
if err := s.validateRequestHost(req); err != nil {
return nil, err
}
// 获取或创建对应的客户端,并标记请求占用
entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency)
if err != nil {
@@ -145,6 +150,37 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
return resp, nil
}
func (s *httpUpstreamService) shouldValidateResolvedIP() bool {
if s.cfg == nil {
return false
}
return !s.cfg.Security.URLAllowlist.AllowPrivateHosts
}
func (s *httpUpstreamService) validateRequestHost(req *http.Request) error {
if !s.shouldValidateResolvedIP() {
return nil
}
if req == nil || req.URL == nil {
return errors.New("request url is nil")
}
host := strings.TrimSpace(req.URL.Hostname())
if host == "" {
return errors.New("request host is empty")
}
if err := urlvalidator.ValidateResolvedIP(host); err != nil {
return err
}
return nil
}
func (s *httpUpstreamService) redirectChecker(req *http.Request, via []*http.Request) error {
if len(via) >= 10 {
return errors.New("stopped after 10 redirects")
}
return s.validateRequestHost(req)
}
// acquireClient 获取或创建客户端,并标记为进行中请求
// 用于请求路径,避免在获取后被淘汰
func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
@@ -232,6 +268,9 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
return nil, fmt.Errorf("build transport: %w", err)
}
client := &http.Client{Transport: transport}
if s.shouldValidateResolvedIP() {
client.CheckRedirect = s.redirectChecker
}
entry := &upstreamClientEntry{
client: client,
proxyKey: proxyKey,

View File

@@ -22,7 +22,13 @@ type HTTPUpstreamSuite struct {
// SetupTest 每个测试用例执行前的初始化
// 创建空配置,各测试用例可按需覆盖
func (s *HTTPUpstreamSuite) SetupTest() {
s.cfg = &config.Config{}
s.cfg = &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
AllowPrivateHosts: true,
},
},
}
}
// newService 创建测试用的 httpUpstreamService 实例

View File

@@ -8,6 +8,7 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
@@ -16,9 +17,15 @@ type pricingRemoteClient struct {
httpClient *http.Client
}
func NewPricingRemoteClient() service.PricingRemoteClient {
func NewPricingRemoteClient(cfg *config.Config) service.PricingRemoteClient {
allowPrivate := false
if cfg != nil {
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
}
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: allowPrivate,
})
if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second}

View File

@@ -6,6 +6,7 @@ import (
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
@@ -19,7 +20,13 @@ type PricingServiceSuite struct {
func (s *PricingServiceSuite) SetupTest() {
s.ctx = context.Background()
client, ok := NewPricingRemoteClient().(*pricingRemoteClient)
client, ok := NewPricingRemoteClient(&config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
AllowPrivateHosts: true,
},
},
}).(*pricingRemoteClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
}

View File

@@ -5,28 +5,48 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func NewProxyExitInfoProber() service.ProxyExitInfoProber {
return &proxyProbeService{ipInfoURL: defaultIPInfoURL}
func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
insecure := false
allowPrivate := false
if cfg != nil {
insecure = cfg.Security.ProxyProbe.InsecureSkipVerify
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
}
if insecure {
log.Printf("[ProxyProbe] Warning: TLS verification is disabled for proxy probing.")
}
return &proxyProbeService{
ipInfoURL: defaultIPInfoURL,
insecureSkipVerify: insecure,
allowPrivateHosts: allowPrivate,
}
}
const defaultIPInfoURL = "https://ipinfo.io/json"
type proxyProbeService struct {
ipInfoURL string
ipInfoURL string
insecureSkipVerify bool
allowPrivateHosts bool
}
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 15 * time.Second,
InsecureSkipVerify: true,
InsecureSkipVerify: s.insecureSkipVerify,
ProxyStrict: true,
ValidateResolvedIP: true,
AllowPrivateHosts: s.allowPrivateHosts,
})
if err != nil {
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)

View File

@@ -20,7 +20,10 @@ type ProxyProbeServiceSuite struct {
func (s *ProxyProbeServiceSuite) SetupTest() {
s.ctx = context.Background()
s.prober = &proxyProbeService{ipInfoURL: "http://ipinfo.test/json"}
s.prober = &proxyProbeService{
ipInfoURL: "http://ipinfo.test/json",
allowPrivateHosts: true,
}
}
func (s *ProxyProbeServiceSuite) TearDownTest() {

View File

@@ -22,7 +22,8 @@ type turnstileVerifier struct {
func NewTurnstileVerifier() service.TurnstileVerifier {
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 10 * time.Second,
Timeout: 10 * time.Second,
ValidateResolvedIP: true,
})
if err != nil {
sharedClient = &http.Client{Timeout: 10 * time.Second}

View File

@@ -294,13 +294,13 @@ func TestAPIContracts(t *testing.T) {
"smtp_host": "smtp.example.com",
"smtp_port": 587,
"smtp_username": "user",
"smtp_password": "secret",
"smtp_password_configured": true,
"smtp_from_email": "no-reply@example.com",
"smtp_from_name": "Sub2API",
"smtp_use_tls": true,
"turnstile_enabled": true,
"turnstile_site_key": "site-key",
"turnstile_secret_key": "secret-key",
"turnstile_secret_key_configured": true,
"site_name": "Sub2API",
"site_logo": "",
"site_subtitle": "Subtitle",

View File

@@ -2,6 +2,7 @@
package server
import (
"log"
"net/http"
"time"
@@ -36,6 +37,15 @@ func ProvideRouter(
r := gin.New()
r.Use(middleware2.Recovery())
if len(cfg.Server.TrustedProxies) > 0 {
if err := r.SetTrustedProxies(cfg.Server.TrustedProxies); err != nil {
log.Printf("Failed to set trusted proxies: %v", err)
}
} else {
if err := r.SetTrustedProxies(nil); err != nil {
log.Printf("Failed to disable trusted proxies: %v", err)
}
}
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg)
}

View File

@@ -19,6 +19,13 @@ func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionS
// apiKeyAuthWithSubscription API Key认证中间件支持订阅验证
func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
queryKey := strings.TrimSpace(c.Query("key"))
queryApiKey := strings.TrimSpace(c.Query("api_key"))
if queryKey != "" || queryApiKey != "" {
AbortWithError(c, 400, "api_key_in_query_deprecated", "API key in query parameter is deprecated. Please use Authorization header instead.")
return
}
// 尝试从Authorization header中提取API key (Bearer scheme)
authHeader := c.GetHeader("Authorization")
var apiKeyString string
@@ -41,19 +48,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
apiKeyString = c.GetHeader("x-goog-api-key")
}
// 如果header中没有尝试从query参数中提取Google API key风格
if apiKeyString == "" {
apiKeyString = c.Query("key")
}
// 兼容常见别名
if apiKeyString == "" {
apiKeyString = c.Query("api_key")
}
// 如果所有header都没有API key
if apiKeyString == "" {
AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter")
AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, or x-goog-api-key header")
return
}

View File

@@ -22,6 +22,10 @@ func APIKeyAuthGoogle(apiKeyService *service.APIKeyService, cfg *config.Config)
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
if v := strings.TrimSpace(c.Query("api_key")); v != "" {
abortWithGoogleError(c, 400, "Query parameter api_key is deprecated. Use Authorization header or key instead.")
return
}
apiKeyString := extractAPIKeyFromRequest(c)
if apiKeyString == "" {
abortWithGoogleError(c, 401, "API key is required")
@@ -116,15 +120,18 @@ func extractAPIKeyFromRequest(c *gin.Context) string {
if v := strings.TrimSpace(c.GetHeader("x-goog-api-key")); v != "" {
return v
}
if v := strings.TrimSpace(c.Query("key")); v != "" {
return v
}
if v := strings.TrimSpace(c.Query("api_key")); v != "" {
return v
if allowGoogleQueryKey(c.Request.URL.Path) {
if v := strings.TrimSpace(c.Query("key")); v != "" {
return v
}
}
return ""
}
func allowGoogleQueryKey(path string) bool {
return strings.HasPrefix(path, "/v1beta") || strings.HasPrefix(path, "/antigravity/v1beta")
}
func abortWithGoogleError(c *gin.Context, status int, message string) {
c.JSON(status, gin.H{
"error": gin.H{

View File

@@ -109,6 +109,58 @@ func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return nil, errors.New("should not be called")
},
})
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test?api_key=legacy", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
var resp googleErrorResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, http.StatusBadRequest, resp.Error.Code)
require.Equal(t, "Query parameter api_key is deprecated. Use Authorization header or key instead.", resp.Error.Message)
require.Equal(t, "INVALID_ARGUMENT", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return &service.APIKey{
ID: 1,
Key: key,
Status: service.StatusActive,
User: &service.User{
ID: 123,
Status: service.StatusActive,
},
}, nil
},
})
cfg := &config.Config{RunMode: config.RunModeSimple}
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test?key=valid", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@@ -1,24 +1,103 @@
package middleware
import (
"log"
"net/http"
"strings"
"sync"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
var corsWarningOnce sync.Once
// CORS 跨域中间件
func CORS() gin.HandlerFunc {
func CORS(cfg config.CORSConfig) gin.HandlerFunc {
allowedOrigins := normalizeOrigins(cfg.AllowedOrigins)
allowAll := false
for _, origin := range allowedOrigins {
if origin == "*" {
allowAll = true
break
}
}
wildcardWithSpecific := allowAll && len(allowedOrigins) > 1
if wildcardWithSpecific {
allowedOrigins = []string{"*"}
}
allowCredentials := cfg.AllowCredentials
corsWarningOnce.Do(func() {
if len(allowedOrigins) == 0 {
log.Println("Warning: CORS allowed_origins not configured; cross-origin requests will be rejected.")
}
if wildcardWithSpecific {
log.Println("Warning: CORS allowed_origins includes '*'; wildcard will take precedence over explicit origins.")
}
if allowAll && allowCredentials {
log.Println("Warning: CORS allowed_origins set to '*', disabling allow_credentials.")
}
})
if allowAll && allowCredentials {
allowCredentials = false
}
allowedSet := make(map[string]struct{}, len(allowedOrigins))
for _, origin := range allowedOrigins {
if origin == "" || origin == "*" {
continue
}
allowedSet[origin] = struct{}{}
}
return func(c *gin.Context) {
// 设置允许跨域的响应头
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
origin := strings.TrimSpace(c.GetHeader("Origin"))
originAllowed := allowAll
if origin != "" && !allowAll {
_, originAllowed = allowedSet[origin]
}
if originAllowed {
if allowAll {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
} else if origin != "" {
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
c.Writer.Header().Add("Vary", "Origin")
}
if allowCredentials {
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
}
}
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
// 处理预检请求
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
if c.Request.Method == http.MethodOptions {
if originAllowed {
c.AbortWithStatus(http.StatusNoContent)
} else {
c.AbortWithStatus(http.StatusForbidden)
}
return
}
c.Next()
}
}
func normalizeOrigins(values []string) []string {
if len(values) == 0 {
return nil
}
normalized := make([]string, 0, len(values))
for _, value := range values {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
continue
}
normalized = append(normalized, trimmed)
}
return normalized
}

View File

@@ -0,0 +1,26 @@
package middleware
import (
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
// SecurityHeaders sets baseline security headers for all responses.
func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
policy := strings.TrimSpace(cfg.Policy)
if policy == "" {
policy = config.DefaultCSPPolicy
}
return func(c *gin.Context) {
c.Header("X-Content-Type-Options", "nosniff")
c.Header("X-Frame-Options", "DENY")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
if cfg.Enabled {
c.Header("Content-Security-Policy", policy)
}
c.Next()
}
}

View File

@@ -24,7 +24,8 @@ func SetupRouter(
) *gin.Engine {
// 应用中间件
r.Use(middleware2.Logger())
r.Use(middleware2.CORS())
r.Use(middleware2.CORS(cfg.CORS))
r.Use(middleware2.SecurityHeaders(cfg.Security.CSP))
// Serve embedded frontend if available
if web.HasEmbeddedFrontend() {

View File

@@ -7,6 +7,7 @@ import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
@@ -14,9 +15,11 @@ import (
"regexp"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
@@ -45,6 +48,7 @@ type AccountTestService struct {
geminiTokenProvider *GeminiTokenProvider
antigravityGatewayService *AntigravityGatewayService
httpUpstream HTTPUpstream
cfg *config.Config
}
// NewAccountTestService creates a new AccountTestService
@@ -53,15 +57,32 @@ func NewAccountTestService(
geminiTokenProvider *GeminiTokenProvider,
antigravityGatewayService *AntigravityGatewayService,
httpUpstream HTTPUpstream,
cfg *config.Config,
) *AccountTestService {
return &AccountTestService{
accountRepo: accountRepo,
geminiTokenProvider: geminiTokenProvider,
antigravityGatewayService: antigravityGatewayService,
httpUpstream: httpUpstream,
cfg: cfg,
}
}
func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg == nil {
return "", errors.New("config is not available")
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", err
}
return normalized, nil
}
// generateSessionString generates a Claude Code style session string
func generateSessionString() (string, error) {
bytes := make([]byte, 32)
@@ -183,11 +204,15 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
return s.sendErrorAndEnd(c, "No API key available")
}
apiURL = account.GetBaseURL()
if apiURL == "" {
apiURL = "https://api.anthropic.com"
baseURL := account.GetBaseURL()
if baseURL == "" {
baseURL = "https://api.anthropic.com"
}
apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages"
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
}
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages"
} else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
@@ -300,7 +325,11 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
if baseURL == "" {
baseURL = "https://api.openai.com"
}
apiURL = strings.TrimSuffix(baseURL, "/") + "/responses"
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
}
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/responses"
} else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
@@ -480,10 +509,14 @@ func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, accou
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
// Use streamGenerateContent for real-time feedback
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse",
strings.TrimRight(baseURL, "/"), modelID)
strings.TrimRight(normalizedBaseURL, "/"), modelID)
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(payload))
if err != nil {
@@ -515,7 +548,11 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun
if strings.TrimSpace(baseURL) == "" {
baseURL = geminicli.AIStudioBaseURL
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(baseURL, "/"), modelID)
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(normalizedBaseURL, "/"), modelID)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
if err != nil {
@@ -544,7 +581,11 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT
}
wrappedBytes, _ := json.Marshal(wrapped)
fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", geminicli.GeminiCliBaseURL)
normalizedBaseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
if err != nil {
return nil, err
}
fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", normalizedBaseURL)
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(wrappedBytes))
if err != nil {

View File

@@ -11,6 +11,7 @@ import (
"log"
"net/http"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
@@ -916,57 +917,145 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
return nil, errors.New("streaming not supported")
}
reader := bufio.NewReader(resp.Body)
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
usage := &ClaudeUsage{}
var firstTokenMs *int
type scanEvent struct {
line string
err error
}
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
streamInterval := time.Duration(0)
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
for {
line, err := reader.ReadString('\n')
if len(line) > 0 {
select {
case ev, ok := <-events:
if !ok {
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return nil, ev.err
}
line := ev.line
trimmed := strings.TrimRight(line, "\r\n")
if strings.HasPrefix(trimmed, "data:") {
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
if payload == "" || payload == "[DONE]" {
_, _ = io.WriteString(c.Writer, line)
flusher.Flush()
} else {
// 解包 v1internal 响应
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
if parseErr == nil && inner != nil {
payload = string(inner)
}
// 解析 usage
var parsed map[string]any
if json.Unmarshal(inner, &parsed) == nil {
if u := extractGeminiUsage(parsed); u != nil {
usage = u
}
}
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", payload)
if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
continue
}
} else {
_, _ = io.WriteString(c.Writer, line)
flusher.Flush()
}
}
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, err
// 解包 v1internal 响应
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
if parseErr == nil && inner != nil {
payload = string(inner)
}
// 解析 usage
var parsed map[string]any
if json.Unmarshal(inner, &parsed) == nil {
if u := extractGeminiUsage(parsed); u != nil {
usage = u
}
}
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", payload); err != nil {
sendErrorEvent("write_failed")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
continue
}
if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
log.Printf("Stream data interval timeout (antigravity)")
sendErrorEvent("stream_timeout")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
}
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
func (s *AntigravityGatewayService) handleGeminiNonStreamingResponse(c *gin.Context, resp *http.Response) (*ClaudeUsage, error) {
@@ -1105,7 +1194,13 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
processor := antigravity.NewStreamingProcessor(originalModel)
var firstTokenMs *int
reader := bufio.NewReader(resp.Body)
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
// 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage
convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage {
@@ -1120,13 +1215,85 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
}
}
for {
line, err := reader.ReadString('\n')
if err != nil && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("stream read error: %w", err)
type scanEvent struct {
line string
err error
}
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
if len(line) > 0 {
streamInterval := time.Duration(0)
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
for {
select {
case ev, ok := <-events:
if !ok {
// 发送结束事件
finalEvents, agUsage := processor.Finish()
if len(finalEvents) > 0 {
_, _ = c.Writer.Write(finalEvents)
flusher.Flush()
}
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return nil, fmt.Errorf("stream read error: %w", ev.err)
}
line := ev.line
// 处理 SSE 行,转换为 Claude 格式
claudeEvents := processor.ProcessLine(strings.TrimRight(line, "\r\n"))
@@ -1141,23 +1308,21 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
if len(finalEvents) > 0 {
_, _ = c.Writer.Write(finalEvents)
}
sendErrorEvent("write_failed")
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, writeErr
}
flusher.Flush()
}
}
if errors.Is(err, io.EOF) {
break
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
log.Printf("Stream data interval timeout (antigravity)")
sendErrorEvent("stream_timeout")
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
}
// 发送结束事件
finalEvents, agUsage := processor.Finish()
if len(finalEvents) > 0 {
_, _ = c.Writer.Write(finalEvents)
flusher.Flush()
}
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil
}

View File

@@ -221,9 +221,33 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
// VerifyTurnstile 验证Turnstile token
func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error {
required := s.cfg != nil && s.cfg.Server.Mode == "release" && s.cfg.Turnstile.Required
if required {
if s.settingService == nil {
log.Println("[Auth] Turnstile required but settings service is not configured")
return ErrTurnstileNotConfigured
}
enabled := s.settingService.IsTurnstileEnabled(ctx)
secretConfigured := s.settingService.GetTurnstileSecretKey(ctx) != ""
if !enabled || !secretConfigured {
log.Printf("[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)", enabled, secretConfigured)
return ErrTurnstileNotConfigured
}
}
if s.turnstileService == nil {
if required {
log.Println("[Auth] Turnstile required but service not configured")
return ErrTurnstileNotConfigured
}
return nil // 服务未配置则跳过验证
}
if !required && s.settingService != nil && s.settingService.IsTurnstileEnabled(ctx) && s.settingService.GetTurnstileSecretKey(ctx) == "" {
log.Println("[Auth] Turnstile enabled but secret key not configured")
}
return s.turnstileService.VerifyToken(ctx, token, remoteIP)
}

View File

@@ -16,7 +16,8 @@ import (
// 注ErrInsufficientBalance在redeem_service.go中定义
// 注ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
var (
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
)
// subscriptionCacheData 订阅缓存数据结构(内部使用)
@@ -72,10 +73,11 @@ type cacheWriteTask struct {
// BillingCacheService 计费缓存服务
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
type BillingCacheService struct {
cache BillingCache
userRepo UserRepository
subRepo UserSubscriptionRepository
cfg *config.Config
cache BillingCache
userRepo UserRepository
subRepo UserSubscriptionRepository
cfg *config.Config
circuitBreaker *billingCircuitBreaker
cacheWriteChan chan cacheWriteTask
cacheWriteWg sync.WaitGroup
@@ -95,6 +97,7 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo
subRepo: subRepo,
cfg: cfg,
}
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
svc.startCacheWriteWorkers()
return svc
}
@@ -450,6 +453,9 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
if s.cfg.RunMode == config.RunModeSimple {
return nil
}
if s.circuitBreaker != nil && !s.circuitBreaker.Allow() {
return ErrBillingServiceUnavailable
}
// 判断计费模式
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
@@ -465,9 +471,14 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userID int64) error {
balance, err := s.GetUserBalance(ctx, userID)
if err != nil {
// 缓存/数据库错误,允许通过(降级处理)
log.Printf("Warning: get user balance failed, allowing request: %v", err)
return nil
if s.circuitBreaker != nil {
s.circuitBreaker.OnFailure(err)
}
log.Printf("ALERT: billing balance check failed for user %d: %v", userID, err)
return ErrBillingServiceUnavailable.WithCause(err)
}
if s.circuitBreaker != nil {
s.circuitBreaker.OnSuccess()
}
if balance <= 0 {
@@ -482,9 +493,14 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
// 获取订阅缓存数据
subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID)
if err != nil {
// 缓存/数据库错误降级使用传入的subscription进行检查
log.Printf("Warning: get subscription cache failed, using fallback: %v", err)
return s.checkSubscriptionLimitsFallback(subscription, group)
if s.circuitBreaker != nil {
s.circuitBreaker.OnFailure(err)
}
log.Printf("ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err)
return ErrBillingServiceUnavailable.WithCause(err)
}
if s.circuitBreaker != nil {
s.circuitBreaker.OnSuccess()
}
// 检查订阅状态
@@ -513,27 +529,133 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
return nil
}
// checkSubscriptionLimitsFallback 降级检查订阅限额
func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *UserSubscription, group *Group) error {
if subscription == nil {
return ErrSubscriptionInvalid
}
type billingCircuitBreakerState int
if !subscription.IsActive() {
return ErrSubscriptionInvalid
}
const (
billingCircuitClosed billingCircuitBreakerState = iota
billingCircuitOpen
billingCircuitHalfOpen
)
if !subscription.CheckDailyLimit(group, 0) {
return ErrDailyLimitExceeded
}
if !subscription.CheckWeeklyLimit(group, 0) {
return ErrWeeklyLimitExceeded
}
if !subscription.CheckMonthlyLimit(group, 0) {
return ErrMonthlyLimitExceeded
}
return nil
type billingCircuitBreaker struct {
mu sync.Mutex
state billingCircuitBreakerState
failures int
openedAt time.Time
failureThreshold int
resetTimeout time.Duration
halfOpenRequests int
halfOpenRemaining int
}
func newBillingCircuitBreaker(cfg config.CircuitBreakerConfig) *billingCircuitBreaker {
if !cfg.Enabled {
return nil
}
resetTimeout := time.Duration(cfg.ResetTimeoutSeconds) * time.Second
if resetTimeout <= 0 {
resetTimeout = 30 * time.Second
}
halfOpen := cfg.HalfOpenRequests
if halfOpen <= 0 {
halfOpen = 1
}
threshold := cfg.FailureThreshold
if threshold <= 0 {
threshold = 5
}
return &billingCircuitBreaker{
state: billingCircuitClosed,
failureThreshold: threshold,
resetTimeout: resetTimeout,
halfOpenRequests: halfOpen,
}
}
func (b *billingCircuitBreaker) Allow() bool {
b.mu.Lock()
defer b.mu.Unlock()
switch b.state {
case billingCircuitClosed:
return true
case billingCircuitOpen:
if time.Since(b.openedAt) < b.resetTimeout {
return false
}
b.state = billingCircuitHalfOpen
b.halfOpenRemaining = b.halfOpenRequests
log.Printf("ALERT: billing circuit breaker entering half-open state")
fallthrough
case billingCircuitHalfOpen:
if b.halfOpenRemaining <= 0 {
return false
}
b.halfOpenRemaining--
return true
default:
return false
}
}
func (b *billingCircuitBreaker) OnFailure(err error) {
if b == nil {
return
}
b.mu.Lock()
defer b.mu.Unlock()
switch b.state {
case billingCircuitOpen:
return
case billingCircuitHalfOpen:
b.state = billingCircuitOpen
b.openedAt = time.Now()
b.halfOpenRemaining = 0
log.Printf("ALERT: billing circuit breaker opened after half-open failure: %v", err)
return
default:
b.failures++
if b.failures >= b.failureThreshold {
b.state = billingCircuitOpen
b.openedAt = time.Now()
b.halfOpenRemaining = 0
log.Printf("ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err)
}
}
}
func (b *billingCircuitBreaker) OnSuccess() {
if b == nil {
return
}
b.mu.Lock()
defer b.mu.Unlock()
previousState := b.state
previousFailures := b.failures
b.state = billingCircuitClosed
b.failures = 0
b.halfOpenRemaining = 0
// 只有状态真正发生变化时才记录日志
if previousState != billingCircuitClosed {
log.Printf("ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState))
} else if previousFailures > 0 {
log.Printf("INFO: billing circuit breaker failures reset from %d", previousFailures)
}
}
func circuitStateString(state billingCircuitBreakerState) string {
switch state {
case billingCircuitClosed:
return "closed"
case billingCircuitOpen:
return "open"
case billingCircuitHalfOpen:
return "half-open"
default:
return "unknown"
}
}

View File

@@ -8,12 +8,13 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
type CRSSyncService struct {
@@ -22,6 +23,7 @@ type CRSSyncService struct {
oauthService *OAuthService
openaiOAuthService *OpenAIOAuthService
geminiOAuthService *GeminiOAuthService
cfg *config.Config
}
func NewCRSSyncService(
@@ -30,6 +32,7 @@ func NewCRSSyncService(
oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
cfg *config.Config,
) *CRSSyncService {
return &CRSSyncService{
accountRepo: accountRepo,
@@ -37,6 +40,7 @@ func NewCRSSyncService(
oauthService: oauthService,
openaiOAuthService: openaiOAuthService,
geminiOAuthService: geminiOAuthService,
cfg: cfg,
}
}
@@ -187,7 +191,10 @@ type crsGeminiAPIKeyAccount struct {
}
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
baseURL, err := normalizeBaseURL(input.BaseURL)
if s.cfg == nil {
return nil, errors.New("config is not available")
}
baseURL, err := normalizeBaseURL(input.BaseURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
if err != nil {
return nil, err
}
@@ -196,7 +203,9 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
client, err := httpclient.GetClient(httpclient.Options{
Timeout: 20 * time.Second,
Timeout: 20 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
client = &http.Client{Timeout: 20 * time.Second}
@@ -1055,17 +1064,18 @@ func mapCRSStatus(isActive bool, status string) string {
return "active"
}
func normalizeBaseURL(raw string) (string, error) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return "", errors.New("base_url is required")
func normalizeBaseURL(raw string, allowlist []string, allowPrivate bool) (string, error) {
// 当 allowlist 为空时,不强制要求白名单(只进行基本的 URL 和 SSRF 验证)
requireAllowlist := len(allowlist) > 0
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: allowlist,
RequireAllowlist: requireAllowlist,
AllowPrivate: allowPrivate,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
u, err := url.Parse(trimmed)
if err != nil || u.Scheme == "" || u.Host == "" {
return "", fmt.Errorf("invalid base_url: %s", trimmed)
}
u.Path = strings.TrimRight(u.Path, "/")
return strings.TrimRight(u.String(), "/"), nil
return normalized, nil
}
// cleanBaseURL removes trailing suffix from base_url in credentials

View File

@@ -15,11 +15,14 @@ import (
"regexp"
"sort"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
@@ -30,6 +33,7 @@ const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL = time.Hour // 粘性会话TTL
defaultMaxLineSize = 10 * 1024 * 1024
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
)
@@ -1225,7 +1229,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
targetURL := claudeAPIURL
if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages"
if baseURL != "" {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/v1/messages"
}
}
// OAuth账号应用统一指纹
@@ -1594,51 +1604,133 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
var firstTokenMs *int
scanner := bufio.NewScanner(resp.Body)
// 设置更大的buffer以处理长行
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
type scanEvent struct {
line string
err error
}
// 独立 goroutine 读取上游,避免读取阻塞导致超时/keepalive无法处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
// 仅监控上游数据间隔超时,避免下游写入阻塞导致误判
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
needModelReplace := originalModel != mappedModel
for scanner.Scan() {
line := scanner.Text()
if line == "event: error" {
return nil, errors.New("have error in stream")
}
// Extract data from SSE line (supports both "data: " and "data:" formats)
if sseDataRe.MatchString(line) {
data := sseDataRe.ReplaceAllString(line, "")
// 如果有模型映射替换响应中的model字段
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
for {
select {
case ev, ok := <-events:
if !ok {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
}
line := ev.line
if line == "event: error" {
return nil, errors.New("have error in stream")
}
// 转发行
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
// Extract data from SSE line (supports both "data: " and "data:" formats)
if sseDataRe.MatchString(line) {
data := sseDataRe.ReplaceAllString(line, "")
// 记录首字时间:第一个有效的 content_block_delta 或 message_start
if firstTokenMs == nil && data != "" && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
// 如果有模型映射替换响应中的model字段
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
// 转发行
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
// 记录首字时间:第一个有效的 content_block_delta 或 message_start
if firstTokenMs == nil && data != "" && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
s.parseSSEUsage(data, usage)
} else {
// 非 data 行直接转发
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
}
s.parseSSEUsage(data, usage)
} else {
// 非 data 行直接转发
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
flusher.Flush()
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
sendErrorEvent("stream_timeout")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
}
if err := scanner.Err(); err != nil {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
// replaceModelInSSELine 替换SSE数据行中的model字段
@@ -1743,12 +1835,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
// 透传响应头
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
// 写入响应
c.Data(resp.StatusCode, "application/json", body)
@@ -2020,7 +2107,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages/count_tokens"
if baseURL != "" {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/v1/messages/count_tokens"
}
}
// OAuth 账号:应用统一指纹和重写 userID
@@ -2100,6 +2193,18 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
})
}
func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
// GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {

View File

@@ -18,9 +18,12 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
)
@@ -41,6 +44,7 @@ type GeminiMessagesCompatService struct {
rateLimitService *RateLimitService
httpUpstream HTTPUpstream
antigravityGatewayService *AntigravityGatewayService
cfg *config.Config
}
func NewGeminiMessagesCompatService(
@@ -51,6 +55,7 @@ func NewGeminiMessagesCompatService(
rateLimitService *RateLimitService,
httpUpstream HTTPUpstream,
antigravityGatewayService *AntigravityGatewayService,
cfg *config.Config,
) *GeminiMessagesCompatService {
return &GeminiMessagesCompatService{
accountRepo: accountRepo,
@@ -60,6 +65,7 @@ func NewGeminiMessagesCompatService(
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
antigravityGatewayService: antigravityGatewayService,
cfg: cfg,
}
}
@@ -230,6 +236,18 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit
return s.antigravityGatewayService
}
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
var accounts []Account
@@ -381,16 +399,20 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return nil, "", errors.New("gemini api_key not configured")
}
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
action := "generateContent"
if req.Stream {
action = "streamGenerateContent"
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, action)
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
if req.Stream {
fullURL += "?alt=sse"
}
@@ -427,7 +449,11 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if projectID != "" {
// Mode 1: Code Assist API
fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, action)
baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), action)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -453,12 +479,16 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return upstreamReq, "x-request-id", nil
} else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, action)
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -650,12 +680,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return nil, "", errors.New("gemini api_key not configured")
}
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, upstreamAction)
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -687,7 +721,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if projectID != "" && !forceAIStudio {
// Mode 1: Code Assist API
fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, upstreamAction)
baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -713,12 +751,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return upstreamReq, "x-request-id", nil
} else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, upstreamAction)
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -1652,6 +1694,8 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
_ = json.Unmarshal(respBody, &parsed)
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
@@ -1773,11 +1817,15 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
return nil, errors.New("invalid path")
}
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
fullURL := strings.TrimRight(baseURL, "/") + path
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
fullURL := strings.TrimRight(normalizedBaseURL, "/") + path
var proxyURL string
if account.ProxyID != nil && account.Proxy != nil {
@@ -1816,9 +1864,14 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
wwwAuthenticate := resp.Header.Get("Www-Authenticate")
filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.cfg.Security.ResponseHeaders)
if wwwAuthenticate != "" {
filteredHeaders.Set("Www-Authenticate", wwwAuthenticate)
}
return &UpstreamHTTPResult{
StatusCode: resp.StatusCode,
Headers: resp.Header.Clone(),
Headers: filteredHeaders,
Body: body,
}, nil
}

View File

@@ -1000,8 +1000,9 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: strings.TrimSpace(proxyURL),
Timeout: 30 * time.Second,
ProxyURL: strings.TrimSpace(proxyURL),
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}

View File

@@ -16,9 +16,12 @@ import (
"sort"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
)
@@ -630,10 +633,14 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
case AccountTypeAPIKey:
// API Key accounts use Platform API or custom base URL
baseURL := account.GetOpenAIBaseURL()
if baseURL != "" {
targetURL = baseURL + "/responses"
} else {
if baseURL == "" {
targetURL = openaiPlatformAPIURL
} else {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/responses"
}
default:
targetURL = openaiPlatformAPIURL
@@ -775,48 +782,158 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
usage := &OpenAIUsage{}
var firstTokenMs *int
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
type scanEvent struct {
line string
err error
}
// 独立 goroutine 读取上游,避免读取阻塞影响 keepalive/超时处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
// 仅监控上游数据间隔超时,不被下游写入阻塞影响
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
keepaliveInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
// 下游 keepalive 仅用于防止代理空闲断开
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
lastDataAt := time.Now()
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
needModelReplace := originalModel != mappedModel
for scanner.Scan() {
line := scanner.Text()
// Extract data from SSE line (supports both "data: " and "data:" formats)
if openaiSSEDataRe.MatchString(line) {
data := openaiSSEDataRe.ReplaceAllString(line, "")
// Replace model in response if needed
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
for {
select {
case ev, ok := <-events:
if !ok {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
}
// Forward line
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
line := ev.line
lastDataAt = time.Now()
// Record first token time
if firstTokenMs == nil && data != "" && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
// Extract data from SSE line (supports both "data: " and "data:" formats)
if openaiSSEDataRe.MatchString(line) {
data := openaiSSEDataRe.ReplaceAllString(line, "")
// Replace model in response if needed
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
// Forward line
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
// Record first token time
if firstTokenMs == nil && data != "" && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
s.parseSSEUsage(data, usage)
} else {
// Forward non-data lines as-is
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
}
s.parseSSEUsage(data, usage)
} else {
// Forward non-data lines as-is
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
sendErrorEvent("stream_timeout")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
}
}
if err := scanner.Err(); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
}
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
@@ -911,18 +1028,25 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
// Pass through headers
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
c.Data(resp.StatusCode, "application/json", body)
return usage, nil
}
func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) {
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {

View File

@@ -0,0 +1,90 @@
package service
import (
"bufio"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
func TestOpenAIStreamingTimeout(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 1,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
start := time.Now()
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, start, "model", "model")
_ = pw.Close()
_ = pr.Close()
if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") {
t.Fatalf("expected stream timeout error, got %v", err)
}
if !strings.Contains(rec.Body.String(), "stream_timeout") {
t.Fatalf("expected stream_timeout SSE error, got %q", rec.Body.String())
}
}
func TestOpenAIStreamingTooLong(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: 64 * 1024,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
// 写入超过 MaxLineSize 的单行数据,触发 ErrTooLong
payload := "data: " + strings.Repeat("a", 128*1024) + "\n"
_, _ = pw.Write([]byte(payload))
}()
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 2}, time.Now(), "model", "model")
_ = pr.Close()
if !errors.Is(err, bufio.ErrTooLong) {
t.Fatalf("expected ErrTooLong, got %v", err)
}
if !strings.Contains(rec.Body.String(), "response_too_large") {
t.Fatalf("expected response_too_large SSE error, got %q", rec.Body.String())
}
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
var (
@@ -211,16 +212,35 @@ func (s *PricingService) syncWithRemote() error {
// downloadPricingData 从远程下载价格数据
func (s *PricingService) downloadPricingData() error {
log.Printf("[Pricing] Downloading from %s", s.cfg.Pricing.RemoteURL)
remoteURL, err := s.validatePricingURL(s.cfg.Pricing.RemoteURL)
if err != nil {
return err
}
log.Printf("[Pricing] Downloading from %s", remoteURL)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
body, err := s.remoteClient.FetchPricingJSON(ctx, s.cfg.Pricing.RemoteURL)
var expectedHash string
if strings.TrimSpace(s.cfg.Pricing.HashURL) != "" {
expectedHash, err = s.fetchRemoteHash()
if err != nil {
return fmt.Errorf("fetch remote hash: %w", err)
}
}
body, err := s.remoteClient.FetchPricingJSON(ctx, remoteURL)
if err != nil {
return fmt.Errorf("download failed: %w", err)
}
if expectedHash != "" {
actualHash := sha256.Sum256(body)
if !strings.EqualFold(expectedHash, hex.EncodeToString(actualHash[:])) {
return fmt.Errorf("pricing hash mismatch")
}
}
// 解析JSON数据使用灵活的解析方式
data, err := s.parsePricingData(body)
if err != nil {
@@ -373,10 +393,31 @@ func (s *PricingService) useFallbackPricing() error {
// fetchRemoteHash 从远程获取哈希值
func (s *PricingService) fetchRemoteHash() (string, error) {
hashURL, err := s.validatePricingURL(s.cfg.Pricing.HashURL)
if err != nil {
return "", err
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
return s.remoteClient.FetchHashText(ctx, s.cfg.Pricing.HashURL)
hash, err := s.remoteClient.FetchHashText(ctx, hashURL)
if err != nil {
return "", err
}
return strings.TrimSpace(hash), nil
}
func (s *PricingService) validatePricingURL(raw string) (string, error) {
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.PricingHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid pricing url: %w", err)
}
return normalized, nil
}
// computeFileHash 计算文件哈希

View File

@@ -221,21 +221,23 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// parseSettings 解析设置到结构体
func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings {
result := &SystemSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
SMTPHost: settings[SettingKeySMTPHost],
SMTPUsername: settings[SettingKeySMTPUsername],
SMTPFrom: settings[SettingKeySMTPFrom],
SMTPFromName: settings[SettingKeySMTPFromName],
SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
APIBaseURL: settings[SettingKeyAPIBaseURL],
ContactInfo: settings[SettingKeyContactInfo],
DocURL: settings[SettingKeyDocURL],
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
SMTPHost: settings[SettingKeySMTPHost],
SMTPUsername: settings[SettingKeySMTPUsername],
SMTPFrom: settings[SettingKeySMTPFrom],
SMTPFromName: settings[SettingKeySMTPFromName],
SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true",
SMTPPasswordConfigured: settings[SettingKeySMTPPassword] != "",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "",
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
APIBaseURL: settings[SettingKeyAPIBaseURL],
ContactInfo: settings[SettingKeyContactInfo],
DocURL: settings[SettingKeyDocURL],
}
// 解析整数类型

View File

@@ -4,17 +4,19 @@ type SystemSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
SMTPHost string
SMTPPort int
SMTPUsername string
SMTPPassword string
SMTPFrom string
SMTPFromName string
SMTPUseTLS bool
SMTPHost string
SMTPPort int
SMTPUsername string
SMTPPassword string
SMTPPasswordConfigured bool
SMTPFrom string
SMTPFromName string
SMTPUseTLS bool
TurnstileEnabled bool
TurnstileSiteKey string
TurnstileSecretKey string
TurnstileEnabled bool
TurnstileSiteKey string
TurnstileSecretKey string
TurnstileSecretKeyConfigured bool
SiteName string
SiteLogo string

View File

@@ -9,6 +9,7 @@ import (
"log"
"os"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/repository"
@@ -196,11 +197,17 @@ func Install(cfg *SetupConfig) error {
// Generate JWT secret if not provided
if cfg.JWT.Secret == "" {
if strings.EqualFold(cfg.Server.Mode, "release") {
return fmt.Errorf("jwt secret is required in release mode")
}
secret, err := generateSecret(32)
if err != nil {
return fmt.Errorf("failed to generate jwt secret: %w", err)
}
cfg.JWT.Secret = secret
log.Println("Warning: JWT secret auto-generated for non-release mode. Do not use in production.")
} else if strings.EqualFold(cfg.Server.Mode, "release") && len(cfg.JWT.Secret) < 32 {
return fmt.Errorf("jwt secret must be at least 32 characters in release mode")
}
// Test connections
@@ -474,12 +481,17 @@ func AutoSetupFromEnv() error {
// Generate JWT secret if not provided
if cfg.JWT.Secret == "" {
if strings.EqualFold(cfg.Server.Mode, "release") {
return fmt.Errorf("jwt secret is required in release mode")
}
secret, err := generateSecret(32)
if err != nil {
return fmt.Errorf("failed to generate jwt secret: %w", err)
}
cfg.JWT.Secret = secret
log.Println("Generated JWT secret automatically")
log.Println("Warning: JWT secret auto-generated for non-release mode. Do not use in production.")
} else if strings.EqualFold(cfg.Server.Mode, "release") && len(cfg.JWT.Secret) < 32 {
return fmt.Errorf("jwt secret must be at least 32 characters in release mode")
}
// Generate admin password if not provided
@@ -489,8 +501,8 @@ func AutoSetupFromEnv() error {
return fmt.Errorf("failed to generate admin password: %w", err)
}
cfg.Admin.Password = password
log.Printf("Generated admin password: %s", cfg.Admin.Password)
log.Println("IMPORTANT: Save this password! It will not be shown again.")
fmt.Printf("Generated admin password (one-time): %s\n", cfg.Admin.Password)
fmt.Println("IMPORTANT: Save this password! It will not be shown again.")
}
// Test database connection

View File

@@ -0,0 +1,100 @@
package logredact
import (
"encoding/json"
"strings"
)
// maxRedactDepth 限制递归深度以防止栈溢出
const maxRedactDepth = 32
var defaultSensitiveKeys = map[string]struct{}{
"authorization_code": {},
"code": {},
"code_verifier": {},
"access_token": {},
"refresh_token": {},
"id_token": {},
"client_secret": {},
"password": {},
}
func RedactMap(input map[string]any, extraKeys ...string) map[string]any {
if input == nil {
return map[string]any{}
}
keys := buildKeySet(extraKeys)
redacted, ok := redactValueWithDepth(input, keys, 0).(map[string]any)
if !ok {
return map[string]any{}
}
return redacted
}
func RedactJSON(raw []byte, extraKeys ...string) string {
if len(raw) == 0 {
return ""
}
var value any
if err := json.Unmarshal(raw, &value); err != nil {
return "<non-json payload redacted>"
}
keys := buildKeySet(extraKeys)
redacted := redactValueWithDepth(value, keys, 0)
encoded, err := json.Marshal(redacted)
if err != nil {
return "<redacted>"
}
return string(encoded)
}
func buildKeySet(extraKeys []string) map[string]struct{} {
keys := make(map[string]struct{}, len(defaultSensitiveKeys)+len(extraKeys))
for k := range defaultSensitiveKeys {
keys[k] = struct{}{}
}
for _, key := range extraKeys {
normalized := normalizeKey(key)
if normalized == "" {
continue
}
keys[normalized] = struct{}{}
}
return keys
}
func redactValueWithDepth(value any, keys map[string]struct{}, depth int) any {
if depth > maxRedactDepth {
return "<depth limit exceeded>"
}
switch v := value.(type) {
case map[string]any:
out := make(map[string]any, len(v))
for k, val := range v {
if isSensitiveKey(k, keys) {
out[k] = "***"
continue
}
out[k] = redactValueWithDepth(val, keys, depth+1)
}
return out
case []any:
out := make([]any, len(v))
for i, item := range v {
out[i] = redactValueWithDepth(item, keys, depth+1)
}
return out
default:
return value
}
}
func isSensitiveKey(key string, keys map[string]struct{}) bool {
_, ok := keys[normalizeKey(key)]
return ok
}
func normalizeKey(key string) string {
return strings.ToLower(strings.TrimSpace(key))
}

View File

@@ -0,0 +1,93 @@
package responseheaders
import (
"net/http"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
)
// defaultAllowed 定义允许透传的响应头白名单
// 注意:以下头部由 Go HTTP 包自动处理,不应手动设置:
// - content-length: 由 ResponseWriter 根据实际写入数据自动设置
// - transfer-encoding: 由 HTTP 库根据需要自动添加/移除
// - connection: 由 HTTP 库管理连接复用
var defaultAllowed = map[string]struct{}{
"content-type": {},
"content-encoding": {},
"content-language": {},
"cache-control": {},
"etag": {},
"last-modified": {},
"expires": {},
"vary": {},
"date": {},
"x-request-id": {},
"x-ratelimit-limit-requests": {},
"x-ratelimit-limit-tokens": {},
"x-ratelimit-remaining-requests": {},
"x-ratelimit-remaining-tokens": {},
"x-ratelimit-reset-requests": {},
"x-ratelimit-reset-tokens": {},
"retry-after": {},
"location": {},
"www-authenticate": {},
}
// hopByHopHeaders 是跳过的 hop-by-hop 头部,这些头部由 HTTP 库自动处理
var hopByHopHeaders = map[string]struct{}{
"content-length": {},
"transfer-encoding": {},
"connection": {},
}
func FilterHeaders(src http.Header, cfg config.ResponseHeaderConfig) http.Header {
allowed := make(map[string]struct{}, len(defaultAllowed)+len(cfg.AdditionalAllowed))
for key := range defaultAllowed {
allowed[key] = struct{}{}
}
for _, key := range cfg.AdditionalAllowed {
normalized := strings.ToLower(strings.TrimSpace(key))
if normalized == "" {
continue
}
allowed[normalized] = struct{}{}
}
forceRemove := make(map[string]struct{}, len(cfg.ForceRemove))
for _, key := range cfg.ForceRemove {
normalized := strings.ToLower(strings.TrimSpace(key))
if normalized == "" {
continue
}
forceRemove[normalized] = struct{}{}
}
filtered := make(http.Header, len(src))
for key, values := range src {
lower := strings.ToLower(key)
if _, blocked := forceRemove[lower]; blocked {
continue
}
if _, ok := allowed[lower]; !ok {
continue
}
// 跳过 hop-by-hop 头部,这些由 HTTP 库自动处理
if _, isHopByHop := hopByHopHeaders[lower]; isHopByHop {
continue
}
for _, value := range values {
filtered.Add(key, value)
}
}
return filtered
}
func WriteFilteredHeaders(dst http.Header, src http.Header, cfg config.ResponseHeaderConfig) {
filtered := FilterHeaders(src, cfg)
for key, values := range filtered {
for _, value := range values {
dst.Add(key, value)
}
}
}

View File

@@ -0,0 +1,121 @@
package urlvalidator
import (
"context"
"errors"
"fmt"
"net"
"net/url"
"strings"
"time"
)
type ValidationOptions struct {
AllowedHosts []string
RequireAllowlist bool
AllowPrivate bool
}
func ValidateHTTPSURL(raw string, opts ValidationOptions) (string, error) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return "", errors.New("url is required")
}
parsed, err := url.Parse(trimmed)
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
return "", fmt.Errorf("invalid url: %s", trimmed)
}
if !strings.EqualFold(parsed.Scheme, "https") {
return "", fmt.Errorf("invalid url scheme: %s", parsed.Scheme)
}
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
if host == "" {
return "", errors.New("invalid host")
}
if !opts.AllowPrivate && isBlockedHost(host) {
return "", fmt.Errorf("host is not allowed: %s", host)
}
allowlist := normalizeAllowlist(opts.AllowedHosts)
if opts.RequireAllowlist && len(allowlist) == 0 {
return "", errors.New("allowlist is not configured")
}
if len(allowlist) > 0 && !isAllowedHost(host, allowlist) {
return "", fmt.Errorf("host is not allowed: %s", host)
}
parsed.Path = strings.TrimRight(parsed.Path, "/")
parsed.RawPath = ""
return strings.TrimRight(parsed.String(), "/"), nil
}
// ValidateResolvedIP 验证 DNS 解析后的 IP 地址是否安全
// 用于防止 DNS Rebinding 攻击:在实际 HTTP 请求时调用此函数验证解析后的 IP
func ValidateResolvedIP(host string) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
if err != nil {
return fmt.Errorf("dns resolution failed: %w", err)
}
for _, ip := range ips {
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() || ip.IsUnspecified() {
return fmt.Errorf("resolved ip %s is not allowed", ip.String())
}
}
return nil
}
func normalizeAllowlist(values []string) []string {
if len(values) == 0 {
return nil
}
normalized := make([]string, 0, len(values))
for _, v := range values {
entry := strings.ToLower(strings.TrimSpace(v))
if entry == "" {
continue
}
if host, _, err := net.SplitHostPort(entry); err == nil {
entry = host
}
normalized = append(normalized, entry)
}
return normalized
}
func isAllowedHost(host string, allowlist []string) bool {
for _, entry := range allowlist {
if entry == "" {
continue
}
if strings.HasPrefix(entry, "*.") {
suffix := strings.TrimPrefix(entry, "*.")
if host == suffix || strings.HasSuffix(host, "."+suffix) {
return true
}
continue
}
if host == entry {
return true
}
}
return false
}
func isBlockedHost(host string) bool {
if host == "localhost" || strings.HasSuffix(host, ".localhost") {
return true
}
if ip := net.ParseIP(host); ip != nil {
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsUnspecified() {
return true
}
}
return false
}

View File

@@ -25,8 +25,8 @@
timeouts {
read_body 30s
read_header 10s
write 60s
idle 120s
write 300s
idle 300s
}
}
}
@@ -77,7 +77,10 @@ example.com {
write_buffer 16KB
compression off
}
# SSE/流式传输优化:禁用响应缓冲,立即刷新数据给客户端
flush_interval -1
# 故障转移
fail_duration 30s
max_fails 3
@@ -92,6 +95,10 @@ example.com {
gzip 6
minimum_length 256
match {
# SSE 请求通常会带 Accept: text/event-stream需排除压缩
not header Accept text/event-stream*
# 排除已知 SSE 路径(即便 Accept 缺失)
not path /v1/messages /v1/responses /responses /antigravity/v1/messages /v1beta/models/* /antigravity/v1beta/models/*
header Content-Type text/*
header Content-Type application/json*
header Content-Type application/javascript*
@@ -179,6 +186,3 @@ example.com {
# =============================================================================
# HTTP 重定向到 HTTPS (Caddy 默认自动处理,此处显式声明)
# =============================================================================
; http://example.com {
; redir https://{host}{uri} permanent
; }

View File

@@ -12,6 +12,8 @@ server:
port: 8080
# Mode: "debug" for development, "release" for production
mode: "release"
# Trusted proxies for X-Forwarded-For parsing (CIDR/IP). Empty disables trusted proxies.
trusted_proxies: []
# =============================================================================
# Run Mode Configuration
@@ -21,12 +23,54 @@ server:
# - simple: Hides SaaS features and skips billing/balance checks
run_mode: "standard"
# =============================================================================
# CORS Configuration
# =============================================================================
cors:
# Allowed origins list. Leave empty to disable cross-origin requests.
allowed_origins: []
# Allow credentials (cookies/authorization headers). Cannot be used with "*".
allow_credentials: true
# =============================================================================
# Security Configuration
# =============================================================================
security:
url_allowlist:
# Allowed upstream hosts for API proxying
upstream_hosts:
- "api.openai.com"
- "api.anthropic.com"
- "generativelanguage.googleapis.com"
- "cloudcode-pa.googleapis.com"
- "*.openai.azure.com"
# Allowed hosts for pricing data download
pricing_hosts:
- "raw.githubusercontent.com"
# Allowed hosts for CRS sync (required when using CRS sync)
crs_hosts: []
# Allow localhost/private IPs for upstream/pricing/CRS (use only in trusted networks)
allow_private_hosts: false
response_headers:
# Extra allowed response headers from upstream
additional_allowed: []
# Force-remove response headers from upstream
force_remove: []
csp:
# Enable Content-Security-Policy header
enabled: true
# Default CSP policy (override if you host assets on other domains)
policy: "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self' https:; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
proxy_probe:
# Allow skipping TLS verification for proxy probe (debug only)
insecure_skip_verify: false
# =============================================================================
# 网关配置
# =============================================================================
gateway:
# 等待上游响应头超时时间(秒)
response_header_timeout: 300
response_header_timeout: 600
# 请求体最大字节数(默认 100MB
max_body_size: 104857600
# 连接池隔离策略:
@@ -38,14 +82,35 @@ gateway:
max_idle_conns: 240
max_idle_conns_per_host: 120
max_conns_per_host: 240
idle_conn_timeout_seconds: 300
idle_conn_timeout_seconds: 90
# 上游连接池客户端缓存配置
# max_upstream_clients: 最大缓存客户端数量,超出后淘汰最久未使用的
# client_idle_ttl_seconds: 客户端空闲回收阈值(秒),超时且无活跃请求时回收
max_upstream_clients: 5000
client_idle_ttl_seconds: 900
# 并发槽位过期时间(分钟)
concurrency_slot_ttl_minutes: 15
concurrency_slot_ttl_minutes: 30
# 流数据间隔超时0=禁用
stream_data_interval_timeout: 180
# 流式 keepalive 间隔0=禁用
stream_keepalive_interval: 10
# SSE 单行最大字节数(默认 10MB
max_line_size: 10485760
# Log upstream error response body summary (safe/truncated; does not log request content)
log_upstream_error_body: false
# Max bytes to log from upstream error body
log_upstream_error_body_max_bytes: 2048
# Auto inject anthropic-beta for API-key accounts when needed (default off)
inject_beta_for_apikey: false
# Allow failover on selected 400 errors (default off)
failover_on_400: false
# =============================================================================
# 并发等待配置
# =============================================================================
concurrency:
# 并发等待期间的 SSE ping 间隔(秒)
ping_interval: 10
# =============================================================================
# Database Configuration (PostgreSQL)
@@ -77,7 +142,7 @@ jwt:
# IMPORTANT: Change this to a random string in production!
# Generate with: openssl rand -hex 32
secret: "change-this-to-a-secure-random-string"
# Token expiration time in hours
# Token expiration time in hours (max 24)
expire_hour: 24
# =============================================================================
@@ -123,19 +188,21 @@ pricing:
hash_check_interval_minutes: 10
# =============================================================================
# Gateway (Optional)
# Billing Configuration
# =============================================================================
gateway:
# Wait time (in seconds) for upstream response headers (streaming body not affected)
response_header_timeout: 300
# Log upstream error response body summary (safe/truncated; does not log request content)
log_upstream_error_body: false
# Max bytes to log from upstream error body
log_upstream_error_body_max_bytes: 2048
# Auto inject anthropic-beta for API-key accounts when needed (default off)
inject_beta_for_apikey: false
# Allow failover on selected 400 errors (default off)
failover_on_400: false
billing:
circuit_breaker:
enabled: true
failure_threshold: 5
reset_timeout_seconds: 30
half_open_requests: 3
# =============================================================================
# Turnstile Configuration
# =============================================================================
turnstile:
# Require Turnstile in release mode (when enabled, login/register will fail if not configured)
required: false
# =============================================================================
# Gemini OAuth (Required for Gemini accounts)

View File

@@ -26,14 +26,37 @@ export interface SystemSettings {
smtp_host: string
smtp_port: number
smtp_username: string
smtp_password: string
smtp_password_configured: boolean
smtp_from_email: string
smtp_from_name: string
smtp_use_tls: boolean
// Cloudflare Turnstile settings
turnstile_enabled: boolean
turnstile_site_key: string
turnstile_secret_key: string
turnstile_secret_key_configured: boolean
}
export interface UpdateSettingsRequest {
registration_enabled?: boolean
email_verify_enabled?: boolean
default_balance?: number
default_concurrency?: number
site_name?: string
site_logo?: string
site_subtitle?: string
api_base_url?: string
contact_info?: string
doc_url?: string
smtp_host?: string
smtp_port?: number
smtp_username?: string
smtp_password?: string
smtp_from_email?: string
smtp_from_name?: string
smtp_use_tls?: boolean
turnstile_enabled?: boolean
turnstile_site_key?: string
turnstile_secret_key?: string
}
/**
@@ -50,7 +73,7 @@ export async function getSettings(): Promise<SystemSettings> {
* @param settings - Partial settings to update
* @returns Updated settings
*/
export async function updateSettings(settings: Partial<SystemSettings>): Promise<SystemSettings> {
export async function updateSettings(settings: UpdateSettingsRequest): Promise<SystemSettings> {
const { data } = await apiClient.put<SystemSettings>('/admin/settings', settings)
return data
}

View File

@@ -62,8 +62,24 @@ apiClient.interceptors.response.use(
// 401: Unauthorized - clear token and redirect to login
if (status === 401) {
const hasToken = !!localStorage.getItem('auth_token')
const url = error.config?.url || ''
const isAuthEndpoint =
url.includes('/auth/login') || url.includes('/auth/register') || url.includes('/auth/refresh')
const headers = error.config?.headers as Record<string, unknown> | undefined
const authHeader = headers?.Authorization ?? headers?.authorization
const sentAuth =
typeof authHeader === 'string'
? authHeader.trim() !== ''
: Array.isArray(authHeader)
? authHeader.length > 0
: !!authHeader
localStorage.removeItem('auth_token')
localStorage.removeItem('auth_user')
if ((hasToken || sentAuth) && !isAuthEndpoint) {
sessionStorage.setItem('auth_expired', '1')
}
// Only redirect if not already on login page
if (!window.location.pathname.includes('/login')) {
window.location.href = '/login'

View File

@@ -136,16 +136,16 @@
<ol
class="list-inside list-decimal space-y-1 text-xs text-amber-700 dark:text-amber-300"
>
<li v-html="t('admin.accounts.oauth.step1')"></li>
<li v-html="t('admin.accounts.oauth.step2')"></li>
<li v-html="t('admin.accounts.oauth.step3')"></li>
<li v-html="t('admin.accounts.oauth.step4')"></li>
<li v-html="t('admin.accounts.oauth.step5')"></li>
<li v-html="t('admin.accounts.oauth.step6')"></li>
<li>{{ t('admin.accounts.oauth.step1') }}</li>
<li>{{ t('admin.accounts.oauth.step2') }}</li>
<li>{{ t('admin.accounts.oauth.step3') }}</li>
<li>{{ t('admin.accounts.oauth.step4') }}</li>
<li>{{ t('admin.accounts.oauth.step5') }}</li>
<li>{{ t('admin.accounts.oauth.step6') }}</li>
</ol>
<p
class="mt-2 text-xs text-amber-600 dark:text-amber-400"
v-html="t('admin.accounts.oauth.sessionKeyFormat')"
v-text="t('admin.accounts.oauth.sessionKeyFormat')"
></p>
</div>
@@ -390,7 +390,7 @@
>
<p
class="text-xs text-amber-800 dark:text-amber-300"
v-html="oauthImportantNotice"
v-text="oauthImportantNotice"
></p>
</div>
<!-- Proxy Warning (for non-OpenAI) -->
@@ -400,7 +400,7 @@
>
<p
class="text-xs text-yellow-800 dark:text-yellow-300"
v-html="t('admin.accounts.oauth.proxyWarning')"
v-text="t('admin.accounts.oauth.proxyWarning')"
></p>
</div>
</div>
@@ -423,7 +423,7 @@
</p>
<p
class="mb-3 text-sm text-blue-700 dark:text-blue-300"
v-html="oauthAuthCodeDesc"
v-text="oauthAuthCodeDesc"
></p>
<div>
<label class="input-label">

View File

@@ -107,7 +107,10 @@
</button>
</div>
<!-- Code Content -->
<pre class="p-4 text-sm font-mono text-gray-100 overflow-x-auto"><code v-html="file.highlighted"></code></pre>
<pre class="p-4 text-sm font-mono text-gray-100 overflow-x-auto">
<code v-if="file.highlighted" v-html="file.highlighted"></code>
<code v-else v-text="file.content"></code>
</pre>
</div>
</div>
</div>
@@ -164,8 +167,8 @@ interface TabConfig {
interface FileConfig {
path: string
content: string
highlighted: string
hint?: string // Optional hint message for this file
highlighted?: string
}
const props = defineProps<Props>()
@@ -311,14 +314,23 @@ const platformNote = computed(() => {
}
})
// Syntax highlighting helpers
const keyword = (text: string) => `<span class="text-purple-400">${text}</span>`
const variable = (text: string) => `<span class="text-cyan-400">${text}</span>`
const string = (text: string) => `<span class="text-green-400">${text}</span>`
const operator = (text: string) => `<span class="text-yellow-400">${text}</span>`
const comment = (text: string) => `<span class="text-gray-500">${text}</span>`
const key = (text: string) => `<span class="text-blue-400">${text}</span>`
const escapeHtml = (value: string) => value
.replace(/&/g, '&amp;')
.replace(/</g, '&lt;')
.replace(/>/g, '&gt;')
.replace(/"/g, '&quot;')
.replace(/'/g, '&#39;')
const wrapToken = (className: string, value: string) =>
`<span class="${className}">${escapeHtml(value)}</span>`
const keyword = (value: string) => wrapToken('text-emerald-300', value)
const variable = (value: string) => wrapToken('text-sky-200', value)
const operator = (value: string) => wrapToken('text-slate-400', value)
const string = (value: string) => wrapToken('text-amber-200', value)
const comment = (value: string) => wrapToken('text-slate-500', value)
// Syntax highlighting helpers
// Generate file configs based on platform and active tab
const currentFiles = computed((): FileConfig[] => {
const baseUrl = props.baseUrl || window.location.origin
@@ -343,37 +355,29 @@ const currentFiles = computed((): FileConfig[] => {
function generateAnthropicFiles(baseUrl: string, apiKey: string): FileConfig[] {
let path: string
let content: string
let highlighted: string
switch (activeTab.value) {
case 'unix':
path = 'Terminal'
content = `export ANTHROPIC_BASE_URL="${baseUrl}"
export ANTHROPIC_AUTH_TOKEN="${apiKey}"`
highlighted = `${keyword('export')} ${variable('ANTHROPIC_BASE_URL')}${operator('=')}${string(`"${baseUrl}"`)}
${keyword('export')} ${variable('ANTHROPIC_AUTH_TOKEN')}${operator('=')}${string(`"${apiKey}"`)}`
break
case 'cmd':
path = 'Command Prompt'
content = `set ANTHROPIC_BASE_URL=${baseUrl}
set ANTHROPIC_AUTH_TOKEN=${apiKey}`
highlighted = `${keyword('set')} ${variable('ANTHROPIC_BASE_URL')}${operator('=')}${baseUrl}
${keyword('set')} ${variable('ANTHROPIC_AUTH_TOKEN')}${operator('=')}${apiKey}`
break
case 'powershell':
path = 'PowerShell'
content = `$env:ANTHROPIC_BASE_URL="${baseUrl}"
$env:ANTHROPIC_AUTH_TOKEN="${apiKey}"`
highlighted = `${keyword('$env:')}${variable('ANTHROPIC_BASE_URL')}${operator('=')}${string(`"${baseUrl}"`)}
${keyword('$env:')}${variable('ANTHROPIC_AUTH_TOKEN')}${operator('=')}${string(`"${apiKey}"`)}`
break
default:
path = 'Terminal'
content = ''
highlighted = ''
}
return [{ path, content, highlighted }]
return [{ path, content }]
}
function generateGeminiCliContent(baseUrl: string, apiKey: string): FileConfig {
@@ -398,9 +402,9 @@ ${keyword('export')} ${variable('GEMINI_MODEL')}${operator('=')}${string(`"${mod
content = `set GOOGLE_GEMINI_BASE_URL=${baseUrl}
set GEMINI_API_KEY=${apiKey}
set GEMINI_MODEL=${model}`
highlighted = `${keyword('set')} ${variable('GOOGLE_GEMINI_BASE_URL')}${operator('=')}${baseUrl}
${keyword('set')} ${variable('GEMINI_API_KEY')}${operator('=')}${apiKey}
${keyword('set')} ${variable('GEMINI_MODEL')}${operator('=')}${model}
highlighted = `${keyword('set')} ${variable('GOOGLE_GEMINI_BASE_URL')}${operator('=')}${string(baseUrl)}
${keyword('set')} ${variable('GEMINI_API_KEY')}${operator('=')}${string(apiKey)}
${keyword('set')} ${variable('GEMINI_MODEL')}${operator('=')}${string(model)}
${comment(`REM ${modelComment}`)}`
break
case 'powershell':
@@ -440,40 +444,20 @@ base_url = "${baseUrl}"
wire_api = "responses"
requires_openai_auth = true`
const configHighlighted = `${key('model_provider')} ${operator('=')} ${string('"sub2api"')}
${key('model')} ${operator('=')} ${string('"gpt-5.2-codex"')}
${key('model_reasoning_effort')} ${operator('=')} ${string('"high"')}
${key('network_access')} ${operator('=')} ${string('"enabled"')}
${key('disable_response_storage')} ${operator('=')} ${keyword('true')}
${key('windows_wsl_setup_acknowledged')} ${operator('=')} ${keyword('true')}
${key('model_verbosity')} ${operator('=')} ${string('"high"')}
${comment('[model_providers.sub2api]')}
${key('name')} ${operator('=')} ${string('"sub2api"')}
${key('base_url')} ${operator('=')} ${string(`"${baseUrl}"`)}
${key('wire_api')} ${operator('=')} ${string('"responses"')}
${key('requires_openai_auth')} ${operator('=')} ${keyword('true')}`
// auth.json content
const authContent = `{
"OPENAI_API_KEY": "${apiKey}"
}`
const authHighlighted = `{
${key('"OPENAI_API_KEY"')}: ${string(`"${apiKey}"`)}
}`
return [
{
path: `${configDir}/config.toml`,
content: configContent,
highlighted: configHighlighted,
hint: t('keys.useKeyModal.openai.configTomlHint')
},
{
path: `${configDir}/auth.json`,
content: authContent,
highlighted: authHighlighted
content: authContent
}
]
}

View File

@@ -63,6 +63,7 @@
<script setup lang="ts">
import { ref, computed, onMounted } from 'vue'
import { getPublicSettings } from '@/api/auth'
import { sanitizeUrl } from '@/utils/url'
const siteName = ref('Sub2API')
const siteLogo = ref('')
@@ -74,7 +75,7 @@ onMounted(async () => {
try {
const settings = await getPublicSettings()
siteName.value = settings.site_name || 'Sub2API'
siteLogo.value = settings.site_logo || ''
siteLogo.value = sanitizeUrl(settings.site_logo || '', { allowRelative: true })
siteSubtitle.value = settings.site_subtitle || 'Subscription to API Conversion Platform'
} catch (error) {
console.error('Failed to load public settings:', error)

View File

@@ -200,6 +200,7 @@ export default {
registrationFailed: 'Registration failed. Please try again.',
loginSuccess: 'Login successful! Welcome back.',
accountCreatedSuccess: 'Account created successfully! Welcome to {siteName}.',
reloginRequired: 'Session expired. Please log in again.',
turnstileExpired: 'Verification expired, please try again',
turnstileFailed: 'Verification failed, please try again',
completeVerification: 'Please complete the verification',
@@ -1127,13 +1128,13 @@ export default {
'One sessionKey per line, e.g.:\nsk-ant-sid01-xxxxx...\nsk-ant-sid01-yyyyy...',
sessionKeyPlaceholderSingle: 'sk-ant-sid01-xxxxx...',
howToGetSessionKey: 'How to get sessionKey',
step1: 'Login to <strong>claude.ai</strong> in your browser',
step2: 'Press <kbd>F12</kbd> to open Developer Tools',
step3: 'Go to <strong>Application</strong> tab',
step4: 'Find <strong>Cookies</strong> → <strong>https://claude.ai</strong>',
step5: 'Find the row with key <strong>sessionKey</strong>',
step6: 'Copy the <strong>Value</strong>',
sessionKeyFormat: 'sessionKey usually starts with <code>sk-ant-sid01-</code>',
step1: 'Login to claude.ai in your browser',
step2: 'Press F12 to open Developer Tools',
step3: 'Go to Application tab',
step4: 'Find Cookies → https://claude.ai',
step5: 'Find the row with key sessionKey',
step6: 'Copy the Value',
sessionKeyFormat: 'sessionKey usually starts with sk-ant-sid01-',
startAutoAuth: 'Start Auto-Auth',
authorizing: 'Authorizing...',
followSteps: 'Follow these steps to authorize your Claude account:',
@@ -1145,10 +1146,10 @@ export default {
openUrlDesc:
'Open the authorization URL in a new tab, log in to your Claude account and authorize.',
proxyWarning:
'<strong>Note:</strong> If you configured a proxy, make sure your browser uses the same proxy to access the authorization page.',
'Note: If you configured a proxy, make sure your browser uses the same proxy to access the authorization page.',
step3EnterCode: 'Enter the Authorization Code',
authCodeDesc:
'After authorization is complete, the page will display an <strong>Authorization Code</strong>. Copy and paste it below:',
'After authorization is complete, the page will display an Authorization Code. Copy and paste it below:',
authCode: 'Authorization Code',
authCodePlaceholder: 'Paste the Authorization Code from Claude page...',
authCodeHint: 'Paste the Authorization Code copied from the Claude page',
@@ -1169,10 +1170,10 @@ export default {
openUrlDesc:
'Open the authorization URL in a new tab, log in to your OpenAI account and authorize.',
importantNotice:
'<strong>Important:</strong> The page may take a while to load after authorization. Please wait patiently. When the browser address bar changes to <code>http://localhost...</code>, the authorization is complete.',
'Important: The page may take a while to load after authorization. Please wait patiently. When the browser address bar changes to http://localhost..., the authorization is complete.',
step3EnterCode: 'Enter Authorization URL or Code',
authCodeDesc:
'After authorization is complete, when the page URL becomes <code>http://localhost:xxx/auth/callback?code=...</code>:',
'After authorization is complete, when the page URL becomes http://localhost:xxx/auth/callback?code=...:',
authCode: 'Authorization URL or Code',
authCodePlaceholder:
'Option 1: Copy the complete URL\n(http://localhost:xxx/auth/callback?code=...)\nOption 2: Copy only the code parameter value',
@@ -1195,7 +1196,7 @@ export default {
'Open the authorization URL in a new tab, log in to your Google account and authorize.',
step3EnterCode: 'Enter Authorization URL or Code',
authCodeDesc:
'After authorization, copy the callback URL (recommended) or just the <code>code</code> and paste it below.',
'After authorization, copy the callback URL (recommended) or just the code and paste it below.',
authCode: 'Callback URL or Code',
authCodePlaceholder:
'Option 1 (recommended): Paste the callback URL\nOption 2: Paste only the code value',
@@ -1237,10 +1238,10 @@ export default {
step2OpenUrl: 'Open the URL in your browser and complete authorization',
openUrlDesc: 'Open the authorization URL in a new tab, log in to your Google account and authorize.',
importantNotice:
'<strong>Important:</strong> The page may take a while to load after authorization. Please wait patiently. When the browser address bar shows <code>http://localhost...</code>, authorization is complete.',
'Important: The page may take a while to load after authorization. Please wait patiently. When the browser address bar shows http://localhost..., authorization is complete.',
step3EnterCode: 'Enter Authorization URL or Code',
authCodeDesc:
'After authorization, when the page URL becomes <code>http://localhost:xxx/auth/callback?code=...</code>:',
'After authorization, when the page URL becomes http://localhost:xxx/auth/callback?code=...:',
authCode: 'Authorization URL or Code',
authCodePlaceholder:
'Option 1: Copy the complete URL\n(http://localhost:xxx/auth/callback?code=...)\nOption 2: Copy only the code parameter value',
@@ -1652,8 +1653,8 @@ export default {
secretKey: 'Secret Key',
siteKeyHint: 'Get this from your Cloudflare Dashboard',
cloudflareDashboard: 'Cloudflare Dashboard',
secretKeyHint: 'Server-side verification key (keep this secret)'
},
secretKeyHint: 'Server-side verification key (keep this secret)',
secretKeyConfiguredHint: 'Secret key configured. Leave empty to keep the current value.' },
defaults: {
title: 'Default User Settings',
description: 'Default values for new users',
@@ -1703,6 +1704,8 @@ export default {
password: 'SMTP Password',
passwordPlaceholder: '********',
passwordHint: 'Leave empty to keep existing password',
passwordConfiguredPlaceholder: '********',
passwordConfiguredHint: 'Password configured. Leave empty to keep the current value.',
fromEmail: 'From Email',
fromEmailPlaceholder: "noreply{'@'}example.com",
fromName: 'From Name',

View File

@@ -197,6 +197,7 @@ export default {
registrationFailed: '注册失败,请重试。',
loginSuccess: '登录成功!欢迎回来。',
accountCreatedSuccess: '账户创建成功!欢迎使用 {siteName}。',
reloginRequired: '会话已过期,请重新登录。',
turnstileExpired: '验证已过期,请重试',
turnstileFailed: '验证失败,请重试',
completeVerification: '请完成验证',
@@ -1279,13 +1280,13 @@ export default {
'每行一个 sessionKey例如\nsk-ant-sid01-xxxxx...\nsk-ant-sid01-yyyyy...',
sessionKeyPlaceholderSingle: 'sk-ant-sid01-xxxxx...',
howToGetSessionKey: '如何获取 sessionKey',
step1: '在浏览器中登录 <strong>claude.ai</strong>',
step2: '按 <kbd>F12</kbd> 打开开发者工具',
step3: '切换到 <strong>Application</strong> 标签',
step4: '找到 <strong>Cookies</strong> → <strong>https://claude.ai</strong>',
step5: '找到 <strong>sessionKey</strong> 所在行',
step6: '复制 <strong>Value</strong> 列的值',
sessionKeyFormat: 'sessionKey 通常以 <code>sk-ant-sid01-</code> 开头',
step1: '在浏览器中登录 claude.ai',
step2: '按 F12 打开开发者工具',
step3: '切换到 Application 标签',
step4: '找到 Cookies → https://claude.ai',
step5: '找到 sessionKey 所在行',
step6: '复制 Value 列的值',
sessionKeyFormat: 'sessionKey 通常以 sk-ant-sid01- 开头',
startAutoAuth: '开始自动授权',
authorizing: '授权中...',
followSteps: '按照以下步骤授权您的 Claude 账号:',
@@ -1296,9 +1297,9 @@ export default {
step2OpenUrl: '在浏览器中打开 URL 并完成授权',
openUrlDesc: '在新标签页中打开授权 URL登录您的 Claude 账号并授权。',
proxyWarning:
'<strong>注意:</strong>如果您配置了代理,请确保浏览器使用相同的代理访问授权页面。',
'注意:如果您配置了代理,请确保浏览器使用相同的代理访问授权页面。',
step3EnterCode: '输入授权码',
authCodeDesc: '授权完成后,页面会显示一个 <strong>授权码</strong>。复制并粘贴到下方:',
authCodeDesc: '授权完成后,页面会显示一个授权码。复制并粘贴到下方:',
authCode: '授权码',
authCodePlaceholder: '粘贴 Claude 页面的授权码...',
authCodeHint: '粘贴从 Claude 页面复制的授权码',
@@ -1318,10 +1319,10 @@ export default {
step2OpenUrl: '在浏览器中打开链接并完成授权',
openUrlDesc: '请在新标签页中打开授权链接,登录您的 OpenAI 账户并授权。',
importantNotice:
'<strong>重要提示:</strong>授权后页面可能会加载较长时间,请耐心等待。当浏览器地址栏变为 <code>http://localhost...</code> 开头时,表示授权已完成。',
'重要提示:授权后页面可能会加载较长时间,请耐心等待。当浏览器地址栏变为 http://localhost... 开头时,表示授权已完成。',
step3EnterCode: '输入授权链接或 Code',
authCodeDesc:
'授权完成后,当页面地址变为 <code>http://localhost:xxx/auth/callback?code=...</code> 时:',
'授权完成后,当页面地址变为 http://localhost:xxx/auth/callback?code=... 时:',
authCode: '授权链接或 Code',
authCodePlaceholder:
'方式1复制完整的链接\n(http://localhost:xxx/auth/callback?code=...)\n方式2仅复制 code 参数的值',
@@ -1340,7 +1341,7 @@ export default {
step2OpenUrl: '在浏览器中打开链接并完成授权',
openUrlDesc: '请在新标签页中打开授权链接,登录您的 Google 账户并授权。',
step3EnterCode: '输入回调链接或 Code',
authCodeDesc: '授权完成后,复制浏览器跳转后的回调链接(推荐)或仅复制 <code>code</code>,粘贴到下方即可。',
authCodeDesc: '授权完成后,复制浏览器跳转后的回调链接(推荐)或仅复制 code粘贴到下方即可。',
authCode: '回调链接或 Code',
authCodePlaceholder: '方式1推荐粘贴回调链接\n方式2仅粘贴 code 参数的值',
authCodeHint: '系统会自动从链接中解析 code/state。',
@@ -1376,10 +1377,10 @@ export default {
step2OpenUrl: '在浏览器中打开链接并完成授权',
openUrlDesc: '请在新标签页中打开授权链接,登录您的 Google 账户并授权。',
importantNotice:
'<strong>重要提示:</strong>授权后页面可能会加载较长时间,请耐心等待。当浏览器地址栏变为 <code>http://localhost...</code> 开头时,表示授权已完成。',
'重要提示:授权后页面可能会加载较长时间,请耐心等待。当浏览器地址栏变为 http://localhost... 开头时,表示授权已完成。',
step3EnterCode: '输入授权链接或 Code',
authCodeDesc:
'授权完成后,当页面地址变为 <code>http://localhost:xxx/auth/callback?code=...</code> 时:',
'授权完成后,当页面地址变为 http://localhost:xxx/auth/callback?code=... 时:',
authCode: '授权链接或 Code',
authCodePlaceholder:
'方式1复制完整的链接\n(http://localhost:xxx/auth/callback?code=...)\n方式2仅复制 code 参数的值',
@@ -1847,8 +1848,8 @@ export default {
secretKey: '私密密钥',
siteKeyHint: '从 Cloudflare Dashboard 获取',
cloudflareDashboard: 'Cloudflare Dashboard',
secretKeyHint: '服务端验证密钥(请保密)'
},
secretKeyHint: '服务端验证密钥(请保密)',
secretKeyConfiguredHint: '密钥已配置,留空以保留当前值。' },
defaults: {
title: '用户默认设置',
description: '新用户的默认值',
@@ -1897,6 +1898,8 @@ export default {
password: 'SMTP 密码',
passwordPlaceholder: '********',
passwordHint: '留空以保留现有密码',
passwordConfiguredPlaceholder: '********',
passwordConfiguredHint: '密码已配置,留空以保留当前值。',
fromEmail: '发件人邮箱',
fromEmailPlaceholder: "noreply{'@'}example.com",
fromName: '发件人名称',

37
frontend/src/utils/url.ts Normal file
View File

@@ -0,0 +1,37 @@
/**
* 验证并规范化 URL
* 默认只接受绝对 URL以 http:// 或 https:// 开头),可按需允许相对路径
* @param value 用户输入的 URL
* @returns 规范化后的 URL如果无效则返回空字符串
*/
type SanitizeOptions = {
allowRelative?: boolean
}
export function sanitizeUrl(value: string, options: SanitizeOptions = {}): string {
const trimmed = value.trim()
if (!trimmed) {
return ''
}
if (options.allowRelative && trimmed.startsWith('/') && !trimmed.startsWith('//')) {
return trimmed
}
// 只接受绝对 URL不使用 base URL 来避免相对路径被解析为当前域名
// 检查是否以 http:// 或 https:// 开头
if (!trimmed.match(/^https?:\/\//i)) {
return ''
}
try {
const parsed = new URL(trimmed)
const protocol = parsed.protocol.toLowerCase()
if (protocol !== 'http:' && protocol !== 'https:') {
return ''
}
return parsed.toString()
} catch {
return ''
}
}

View File

@@ -493,6 +493,7 @@ import { useI18n } from 'vue-i18n'
import { getPublicSettings } from '@/api/auth'
import { useAuthStore } from '@/stores'
import LocaleSwitcher from '@/components/common/LocaleSwitcher.vue'
import { sanitizeUrl } from '@/utils/url'
const { t } = useI18n()
@@ -549,9 +550,9 @@ onMounted(async () => {
try {
const settings = await getPublicSettings()
siteName.value = settings.site_name || 'Sub2API'
siteLogo.value = settings.site_logo || ''
siteLogo.value = sanitizeUrl(settings.site_logo || '', { allowRelative: true })
siteSubtitle.value = settings.site_subtitle || 'AI API Gateway Platform'
docUrl.value = settings.doc_url || ''
docUrl.value = sanitizeUrl(settings.doc_url || '', { allowRelative: true })
} catch (error) {
console.error('Failed to load public settings:', error)
}

View File

@@ -255,7 +255,11 @@
placeholder="0x4AAAAAAA..."
/>
<p class="mt-1.5 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.settings.turnstile.secretKeyHint') }}
{{
form.turnstile_secret_key_configured
? t('admin.settings.turnstile.secretKeyConfiguredHint')
: t('admin.settings.turnstile.secretKeyHint')
}}
</p>
</div>
</div>
@@ -577,10 +581,18 @@
v-model="form.smtp_password"
type="password"
class="input"
:placeholder="t('admin.settings.smtp.passwordPlaceholder')"
:placeholder="
form.smtp_password_configured
? t('admin.settings.smtp.passwordConfiguredPlaceholder')
: t('admin.settings.smtp.passwordPlaceholder')
"
/>
<p class="mt-1.5 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.settings.smtp.passwordHint') }}
{{
form.smtp_password_configured
? t('admin.settings.smtp.passwordConfiguredHint')
: t('admin.settings.smtp.passwordHint')
}}
</p>
</div>
<div>
@@ -713,7 +725,7 @@
import { ref, reactive, onMounted } from 'vue'
import { useI18n } from 'vue-i18n'
import { adminAPI } from '@/api'
import type { SystemSettings } from '@/api/admin/settings'
import type { SystemSettings, UpdateSettingsRequest } from '@/api/admin/settings'
import AppLayout from '@/components/layout/AppLayout.vue'
import Toggle from '@/components/common/Toggle.vue'
import { useAppStore } from '@/stores'
@@ -735,7 +747,12 @@ const adminApiKeyMasked = ref('')
const adminApiKeyOperating = ref(false)
const newAdminApiKey = ref('')
const form = reactive<SystemSettings>({
type SettingsForm = SystemSettings & {
smtp_password: string
turnstile_secret_key: string
}
const form = reactive<SettingsForm>({
registration_enabled: true,
email_verify_enabled: false,
default_balance: 0,
@@ -750,13 +767,15 @@ const form = reactive<SystemSettings>({
smtp_port: 587,
smtp_username: '',
smtp_password: '',
smtp_password_configured: false,
smtp_from_email: '',
smtp_from_name: '',
smtp_use_tls: true,
// Cloudflare Turnstile
turnstile_enabled: false,
turnstile_site_key: '',
turnstile_secret_key: ''
turnstile_secret_key: '',
turnstile_secret_key_configured: false
})
function handleLogoUpload(event: Event) {
@@ -802,6 +821,8 @@ async function loadSettings() {
try {
const settings = await adminAPI.settings.getSettings()
Object.assign(form, settings)
form.smtp_password = ''
form.turnstile_secret_key = ''
} catch (error: any) {
appStore.showError(
t('admin.settings.failedToLoad') + ': ' + (error.message || t('common.unknownError'))
@@ -814,7 +835,32 @@ async function loadSettings() {
async function saveSettings() {
saving.value = true
try {
await adminAPI.settings.updateSettings(form)
const payload: UpdateSettingsRequest = {
registration_enabled: form.registration_enabled,
email_verify_enabled: form.email_verify_enabled,
default_balance: form.default_balance,
default_concurrency: form.default_concurrency,
site_name: form.site_name,
site_logo: form.site_logo,
site_subtitle: form.site_subtitle,
api_base_url: form.api_base_url,
contact_info: form.contact_info,
doc_url: form.doc_url,
smtp_host: form.smtp_host,
smtp_port: form.smtp_port,
smtp_username: form.smtp_username,
smtp_password: form.smtp_password || undefined,
smtp_from_email: form.smtp_from_email,
smtp_from_name: form.smtp_from_name,
smtp_use_tls: form.smtp_use_tls,
turnstile_enabled: form.turnstile_enabled,
turnstile_site_key: form.turnstile_site_key,
turnstile_secret_key: form.turnstile_secret_key || undefined
}
const updated = await adminAPI.settings.updateSettings(payload)
Object.assign(form, updated)
form.smtp_password = ''
form.turnstile_secret_key = ''
// Refresh cached public settings so sidebar/header update immediately
await appStore.fetchPublicSettings(true)
appStore.showSuccess(t('admin.settings.settingsSaved'))

View File

@@ -277,6 +277,14 @@ const errors = reactive({
// ==================== Lifecycle ====================
onMounted(async () => {
const expiredFlag = sessionStorage.getItem('auth_expired')
if (expiredFlag) {
sessionStorage.removeItem('auth_expired')
const message = t('auth.reloginRequired')
errorMessage.value = message
appStore.showWarning(message)
}
try {
const settings = await getPublicSettings()
turnstileEnabled.value = settings.turnstile_enabled