From bd4bf00856c3d95b803520a2978dd41d261b7bb1 Mon Sep 17 00:00:00 2001
From: yangjianbo
Date: Fri, 2 Jan 2026 17:40:57 +0800
Subject: [PATCH 1/7] =?UTF-8?q?feat(=E5=AE=89=E5=85=A8):=20=E5=BC=BA?=
=?UTF-8?q?=E5=8C=96=E5=AE=89=E5=85=A8=E7=AD=96=E7=95=A5=E4=B8=8E=E9=85=8D?=
=?UTF-8?q?=E7=BD=AE=E6=A0=A1=E9=AA=8C?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- 增加 CORS/CSP/安全响应头与代理信任配置
- 引入 URL 白名单与私网开关,校验上游与价格源
- 改善 API Key 处理与网关错误返回
- 管理端设置隐藏敏感字段并优化前端提示
- 增加计费熔断与相关配置示例
测试: go test ./...
---
README.md | 9 +
README_CN.md | 9 +
backend/cmd/server/main.go | 3 +-
backend/cmd/server/wire_gen.go | 8 +-
backend/internal/config/config.go | 212 +++++++++++++++++-
.../internal/handler/admin/setting_handler.go | 105 ++++++++-
backend/internal/handler/dto/settings.go | 4 +-
backend/internal/handler/gateway_handler.go | 22 +-
.../internal/handler/gemini_v1beta_handler.go | 5 +-
.../handler/openai_gateway_handler.go | 3 +-
.../repository/claude_oauth_service.go | 13 +-
.../repository/proxy_probe_service.go | 21 +-
backend/internal/server/api_contract_test.go | 4 +-
backend/internal/server/http.go | 10 +
.../server/middleware/api_key_auth.go | 19 +-
.../server/middleware/api_key_auth_google.go | 17 +-
.../middleware/api_key_auth_google_test.go | 52 +++++
backend/internal/server/middleware/cors.go | 91 +++++++-
.../server/middleware/security_headers.go | 26 +++
backend/internal/server/router.go | 3 +-
.../internal/service/account_test_service.go | 57 ++++-
backend/internal/service/auth_service.go | 24 ++
.../internal/service/billing_cache_service.go | 159 ++++++++++++-
backend/internal/service/crs_sync_service.go | 32 ++-
backend/internal/service/gateway_service.go | 37 ++-
.../service/gemini_messages_compat_service.go | 79 +++++--
.../service/openai_gateway_service.go | 31 ++-
backend/internal/service/pricing_service.go | 47 +++-
backend/internal/service/setting_service.go | 6 +-
backend/internal/service/settings_view.go | 2 +
backend/internal/setup/setup.go | 18 +-
backend/internal/util/logredact/redact.go | 100 +++++++++
.../util/responseheaders/responseheaders.go | 92 ++++++++
.../internal/util/urlvalidator/validator.go | 121 ++++++++++
deploy/config.example.yaml | 63 +++++-
frontend/src/api/admin/settings.ts | 29 ++-
frontend/src/api/client.ts | 16 ++
.../account/OAuthAuthorizationFlow.vue | 20 +-
frontend/src/components/keys/UseKeyModal.vue | 42 +---
frontend/src/components/layout/AuthLayout.vue | 3 +-
frontend/src/i18n/locales/en.ts | 34 +--
frontend/src/i18n/locales/zh.ts | 34 +--
frontend/src/utils/url.ts | 37 +++
frontend/src/views/HomeView.vue | 5 +-
frontend/src/views/admin/SettingsView.vue | 60 ++++-
frontend/src/views/auth/LoginView.vue | 8 +
46 files changed, 1572 insertions(+), 220 deletions(-)
create mode 100644 backend/internal/server/middleware/security_headers.go
create mode 100644 backend/internal/util/logredact/redact.go
create mode 100644 backend/internal/util/responseheaders/responseheaders.go
create mode 100644 backend/internal/util/urlvalidator/validator.go
create mode 100644 frontend/src/utils/url.ts
diff --git a/README.md b/README.md
index a6d051b0..85820b68 100644
--- a/README.md
+++ b/README.md
@@ -268,6 +268,15 @@ default:
rate_multiplier: 1.0
```
+Additional security-related options are available in `config.yaml`:
+
+- `cors.allowed_origins` for CORS allowlist
+- `security.url_allowlist` for upstream/pricing/CRS host allowlists
+- `security.csp` to control Content-Security-Policy headers
+- `billing.circuit_breaker` to fail closed on billing errors
+- `server.trusted_proxies` to enable X-Forwarded-For parsing
+- `turnstile.required` to require Turnstile in release mode
+
```bash
# 6. Run the application
./sub2api
diff --git a/README_CN.md b/README_CN.md
index 0e15be1f..f1c1ff3b 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -268,6 +268,15 @@ default:
rate_multiplier: 1.0
```
+`config.yaml` 还支持以下安全相关配置:
+
+- `cors.allowed_origins` 配置 CORS 白名单
+- `security.url_allowlist` 配置上游/价格数据/CRS 主机白名单
+- `security.csp` 配置 Content-Security-Policy
+- `billing.circuit_breaker` 计费异常时 fail-closed
+- `server.trusted_proxies` 启用可信代理解析 X-Forwarded-For
+- `turnstile.required` 在 release 模式强制启用 Turnstile
+
```bash
# 6. 运行应用
./sub2api
diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go
index d9e0c485..c9dc57bb 100644
--- a/backend/cmd/server/main.go
+++ b/backend/cmd/server/main.go
@@ -86,7 +86,8 @@ func main() {
func runSetupServer() {
r := gin.New()
r.Use(middleware.Recovery())
- r.Use(middleware.CORS())
+ r.Use(middleware.CORS(config.CORSConfig{}))
+ r.Use(middleware.SecurityHeaders(config.CSPConfig{Enabled: true, Policy: config.DefaultCSPPolicy}))
// Register setup routes
setup.RegisterRoutes(r)
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index d469dcbb..b95c65ce 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -77,7 +77,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
dashboardHandler := admin.NewDashboardHandler(dashboardService)
accountRepository := repository.NewAccountRepository(client, db)
proxyRepository := repository.NewProxyRepository(client, db)
- proxyExitInfoProber := repository.NewProxyExitInfoProber()
+ proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber)
adminUserHandler := admin.NewUserHandler(adminService)
groupHandler := admin.NewGroupHandler(adminService)
@@ -98,10 +98,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
httpUpstream := repository.NewHTTPUpstream(configConfig)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream)
- accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream)
+ accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.NewConcurrencyService(concurrencyCache)
- crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService)
+ crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
@@ -129,7 +129,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
timingWheelService := service.ProvideTimingWheelService()
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
- geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService)
+ geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index aeeddcb4..c7b367d9 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -1,7 +1,10 @@
package config
import (
+ "crypto/rand"
+ "encoding/hex"
"fmt"
+ "log"
"strings"
"github.com/spf13/viper"
@@ -12,6 +15,8 @@ const (
RunModeSimple = "simple"
)
+const DefaultCSPPolicy = "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self' https:; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
+
// 连接池隔离策略常量
// 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗
const (
@@ -28,6 +33,10 @@ const (
type Config struct {
Server ServerConfig `mapstructure:"server"`
+ CORS CORSConfig `mapstructure:"cors"`
+ Security SecurityConfig `mapstructure:"security"`
+ Billing BillingConfig `mapstructure:"billing"`
+ Turnstile TurnstileConfig `mapstructure:"turnstile"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
JWT JWTConfig `mapstructure:"jwt"`
@@ -81,11 +90,56 @@ type PricingConfig struct {
}
type ServerConfig struct {
- Host string `mapstructure:"host"`
- Port int `mapstructure:"port"`
- Mode string `mapstructure:"mode"` // debug/release
- ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
- IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
+ Host string `mapstructure:"host"`
+ Port int `mapstructure:"port"`
+ Mode string `mapstructure:"mode"` // debug/release
+ ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
+ IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
+ TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP)
+}
+
+type CORSConfig struct {
+ AllowedOrigins []string `mapstructure:"allowed_origins"`
+ AllowCredentials bool `mapstructure:"allow_credentials"`
+}
+
+type SecurityConfig struct {
+ URLAllowlist URLAllowlistConfig `mapstructure:"url_allowlist"`
+ ResponseHeaders ResponseHeaderConfig `mapstructure:"response_headers"`
+ CSP CSPConfig `mapstructure:"csp"`
+ ProxyProbe ProxyProbeConfig `mapstructure:"proxy_probe"`
+}
+
+type URLAllowlistConfig struct {
+ UpstreamHosts []string `mapstructure:"upstream_hosts"`
+ PricingHosts []string `mapstructure:"pricing_hosts"`
+ CRSHosts []string `mapstructure:"crs_hosts"`
+ AllowPrivateHosts bool `mapstructure:"allow_private_hosts"`
+}
+
+type ResponseHeaderConfig struct {
+ AdditionalAllowed []string `mapstructure:"additional_allowed"`
+ ForceRemove []string `mapstructure:"force_remove"`
+}
+
+type CSPConfig struct {
+ Enabled bool `mapstructure:"enabled"`
+ Policy string `mapstructure:"policy"`
+}
+
+type ProxyProbeConfig struct {
+ InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"`
+}
+
+type BillingConfig struct {
+ CircuitBreaker CircuitBreakerConfig `mapstructure:"circuit_breaker"`
+}
+
+type CircuitBreakerConfig struct {
+ Enabled bool `mapstructure:"enabled"`
+ FailureThreshold int `mapstructure:"failure_threshold"`
+ ResetTimeoutSeconds int `mapstructure:"reset_timeout_seconds"`
+ HalfOpenRequests int `mapstructure:"half_open_requests"`
}
// GatewayConfig API网关相关配置
@@ -192,6 +246,10 @@ type JWTConfig struct {
ExpireHour int `mapstructure:"expire_hour"`
}
+type TurnstileConfig struct {
+ Required bool `mapstructure:"required"`
+}
+
type DefaultConfig struct {
AdminEmail string `mapstructure:"admin_email"`
AdminPassword string `mapstructure:"admin_password"`
@@ -242,11 +300,39 @@ func Load() (*Config, error) {
}
cfg.RunMode = NormalizeRunMode(cfg.RunMode)
+ cfg.Server.Mode = strings.ToLower(strings.TrimSpace(cfg.Server.Mode))
+ if cfg.Server.Mode == "" {
+ cfg.Server.Mode = "debug"
+ }
+ cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
+ cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
+ cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
+ cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
+ cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy)
+
+ if cfg.Server.Mode != "release" && cfg.JWT.Secret == "" {
+ secret, err := generateJWTSecret(64)
+ if err != nil {
+ return nil, fmt.Errorf("generate jwt secret error: %w", err)
+ }
+ cfg.JWT.Secret = secret
+ log.Println("Warning: JWT secret auto-generated for non-release mode. Do not use in production.")
+ }
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("validate config error: %w", err)
}
+ if cfg.Server.Mode != "release" && cfg.JWT.Secret != "" && isWeakJWTSecret(cfg.JWT.Secret) {
+ log.Println("Warning: JWT secret appears weak; use a 32+ character random secret in production.")
+ }
+ if len(cfg.Security.ResponseHeaders.AdditionalAllowed) > 0 || len(cfg.Security.ResponseHeaders.ForceRemove) > 0 {
+ log.Printf("AUDIT: response header policy configured additional_allowed=%v force_remove=%v",
+ cfg.Security.ResponseHeaders.AdditionalAllowed,
+ cfg.Security.ResponseHeaders.ForceRemove,
+ )
+ }
+
return &cfg, nil
}
@@ -259,6 +345,39 @@ func setDefaults() {
viper.SetDefault("server.mode", "debug")
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
+ viper.SetDefault("server.trusted_proxies", []string{})
+
+ // CORS
+ viper.SetDefault("cors.allowed_origins", []string{})
+ viper.SetDefault("cors.allow_credentials", true)
+
+ // Security
+ viper.SetDefault("security.url_allowlist.upstream_hosts", []string{
+ "api.openai.com",
+ "api.anthropic.com",
+ "generativelanguage.googleapis.com",
+ "cloudcode-pa.googleapis.com",
+ "*.openai.azure.com",
+ })
+ viper.SetDefault("security.url_allowlist.pricing_hosts", []string{
+ "raw.githubusercontent.com",
+ })
+ viper.SetDefault("security.url_allowlist.crs_hosts", []string{})
+ viper.SetDefault("security.url_allowlist.allow_private_hosts", false)
+ viper.SetDefault("security.response_headers.additional_allowed", []string{})
+ viper.SetDefault("security.response_headers.force_remove", []string{})
+ viper.SetDefault("security.csp.enabled", true)
+ viper.SetDefault("security.csp.policy", DefaultCSPPolicy)
+ viper.SetDefault("security.proxy_probe.insecure_skip_verify", false)
+
+ // Billing
+ viper.SetDefault("billing.circuit_breaker.enabled", true)
+ viper.SetDefault("billing.circuit_breaker.failure_threshold", 5)
+ viper.SetDefault("billing.circuit_breaker.reset_timeout_seconds", 30)
+ viper.SetDefault("billing.circuit_breaker.half_open_requests", 3)
+
+ // Turnstile
+ viper.SetDefault("turnstile.required", false)
// Database
viper.SetDefault("database.host", "localhost")
@@ -284,7 +403,7 @@ func setDefaults() {
viper.SetDefault("redis.min_idle_conns", 10)
// JWT
- viper.SetDefault("jwt.secret", "change-me-in-production")
+ viper.SetDefault("jwt.secret", "")
viper.SetDefault("jwt.expire_hour", 24)
// Default
@@ -340,11 +459,39 @@ func setDefaults() {
}
func (c *Config) Validate() error {
- if c.JWT.Secret == "" {
- return fmt.Errorf("jwt.secret is required")
+ if c.Server.Mode == "release" {
+ if c.JWT.Secret == "" {
+ return fmt.Errorf("jwt.secret is required in release mode")
+ }
+ if len(c.JWT.Secret) < 32 {
+ return fmt.Errorf("jwt.secret must be at least 32 characters")
+ }
+ if isWeakJWTSecret(c.JWT.Secret) {
+ return fmt.Errorf("jwt.secret is too weak")
+ }
}
- if c.JWT.Secret == "change-me-in-production" && c.Server.Mode == "release" {
- return fmt.Errorf("jwt.secret must be changed in production")
+ if c.JWT.ExpireHour <= 0 {
+ return fmt.Errorf("jwt.expire_hour must be positive")
+ }
+ if c.JWT.ExpireHour > 168 {
+ return fmt.Errorf("jwt.expire_hour must be <= 168 (7 days)")
+ }
+ if c.JWT.ExpireHour > 24 {
+ log.Printf("Warning: jwt.expire_hour is %d hours (> 24). Consider shorter expiration for security.", c.JWT.ExpireHour)
+ }
+ if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" {
+ return fmt.Errorf("security.csp.policy is required when CSP is enabled")
+ }
+ if c.Billing.CircuitBreaker.Enabled {
+ if c.Billing.CircuitBreaker.FailureThreshold <= 0 {
+ return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive")
+ }
+ if c.Billing.CircuitBreaker.ResetTimeoutSeconds <= 0 {
+ return fmt.Errorf("billing.circuit_breaker.reset_timeout_seconds must be positive")
+ }
+ if c.Billing.CircuitBreaker.HalfOpenRequests <= 0 {
+ return fmt.Errorf("billing.circuit_breaker.half_open_requests must be positive")
+ }
}
if c.Database.MaxOpenConns <= 0 {
return fmt.Errorf("database.max_open_conns must be positive")
@@ -414,6 +561,51 @@ func (c *Config) Validate() error {
return nil
}
+func normalizeStringSlice(values []string) []string {
+ if len(values) == 0 {
+ return values
+ }
+ normalized := make([]string, 0, len(values))
+ for _, v := range values {
+ trimmed := strings.TrimSpace(v)
+ if trimmed == "" {
+ continue
+ }
+ normalized = append(normalized, trimmed)
+ }
+ return normalized
+}
+
+func isWeakJWTSecret(secret string) bool {
+ lower := strings.ToLower(strings.TrimSpace(secret))
+ if lower == "" {
+ return true
+ }
+ weak := map[string]struct{}{
+ "change-me-in-production": {},
+ "changeme": {},
+ "secret": {},
+ "password": {},
+ "123456": {},
+ "12345678": {},
+ "admin": {},
+ "jwt-secret": {},
+ }
+ _, exists := weak[lower]
+ return exists
+}
+
+func generateJWTSecret(byteLength int) (string, error) {
+ if byteLength <= 0 {
+ byteLength = 32
+ }
+ buf := make([]byte, byteLength)
+ if _, err := rand.Read(buf); err != nil {
+ return "", err
+ }
+ return hex.EncodeToString(buf), nil
+}
+
// GetServerAddress returns the server address (host:port) from config file or environment variable.
// This is a lightweight function that can be used before full config validation,
// such as during setup wizard startup.
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index 14b569de..816db304 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -1,8 +1,12 @@
package admin
import (
+ "log"
+ "time"
+
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -37,13 +41,13 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
SmtpHost: settings.SmtpHost,
SmtpPort: settings.SmtpPort,
SmtpUsername: settings.SmtpUsername,
- SmtpPassword: settings.SmtpPassword,
+ SmtpPasswordConfigured: settings.SmtpPasswordConfigured,
SmtpFrom: settings.SmtpFrom,
SmtpFromName: settings.SmtpFromName,
SmtpUseTLS: settings.SmtpUseTLS,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
- TurnstileSecretKey: settings.TurnstileSecretKey,
+ TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
@@ -97,6 +101,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return
}
+ previousSettings, err := h.settingService.GetAllSettings(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
// 验证参数
if req.DefaultConcurrency < 1 {
req.DefaultConcurrency = 1
@@ -136,6 +146,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return
}
+ h.auditSettingsUpdate(c, previousSettings, settings, req)
+
// 重新获取设置返回
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
@@ -149,13 +161,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
SmtpHost: updatedSettings.SmtpHost,
SmtpPort: updatedSettings.SmtpPort,
SmtpUsername: updatedSettings.SmtpUsername,
- SmtpPassword: updatedSettings.SmtpPassword,
+ SmtpPasswordConfigured: updatedSettings.SmtpPasswordConfigured,
SmtpFrom: updatedSettings.SmtpFrom,
SmtpFromName: updatedSettings.SmtpFromName,
SmtpUseTLS: updatedSettings.SmtpUseTLS,
TurnstileEnabled: updatedSettings.TurnstileEnabled,
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
- TurnstileSecretKey: updatedSettings.TurnstileSecretKey,
+ TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
SiteName: updatedSettings.SiteName,
SiteLogo: updatedSettings.SiteLogo,
SiteSubtitle: updatedSettings.SiteSubtitle,
@@ -167,6 +179,91 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
})
}
+func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) {
+ if before == nil || after == nil {
+ return
+ }
+
+ changed := diffSettings(before, after, req)
+ if len(changed) == 0 {
+ return
+ }
+
+ subject, _ := middleware.GetAuthSubjectFromContext(c)
+ role, _ := middleware.GetUserRoleFromContext(c)
+ log.Printf("AUDIT: settings updated at=%s user_id=%d role=%s changed=%v",
+ time.Now().UTC().Format(time.RFC3339),
+ subject.UserID,
+ role,
+ changed,
+ )
+}
+
+func diffSettings(before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) []string {
+ changed := make([]string, 0, 16)
+ if before.RegistrationEnabled != after.RegistrationEnabled {
+ changed = append(changed, "registration_enabled")
+ }
+ if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
+ changed = append(changed, "email_verify_enabled")
+ }
+ if before.SmtpHost != after.SmtpHost {
+ changed = append(changed, "smtp_host")
+ }
+ if before.SmtpPort != after.SmtpPort {
+ changed = append(changed, "smtp_port")
+ }
+ if before.SmtpUsername != after.SmtpUsername {
+ changed = append(changed, "smtp_username")
+ }
+ if req.SmtpPassword != "" {
+ changed = append(changed, "smtp_password")
+ }
+ if before.SmtpFrom != after.SmtpFrom {
+ changed = append(changed, "smtp_from_email")
+ }
+ if before.SmtpFromName != after.SmtpFromName {
+ changed = append(changed, "smtp_from_name")
+ }
+ if before.SmtpUseTLS != after.SmtpUseTLS {
+ changed = append(changed, "smtp_use_tls")
+ }
+ if before.TurnstileEnabled != after.TurnstileEnabled {
+ changed = append(changed, "turnstile_enabled")
+ }
+ if before.TurnstileSiteKey != after.TurnstileSiteKey {
+ changed = append(changed, "turnstile_site_key")
+ }
+ if req.TurnstileSecretKey != "" {
+ changed = append(changed, "turnstile_secret_key")
+ }
+ if before.SiteName != after.SiteName {
+ changed = append(changed, "site_name")
+ }
+ if before.SiteLogo != after.SiteLogo {
+ changed = append(changed, "site_logo")
+ }
+ if before.SiteSubtitle != after.SiteSubtitle {
+ changed = append(changed, "site_subtitle")
+ }
+ if before.ApiBaseUrl != after.ApiBaseUrl {
+ changed = append(changed, "api_base_url")
+ }
+ if before.ContactInfo != after.ContactInfo {
+ changed = append(changed, "contact_info")
+ }
+ if before.DocUrl != after.DocUrl {
+ changed = append(changed, "doc_url")
+ }
+ if before.DefaultConcurrency != after.DefaultConcurrency {
+ changed = append(changed, "default_concurrency")
+ }
+ if before.DefaultBalance != after.DefaultBalance {
+ changed = append(changed, "default_balance")
+ }
+ return changed
+}
+
// TestSmtpRequest 测试SMTP连接请求
type TestSmtpRequest struct {
SmtpHost string `json:"smtp_host" binding:"required"`
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index 96e59e3f..752dcbee 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -8,14 +8,14 @@ type SystemSettings struct {
SmtpHost string `json:"smtp_host"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
- SmtpPassword string `json:"smtp_password,omitempty"`
+ SmtpPasswordConfigured bool `json:"smtp_password_configured"`
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
- TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"`
+ TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go
index a2f833ff..09f2cd48 100644
--- a/backend/internal/handler/gateway_handler.go
+++ b/backend/internal/handler/gateway_handler.go
@@ -11,6 +11,7 @@ import (
"strings"
"time"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -127,7 +128,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 2. 【新增】Wait后二次检查余额/订阅
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
- h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
+ status, code, message := billingErrorDetails(err)
+ h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
@@ -535,7 +537,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 校验 billing eligibility(订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
- h.errorResponse(c, http.StatusForbidden, "billing_error", err.Error())
+ status, code, message := billingErrorDetails(err)
+ h.errorResponse(c, status, code, message)
return
}
@@ -642,3 +645,18 @@ func sendMockWarmupResponse(c *gin.Context, model string) {
},
})
}
+
+func billingErrorDetails(err error) (status int, code, message string) {
+ if errors.Is(err, service.ErrBillingServiceUnavailable) {
+ msg := infraerrors.Message(err)
+ if msg == "" {
+ msg = "Billing service temporarily unavailable. Please retry later."
+ }
+ return http.StatusServiceUnavailable, "billing_service_error", msg
+ }
+ msg := infraerrors.Message(err)
+ if msg == "" {
+ msg = err.Error()
+ }
+ return http.StatusForbidden, "billing_error", msg
+}
diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go
index 4e99e00d..25c7ac78 100644
--- a/backend/internal/handler/gemini_v1beta_handler.go
+++ b/backend/internal/handler/gemini_v1beta_handler.go
@@ -190,7 +190,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 2) billing eligibility check (after wait)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
- googleError(c, http.StatusForbidden, err.Error())
+ status, _, message := billingErrorDetails(err)
+ googleError(c, status, message)
return
}
@@ -329,7 +330,7 @@ func writeUpstreamResponse(c *gin.Context, res *service.UpstreamHTTPResult) {
}
for k, vv := range res.Headers {
// Avoid overriding content-length and hop-by-hop headers.
- if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Connection") {
+ if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Connection") || strings.EqualFold(k, "Www-Authenticate") {
continue
}
for _, v := range vv {
diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go
index 7c9934c6..981257f2 100644
--- a/backend/internal/handler/openai_gateway_handler.go
+++ b/backend/internal/handler/openai_gateway_handler.go
@@ -131,7 +131,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
- h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
+ status, code, message := billingErrorDetails(err)
+ h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
diff --git a/backend/internal/repository/claude_oauth_service.go b/backend/internal/repository/claude_oauth_service.go
index b03b5415..051741aa 100644
--- a/backend/internal/repository/claude_oauth_service.go
+++ b/backend/internal/repository/claude_oauth_service.go
@@ -12,6 +12,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/Wei-Shaw/sub2api/internal/util/logredact"
"github.com/imroc/req/v3"
)
@@ -54,7 +55,7 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
return "", fmt.Errorf("request failed: %w", err)
}
- log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
+ log.Printf("[OAuth] Step 1 Response - Status: %d", resp.StatusCode)
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
@@ -84,8 +85,8 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
"code_challenge_method": "S256",
}
- reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
+ reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
var result struct {
@@ -113,7 +114,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
return "", fmt.Errorf("request failed: %w", err)
}
- log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
+ log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
@@ -141,7 +142,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
fullCode = authCode + "#" + responseState
}
- log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", prefix(authCode, 20))
+ log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code")
return fullCode, nil
}
@@ -173,8 +174,8 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds
}
- reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
+ reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
var tokenResp oauth.TokenResponse
@@ -191,7 +192,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
return nil, fmt.Errorf("request failed: %w", err)
}
- log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
+ log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go
index 8b288c3c..1e55333c 100644
--- a/backend/internal/repository/proxy_probe_service.go
+++ b/backend/internal/repository/proxy_probe_service.go
@@ -8,25 +8,38 @@ import (
"net/http"
"time"
+ "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
+ "log"
)
-func NewProxyExitInfoProber() service.ProxyExitInfoProber {
- return &proxyProbeService{ipInfoURL: defaultIPInfoURL}
+func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
+ insecure := false
+ if cfg != nil {
+ insecure = cfg.Security.ProxyProbe.InsecureSkipVerify
+ }
+ if insecure {
+ log.Printf("[ProxyProbe] Warning: TLS verification is disabled for proxy probing.")
+ }
+ return &proxyProbeService{
+ ipInfoURL: defaultIPInfoURL,
+ insecureSkipVerify: insecure,
+ }
}
const defaultIPInfoURL = "https://ipinfo.io/json"
type proxyProbeService struct {
- ipInfoURL string
+ ipInfoURL string
+ insecureSkipVerify bool
}
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 15 * time.Second,
- InsecureSkipVerify: true,
+ InsecureSkipVerify: s.insecureSkipVerify,
ProxyStrict: true,
})
if err != nil {
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index 5a243bfc..33080207 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -295,13 +295,13 @@ func TestAPIContracts(t *testing.T) {
"smtp_host": "smtp.example.com",
"smtp_port": 587,
"smtp_username": "user",
- "smtp_password": "secret",
+ "smtp_password_configured": true,
"smtp_from_email": "no-reply@example.com",
"smtp_from_name": "Sub2API",
"smtp_use_tls": true,
"turnstile_enabled": true,
"turnstile_site_key": "site-key",
- "turnstile_secret_key": "secret-key",
+ "turnstile_secret_key_configured": true,
"site_name": "Sub2API",
"site_logo": "",
"site_subtitle": "Subtitle",
diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go
index b64220d9..49bca417 100644
--- a/backend/internal/server/http.go
+++ b/backend/internal/server/http.go
@@ -1,6 +1,7 @@
package server
import (
+ "log"
"net/http"
"time"
@@ -35,6 +36,15 @@ func ProvideRouter(
r := gin.New()
r.Use(middleware2.Recovery())
+ if len(cfg.Server.TrustedProxies) > 0 {
+ if err := r.SetTrustedProxies(cfg.Server.TrustedProxies); err != nil {
+ log.Printf("Failed to set trusted proxies: %v", err)
+ }
+ } else {
+ if err := r.SetTrustedProxies(nil); err != nil {
+ log.Printf("Failed to disable trusted proxies: %v", err)
+ }
+ }
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg)
}
diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go
index 75e508dd..3cfbd04a 100644
--- a/backend/internal/server/middleware/api_key_auth.go
+++ b/backend/internal/server/middleware/api_key_auth.go
@@ -19,6 +19,13 @@ func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionS
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
+ queryKey := strings.TrimSpace(c.Query("key"))
+ queryApiKey := strings.TrimSpace(c.Query("api_key"))
+ if queryKey != "" || queryApiKey != "" {
+ AbortWithError(c, 400, "api_key_in_query_deprecated", "API key in query parameter is deprecated. Please use Authorization header instead.")
+ return
+ }
+
// 尝试从Authorization header中提取API key (Bearer scheme)
authHeader := c.GetHeader("Authorization")
var apiKeyString string
@@ -41,19 +48,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
apiKeyString = c.GetHeader("x-goog-api-key")
}
- // 如果header中没有,尝试从query参数中提取(Google API key风格)
- if apiKeyString == "" {
- apiKeyString = c.Query("key")
- }
-
- // 兼容常见别名
- if apiKeyString == "" {
- apiKeyString = c.Query("api_key")
- }
-
// 如果所有header都没有API key
if apiKeyString == "" {
- AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter")
+ AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, or x-goog-api-key header")
return
}
diff --git a/backend/internal/server/middleware/api_key_auth_google.go b/backend/internal/server/middleware/api_key_auth_google.go
index d8f47bd2..05cddb1d 100644
--- a/backend/internal/server/middleware/api_key_auth_google.go
+++ b/backend/internal/server/middleware/api_key_auth_google.go
@@ -22,6 +22,10 @@ func ApiKeyAuthGoogle(apiKeyService *service.ApiKeyService, cfg *config.Config)
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
+ if v := strings.TrimSpace(c.Query("api_key")); v != "" {
+ abortWithGoogleError(c, 400, "Query parameter api_key is deprecated. Use Authorization header or key instead.")
+ return
+ }
apiKeyString := extractAPIKeyFromRequest(c)
if apiKeyString == "" {
abortWithGoogleError(c, 401, "API key is required")
@@ -116,15 +120,18 @@ func extractAPIKeyFromRequest(c *gin.Context) string {
if v := strings.TrimSpace(c.GetHeader("x-goog-api-key")); v != "" {
return v
}
- if v := strings.TrimSpace(c.Query("key")); v != "" {
- return v
- }
- if v := strings.TrimSpace(c.Query("api_key")); v != "" {
- return v
+ if allowGoogleQueryKey(c.Request.URL.Path) {
+ if v := strings.TrimSpace(c.Query("key")); v != "" {
+ return v
+ }
}
return ""
}
+func allowGoogleQueryKey(path string) bool {
+ return strings.HasPrefix(path, "/v1beta") || strings.HasPrefix(path, "/antigravity/v1beta")
+}
+
func abortWithGoogleError(c *gin.Context, status int, message string) {
c.JSON(status, gin.H{
"error": gin.H{
diff --git a/backend/internal/server/middleware/api_key_auth_google_test.go b/backend/internal/server/middleware/api_key_auth_google_test.go
index 04d67977..51a888a7 100644
--- a/backend/internal/server/middleware/api_key_auth_google_test.go
+++ b/backend/internal/server/middleware/api_key_auth_google_test.go
@@ -109,6 +109,58 @@ func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
}
+func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ r := gin.New()
+ apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
+ getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
+ return nil, errors.New("should not be called")
+ },
+ })
+ r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
+ r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
+
+ req := httptest.NewRequest(http.MethodGet, "/v1beta/test?api_key=legacy", nil)
+ rec := httptest.NewRecorder()
+ r.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ var resp googleErrorResponse
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ require.Equal(t, http.StatusBadRequest, resp.Error.Code)
+ require.Equal(t, "Query parameter api_key is deprecated. Use Authorization header or key instead.", resp.Error.Message)
+ require.Equal(t, "INVALID_ARGUMENT", resp.Error.Status)
+}
+
+func TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ r := gin.New()
+ apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
+ getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
+ return &service.ApiKey{
+ ID: 1,
+ Key: key,
+ Status: service.StatusActive,
+ User: &service.User{
+ ID: 123,
+ Status: service.StatusActive,
+ },
+ }, nil
+ },
+ })
+ cfg := &config.Config{RunMode: config.RunModeSimple}
+ r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
+ r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
+
+ req := httptest.NewRequest(http.MethodGet, "/v1beta/test?key=valid", nil)
+ rec := httptest.NewRecorder()
+ r.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+}
+
func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
gin.SetMode(gin.TestMode)
diff --git a/backend/internal/server/middleware/cors.go b/backend/internal/server/middleware/cors.go
index bc16279f..7d82f183 100644
--- a/backend/internal/server/middleware/cors.go
+++ b/backend/internal/server/middleware/cors.go
@@ -1,24 +1,103 @@
package middleware
import (
+ "log"
+ "net/http"
+ "strings"
+ "sync"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
+var corsWarningOnce sync.Once
+
// CORS 跨域中间件
-func CORS() gin.HandlerFunc {
+func CORS(cfg config.CORSConfig) gin.HandlerFunc {
+ allowedOrigins := normalizeOrigins(cfg.AllowedOrigins)
+ allowAll := false
+ for _, origin := range allowedOrigins {
+ if origin == "*" {
+ allowAll = true
+ break
+ }
+ }
+ wildcardWithSpecific := allowAll && len(allowedOrigins) > 1
+ if wildcardWithSpecific {
+ allowedOrigins = []string{"*"}
+ }
+ allowCredentials := cfg.AllowCredentials
+
+ corsWarningOnce.Do(func() {
+ if len(allowedOrigins) == 0 {
+ log.Println("Warning: CORS allowed_origins not configured; cross-origin requests will be rejected.")
+ }
+ if wildcardWithSpecific {
+ log.Println("Warning: CORS allowed_origins includes '*'; wildcard will take precedence over explicit origins.")
+ }
+ if allowAll && allowCredentials {
+ log.Println("Warning: CORS allowed_origins set to '*', disabling allow_credentials.")
+ }
+ })
+ if allowAll && allowCredentials {
+ allowCredentials = false
+ }
+
+ allowedSet := make(map[string]struct{}, len(allowedOrigins))
+ for _, origin := range allowedOrigins {
+ if origin == "" || origin == "*" {
+ continue
+ }
+ allowedSet[origin] = struct{}{}
+ }
+
return func(c *gin.Context) {
- // 设置允许跨域的响应头
- c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
- c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
+ origin := strings.TrimSpace(c.GetHeader("Origin"))
+ originAllowed := allowAll
+ if origin != "" && !allowAll {
+ _, originAllowed = allowedSet[origin]
+ }
+
+ if originAllowed {
+ if allowAll {
+ c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
+ } else if origin != "" {
+ c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
+ c.Writer.Header().Add("Vary", "Origin")
+ }
+ if allowCredentials {
+ c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
+ }
+ }
+
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
// 处理预检请求
- if c.Request.Method == "OPTIONS" {
- c.AbortWithStatus(204)
+ if c.Request.Method == http.MethodOptions {
+ if originAllowed {
+ c.AbortWithStatus(http.StatusNoContent)
+ } else {
+ c.AbortWithStatus(http.StatusForbidden)
+ }
return
}
c.Next()
}
}
+
+func normalizeOrigins(values []string) []string {
+ if len(values) == 0 {
+ return nil
+ }
+ normalized := make([]string, 0, len(values))
+ for _, value := range values {
+ trimmed := strings.TrimSpace(value)
+ if trimmed == "" {
+ continue
+ }
+ normalized = append(normalized, trimmed)
+ }
+ return normalized
+}
diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go
new file mode 100644
index 00000000..9fca0cd3
--- /dev/null
+++ b/backend/internal/server/middleware/security_headers.go
@@ -0,0 +1,26 @@
+package middleware
+
+import (
+ "strings"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/gin-gonic/gin"
+)
+
+// SecurityHeaders sets baseline security headers for all responses.
+func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
+ policy := strings.TrimSpace(cfg.Policy)
+ if policy == "" {
+ policy = config.DefaultCSPPolicy
+ }
+
+ return func(c *gin.Context) {
+ c.Header("X-Content-Type-Options", "nosniff")
+ c.Header("X-Frame-Options", "DENY")
+ c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
+ if cfg.Enabled {
+ c.Header("Content-Security-Policy", policy)
+ }
+ c.Next()
+ }
+}
diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go
index 2371dafb..fd43e98a 100644
--- a/backend/internal/server/router.go
+++ b/backend/internal/server/router.go
@@ -24,7 +24,8 @@ func SetupRouter(
) *gin.Engine {
// 应用中间件
r.Use(middleware2.Logger())
- r.Use(middleware2.CORS())
+ r.Use(middleware2.CORS(cfg.CORS))
+ r.Use(middleware2.SecurityHeaders(cfg.Security.CSP))
// Serve embedded frontend if available
if web.HasEmbeddedFrontend() {
diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go
index 7dd451cd..5797e497 100644
--- a/backend/internal/service/account_test_service.go
+++ b/backend/internal/service/account_test_service.go
@@ -7,6 +7,7 @@ import (
"crypto/rand"
"encoding/hex"
"encoding/json"
+ "errors"
"fmt"
"io"
"log"
@@ -15,9 +16,11 @@ import (
"strings"
"time"
+ "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
+ "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
@@ -49,6 +52,7 @@ type AccountTestService struct {
geminiTokenProvider *GeminiTokenProvider
antigravityGatewayService *AntigravityGatewayService
httpUpstream HTTPUpstream
+ cfg *config.Config
}
// NewAccountTestService creates a new AccountTestService
@@ -59,6 +63,7 @@ func NewAccountTestService(
geminiTokenProvider *GeminiTokenProvider,
antigravityGatewayService *AntigravityGatewayService,
httpUpstream HTTPUpstream,
+ cfg *config.Config,
) *AccountTestService {
return &AccountTestService{
accountRepo: accountRepo,
@@ -67,9 +72,25 @@ func NewAccountTestService(
geminiTokenProvider: geminiTokenProvider,
antigravityGatewayService: antigravityGatewayService,
httpUpstream: httpUpstream,
+ cfg: cfg,
}
}
+func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error) {
+ if s.cfg == nil {
+ return "", errors.New("config is not available")
+ }
+ normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
+ AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
+ RequireAllowlist: true,
+ AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
+ })
+ if err != nil {
+ return "", err
+ }
+ return normalized, nil
+}
+
// generateSessionString generates a Claude Code style session string
func generateSessionString() (string, error) {
bytes := make([]byte, 32)
@@ -207,11 +228,15 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
return s.sendErrorAndEnd(c, "No API key available")
}
- apiURL = account.GetBaseURL()
- if apiURL == "" {
- apiURL = "https://api.anthropic.com"
+ baseURL := account.GetBaseURL()
+ if baseURL == "" {
+ baseURL = "https://api.anthropic.com"
}
- apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages"
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
+ }
+ apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages"
} else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
@@ -333,7 +358,11 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
if baseURL == "" {
baseURL = "https://api.openai.com"
}
- apiURL = strings.TrimSuffix(baseURL, "/") + "/responses"
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
+ }
+ apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/responses"
} else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
@@ -513,10 +542,14 @@ func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, accou
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, err
+ }
// Use streamGenerateContent for real-time feedback
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse",
- strings.TrimRight(baseURL, "/"), modelID)
+ strings.TrimRight(normalizedBaseURL, "/"), modelID)
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(payload))
if err != nil {
@@ -548,7 +581,11 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun
if strings.TrimSpace(baseURL) == "" {
baseURL = geminicli.AIStudioBaseURL
}
- fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(baseURL, "/"), modelID)
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, err
+ }
+ fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(normalizedBaseURL, "/"), modelID)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
if err != nil {
@@ -577,7 +614,11 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT
}
wrappedBytes, _ := json.Marshal(wrapped)
- fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", geminicli.GeminiCliBaseURL)
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
+ if err != nil {
+ return nil, err
+ }
+ fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", normalizedBaseURL)
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(wrappedBytes))
if err != nil {
diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go
index 54bbfa5c..b29aa1dc 100644
--- a/backend/internal/service/auth_service.go
+++ b/backend/internal/service/auth_service.go
@@ -221,9 +221,33 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
// VerifyTurnstile 验证Turnstile token
func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error {
+ required := s.cfg != nil && s.cfg.Server.Mode == "release" && s.cfg.Turnstile.Required
+
+ if required {
+ if s.settingService == nil {
+ log.Println("[Auth] Turnstile required but settings service is not configured")
+ return ErrTurnstileNotConfigured
+ }
+ enabled := s.settingService.IsTurnstileEnabled(ctx)
+ secretConfigured := s.settingService.GetTurnstileSecretKey(ctx) != ""
+ if !enabled || !secretConfigured {
+ log.Printf("[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)", enabled, secretConfigured)
+ return ErrTurnstileNotConfigured
+ }
+ }
+
if s.turnstileService == nil {
+ if required {
+ log.Println("[Auth] Turnstile required but service not configured")
+ return ErrTurnstileNotConfigured
+ }
return nil // 服务未配置则跳过验证
}
+
+ if !required && s.settingService != nil && s.settingService.IsTurnstileEnabled(ctx) && s.settingService.GetTurnstileSecretKey(ctx) == "" {
+ log.Println("[Auth] Turnstile enabled but secret key not configured")
+ }
+
return s.turnstileService.VerifyToken(ctx, token, remoteIP)
}
diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go
index 58ed555a..b2111b8b 100644
--- a/backend/internal/service/billing_cache_service.go
+++ b/backend/internal/service/billing_cache_service.go
@@ -17,6 +17,7 @@ import (
// 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
var (
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
+ ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
)
// subscriptionCacheData 订阅缓存数据结构(内部使用)
@@ -76,6 +77,7 @@ type BillingCacheService struct {
userRepo UserRepository
subRepo UserSubscriptionRepository
cfg *config.Config
+ circuitBreaker *billingCircuitBreaker
cacheWriteChan chan cacheWriteTask
cacheWriteWg sync.WaitGroup
@@ -95,6 +97,7 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo
subRepo: subRepo,
cfg: cfg,
}
+ svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
svc.startCacheWriteWorkers()
return svc
}
@@ -450,6 +453,9 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
if s.cfg.RunMode == config.RunModeSimple {
return nil
}
+ if s.circuitBreaker != nil && !s.circuitBreaker.Allow() {
+ return ErrBillingServiceUnavailable
+ }
// 判断计费模式
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
@@ -465,9 +471,14 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userID int64) error {
balance, err := s.GetUserBalance(ctx, userID)
if err != nil {
- // 缓存/数据库错误,允许通过(降级处理)
- log.Printf("Warning: get user balance failed, allowing request: %v", err)
- return nil
+ if s.circuitBreaker != nil {
+ s.circuitBreaker.OnFailure(err)
+ }
+ log.Printf("ALERT: billing balance check failed for user %d: %v", userID, err)
+ return ErrBillingServiceUnavailable.WithCause(err)
+ }
+ if s.circuitBreaker != nil {
+ s.circuitBreaker.OnSuccess()
}
if balance <= 0 {
@@ -482,9 +493,14 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
// 获取订阅缓存数据
subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID)
if err != nil {
- // 缓存/数据库错误,降级使用传入的subscription进行检查
- log.Printf("Warning: get subscription cache failed, using fallback: %v", err)
- return s.checkSubscriptionLimitsFallback(subscription, group)
+ if s.circuitBreaker != nil {
+ s.circuitBreaker.OnFailure(err)
+ }
+ log.Printf("ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err)
+ return ErrBillingServiceUnavailable.WithCause(err)
+ }
+ if s.circuitBreaker != nil {
+ s.circuitBreaker.OnSuccess()
}
// 检查订阅状态
@@ -513,6 +529,137 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
return nil
}
+type billingCircuitBreakerState int
+
+const (
+ billingCircuitClosed billingCircuitBreakerState = iota
+ billingCircuitOpen
+ billingCircuitHalfOpen
+)
+
+type billingCircuitBreaker struct {
+ mu sync.Mutex
+ state billingCircuitBreakerState
+ failures int
+ openedAt time.Time
+ failureThreshold int
+ resetTimeout time.Duration
+ halfOpenRequests int
+ halfOpenRemaining int
+}
+
+func newBillingCircuitBreaker(cfg config.CircuitBreakerConfig) *billingCircuitBreaker {
+ if !cfg.Enabled {
+ return nil
+ }
+ resetTimeout := time.Duration(cfg.ResetTimeoutSeconds) * time.Second
+ if resetTimeout <= 0 {
+ resetTimeout = 30 * time.Second
+ }
+ halfOpen := cfg.HalfOpenRequests
+ if halfOpen <= 0 {
+ halfOpen = 1
+ }
+ threshold := cfg.FailureThreshold
+ if threshold <= 0 {
+ threshold = 5
+ }
+ return &billingCircuitBreaker{
+ state: billingCircuitClosed,
+ failureThreshold: threshold,
+ resetTimeout: resetTimeout,
+ halfOpenRequests: halfOpen,
+ }
+}
+
+func (b *billingCircuitBreaker) Allow() bool {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+
+ switch b.state {
+ case billingCircuitClosed:
+ return true
+ case billingCircuitOpen:
+ if time.Since(b.openedAt) < b.resetTimeout {
+ return false
+ }
+ b.state = billingCircuitHalfOpen
+ b.halfOpenRemaining = b.halfOpenRequests
+ log.Printf("ALERT: billing circuit breaker entering half-open state")
+ fallthrough
+ case billingCircuitHalfOpen:
+ if b.halfOpenRemaining <= 0 {
+ return false
+ }
+ b.halfOpenRemaining--
+ return true
+ default:
+ return false
+ }
+}
+
+func (b *billingCircuitBreaker) OnFailure(err error) {
+ if b == nil {
+ return
+ }
+ b.mu.Lock()
+ defer b.mu.Unlock()
+
+ switch b.state {
+ case billingCircuitOpen:
+ return
+ case billingCircuitHalfOpen:
+ b.state = billingCircuitOpen
+ b.openedAt = time.Now()
+ b.halfOpenRemaining = 0
+ log.Printf("ALERT: billing circuit breaker opened after half-open failure: %v", err)
+ return
+ default:
+ b.failures++
+ if b.failures >= b.failureThreshold {
+ b.state = billingCircuitOpen
+ b.openedAt = time.Now()
+ b.halfOpenRemaining = 0
+ log.Printf("ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err)
+ }
+ }
+}
+
+func (b *billingCircuitBreaker) OnSuccess() {
+ if b == nil {
+ return
+ }
+ b.mu.Lock()
+ defer b.mu.Unlock()
+
+ previousState := b.state
+ previousFailures := b.failures
+
+ b.state = billingCircuitClosed
+ b.failures = 0
+ b.halfOpenRemaining = 0
+
+ // 只有状态真正发生变化时才记录日志
+ if previousState != billingCircuitClosed {
+ log.Printf("ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState))
+ } else if previousFailures > 0 {
+ log.Printf("INFO: billing circuit breaker failures reset from %d", previousFailures)
+ }
+}
+
+func circuitStateString(state billingCircuitBreakerState) string {
+ switch state {
+ case billingCircuitClosed:
+ return "closed"
+ case billingCircuitOpen:
+ return "open"
+ case billingCircuitHalfOpen:
+ return "half-open"
+ default:
+ return "unknown"
+ }
+}
+
// checkSubscriptionLimitsFallback 降级检查订阅限额
func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *UserSubscription, group *Group) error {
if subscription == nil {
diff --git a/backend/internal/service/crs_sync_service.go b/backend/internal/service/crs_sync_service.go
index fd23ecb2..34efcae4 100644
--- a/backend/internal/service/crs_sync_service.go
+++ b/backend/internal/service/crs_sync_service.go
@@ -8,12 +8,13 @@ import (
"fmt"
"io"
"net/http"
- "net/url"
"strconv"
"strings"
"time"
+ "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
+ "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
type CRSSyncService struct {
@@ -22,6 +23,7 @@ type CRSSyncService struct {
oauthService *OAuthService
openaiOAuthService *OpenAIOAuthService
geminiOAuthService *GeminiOAuthService
+ cfg *config.Config
}
func NewCRSSyncService(
@@ -30,6 +32,7 @@ func NewCRSSyncService(
oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
+ cfg *config.Config,
) *CRSSyncService {
return &CRSSyncService{
accountRepo: accountRepo,
@@ -37,6 +40,7 @@ func NewCRSSyncService(
oauthService: oauthService,
openaiOAuthService: openaiOAuthService,
geminiOAuthService: geminiOAuthService,
+ cfg: cfg,
}
}
@@ -187,7 +191,10 @@ type crsGeminiAPIKeyAccount struct {
}
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
- baseURL, err := normalizeBaseURL(input.BaseURL)
+ if s.cfg == nil {
+ return nil, errors.New("config is not available")
+ }
+ baseURL, err := normalizeBaseURL(input.BaseURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
if err != nil {
return nil, err
}
@@ -1055,17 +1062,18 @@ func mapCRSStatus(isActive bool, status string) string {
return "active"
}
-func normalizeBaseURL(raw string) (string, error) {
- trimmed := strings.TrimSpace(raw)
- if trimmed == "" {
- return "", errors.New("base_url is required")
+func normalizeBaseURL(raw string, allowlist []string, allowPrivate bool) (string, error) {
+ // 当 allowlist 为空时,不强制要求白名单(只进行基本的 URL 和 SSRF 验证)
+ requireAllowlist := len(allowlist) > 0
+ normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
+ AllowedHosts: allowlist,
+ RequireAllowlist: requireAllowlist,
+ AllowPrivate: allowPrivate,
+ })
+ if err != nil {
+ return "", fmt.Errorf("invalid base_url: %w", err)
}
- u, err := url.Parse(trimmed)
- if err != nil || u.Scheme == "" || u.Host == "" {
- return "", fmt.Errorf("invalid base_url: %s", trimmed)
- }
- u.Path = strings.TrimRight(u.Path, "/")
- return strings.TrimRight(u.String(), "/"), nil
+ return normalized, nil
}
// cleanBaseURL removes trailing suffix from base_url in credentials
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index d542e9c2..878ee722 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -19,6 +19,8 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
+ "github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
+ "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/tidwall/sjson"
"github.com/gin-gonic/gin"
@@ -724,7 +726,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
targetURL := claudeAPIURL
if account.Type == AccountTypeApiKey {
baseURL := account.GetBaseURL()
- targetURL = baseURL + "/v1/messages"
+ if baseURL != "" {
+ validatedURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, err
+ }
+ targetURL = validatedURL + "/v1/messages"
+ }
}
// OAuth账号:应用统一指纹
@@ -1107,12 +1115,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
- // 透传响应头
- for key, values := range resp.Header {
- for _, value := range values {
- c.Header(key, value)
- }
- }
+ responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
// 写入响应
c.Data(resp.StatusCode, "application/json", body)
@@ -1352,7 +1355,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeApiKey {
baseURL := account.GetBaseURL()
- targetURL = baseURL + "/v1/messages/count_tokens"
+ if baseURL != "" {
+ validatedURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, err
+ }
+ targetURL = validatedURL + "/v1/messages/count_tokens"
+ }
}
// OAuth 账号:应用统一指纹和重写 userID
@@ -1424,3 +1433,15 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
},
})
}
+
+func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
+ normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
+ AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
+ RequireAllowlist: true,
+ AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
+ })
+ if err != nil {
+ return "", fmt.Errorf("invalid base_url: %w", err)
+ }
+ return normalized, nil
+}
diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go
index ee3ade16..35f27f8d 100644
--- a/backend/internal/service/gemini_messages_compat_service.go
+++ b/backend/internal/service/gemini_messages_compat_service.go
@@ -18,9 +18,12 @@ import (
"strings"
"time"
+ "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
+ "github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
+ "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
)
@@ -41,6 +44,7 @@ type GeminiMessagesCompatService struct {
rateLimitService *RateLimitService
httpUpstream HTTPUpstream
antigravityGatewayService *AntigravityGatewayService
+ cfg *config.Config
}
func NewGeminiMessagesCompatService(
@@ -51,6 +55,7 @@ func NewGeminiMessagesCompatService(
rateLimitService *RateLimitService,
httpUpstream HTTPUpstream,
antigravityGatewayService *AntigravityGatewayService,
+ cfg *config.Config,
) *GeminiMessagesCompatService {
return &GeminiMessagesCompatService{
accountRepo: accountRepo,
@@ -60,6 +65,7 @@ func NewGeminiMessagesCompatService(
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
antigravityGatewayService: antigravityGatewayService,
+ cfg: cfg,
}
}
@@ -209,6 +215,18 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit
return s.antigravityGatewayService
}
+func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
+ normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
+ AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
+ RequireAllowlist: true,
+ AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
+ })
+ if err != nil {
+ return "", fmt.Errorf("invalid base_url: %w", err)
+ }
+ return normalized, nil
+}
+
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
var accounts []Account
@@ -360,16 +378,20 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return nil, "", errors.New("gemini api_key not configured")
}
- baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
+ baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, "", err
+ }
action := "generateContent"
if req.Stream {
action = "streamGenerateContent"
}
- fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, action)
+ fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
if req.Stream {
fullURL += "?alt=sse"
}
@@ -406,7 +428,11 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if projectID != "" {
// Mode 1: Code Assist API
- fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, action)
+ baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
+ if err != nil {
+ return nil, "", err
+ }
+ fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), action)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -432,12 +458,16 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return upstreamReq, "x-request-id", nil
} else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
- baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
+ baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, "", err
+ }
- fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, action)
+ fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -622,12 +652,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return nil, "", errors.New("gemini api_key not configured")
}
- baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
+ baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, "", err
+ }
- fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, upstreamAction)
+ fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -659,7 +693,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if projectID != "" && !forceAIStudio {
// Mode 1: Code Assist API
- fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, upstreamAction)
+ baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
+ if err != nil {
+ return nil, "", err
+ }
+ fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -685,12 +723,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return upstreamReq, "x-request-id", nil
} else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
- baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
+ baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, "", err
+ }
- fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, upstreamAction)
+ fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
if useUpstreamStream {
fullURL += "?alt=sse"
}
@@ -1608,6 +1650,8 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
_ = json.Unmarshal(respBody, &parsed)
}
+ responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
+
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
@@ -1720,11 +1764,15 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
return nil, errors.New("invalid path")
}
- baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
+ baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
- fullURL := strings.TrimRight(baseURL, "/") + path
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, err
+ }
+ fullURL := strings.TrimRight(normalizedBaseURL, "/") + path
var proxyURL string
if account.ProxyID != nil && account.Proxy != nil {
@@ -1763,9 +1811,14 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
+ wwwAuthenticate := resp.Header.Get("Www-Authenticate")
+ filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.cfg.Security.ResponseHeaders)
+ if wwwAuthenticate != "" {
+ filteredHeaders.Set("Www-Authenticate", wwwAuthenticate)
+ }
return &UpstreamHTTPResult{
StatusCode: resp.StatusCode,
- Headers: resp.Header.Clone(),
+ Headers: filteredHeaders,
Body: body,
}, nil
}
diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go
index 84e98679..c3d3cab5 100644
--- a/backend/internal/service/openai_gateway_service.go
+++ b/backend/internal/service/openai_gateway_service.go
@@ -18,6 +18,8 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
+ "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
)
@@ -370,10 +372,14 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
case AccountTypeApiKey:
// API Key accounts use Platform API or custom base URL
baseURL := account.GetOpenAIBaseURL()
- if baseURL != "" {
- targetURL = baseURL + "/responses"
- } else {
+ if baseURL == "" {
targetURL = openaiPlatformAPIURL
+ } else {
+ validatedURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, err
+ }
+ targetURL = validatedURL + "/responses"
}
default:
targetURL = openaiPlatformAPIURL
@@ -645,18 +651,25 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
- // Pass through headers
- for key, values := range resp.Header {
- for _, value := range values {
- c.Header(key, value)
- }
- }
+ responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
c.Data(resp.StatusCode, "application/json", body)
return usage, nil
}
+func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) {
+ normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
+ AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
+ RequireAllowlist: true,
+ AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
+ })
+ if err != nil {
+ return "", fmt.Errorf("invalid base_url: %w", err)
+ }
+ return normalized, nil
+}
+
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go
index bb050d0a..58a24c0d 100644
--- a/backend/internal/service/pricing_service.go
+++ b/backend/internal/service/pricing_service.go
@@ -16,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
+ "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
var (
@@ -211,16 +212,35 @@ func (s *PricingService) syncWithRemote() error {
// downloadPricingData 从远程下载价格数据
func (s *PricingService) downloadPricingData() error {
- log.Printf("[Pricing] Downloading from %s", s.cfg.Pricing.RemoteURL)
+ remoteURL, err := s.validatePricingURL(s.cfg.Pricing.RemoteURL)
+ if err != nil {
+ return err
+ }
+ log.Printf("[Pricing] Downloading from %s", remoteURL)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
- body, err := s.remoteClient.FetchPricingJSON(ctx, s.cfg.Pricing.RemoteURL)
+ var expectedHash string
+ if strings.TrimSpace(s.cfg.Pricing.HashURL) != "" {
+ expectedHash, err = s.fetchRemoteHash()
+ if err != nil {
+ return fmt.Errorf("fetch remote hash: %w", err)
+ }
+ }
+
+ body, err := s.remoteClient.FetchPricingJSON(ctx, remoteURL)
if err != nil {
return fmt.Errorf("download failed: %w", err)
}
+ if expectedHash != "" {
+ actualHash := sha256.Sum256(body)
+ if !strings.EqualFold(expectedHash, hex.EncodeToString(actualHash[:])) {
+ return fmt.Errorf("pricing hash mismatch")
+ }
+ }
+
// 解析JSON数据(使用灵活的解析方式)
data, err := s.parsePricingData(body)
if err != nil {
@@ -373,10 +393,31 @@ func (s *PricingService) useFallbackPricing() error {
// fetchRemoteHash 从远程获取哈希值
func (s *PricingService) fetchRemoteHash() (string, error) {
+ hashURL, err := s.validatePricingURL(s.cfg.Pricing.HashURL)
+ if err != nil {
+ return "", err
+ }
+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
- return s.remoteClient.FetchHashText(ctx, s.cfg.Pricing.HashURL)
+ hash, err := s.remoteClient.FetchHashText(ctx, hashURL)
+ if err != nil {
+ return "", err
+ }
+ return strings.TrimSpace(hash), nil
+}
+
+func (s *PricingService) validatePricingURL(raw string) (string, error) {
+ normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
+ AllowedHosts: s.cfg.Security.URLAllowlist.PricingHosts,
+ RequireAllowlist: true,
+ AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
+ })
+ if err != nil {
+ return "", fmt.Errorf("invalid pricing url: %w", err)
+ }
+ return normalized, nil
}
// computeFileHash 计算文件哈希
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index 0ffe991d..fc8859ca 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -215,8 +215,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
SmtpFrom: settings[SettingKeySmtpFrom],
SmtpFromName: settings[SettingKeySmtpFromName],
SmtpUseTLS: settings[SettingKeySmtpUseTLS] == "true",
+ SmtpPasswordConfigured: settings[SettingKeySmtpPassword] != "",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
+ TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "",
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
@@ -245,10 +247,6 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.DefaultBalance = s.cfg.Default.UserBalance
}
- // 敏感信息直接返回,方便测试连接时使用
- result.SmtpPassword = settings[SettingKeySmtpPassword]
- result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
-
return result
}
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index cb9751d1..11c64f13 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -8,6 +8,7 @@ type SystemSettings struct {
SmtpPort int
SmtpUsername string
SmtpPassword string
+ SmtpPasswordConfigured bool
SmtpFrom string
SmtpFromName string
SmtpUseTLS bool
@@ -15,6 +16,7 @@ type SystemSettings struct {
TurnstileEnabled bool
TurnstileSiteKey string
TurnstileSecretKey string
+ TurnstileSecretKeyConfigured bool
SiteName string
SiteLogo string
diff --git a/backend/internal/setup/setup.go b/backend/internal/setup/setup.go
index 5565ab91..344f2db7 100644
--- a/backend/internal/setup/setup.go
+++ b/backend/internal/setup/setup.go
@@ -9,6 +9,7 @@ import (
"log"
"os"
"strconv"
+ "strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
@@ -196,11 +197,17 @@ func Install(cfg *SetupConfig) error {
// Generate JWT secret if not provided
if cfg.JWT.Secret == "" {
+ if strings.EqualFold(cfg.Server.Mode, "release") {
+ return fmt.Errorf("jwt secret is required in release mode")
+ }
secret, err := generateSecret(32)
if err != nil {
return fmt.Errorf("failed to generate jwt secret: %w", err)
}
cfg.JWT.Secret = secret
+ log.Println("Warning: JWT secret auto-generated for non-release mode. Do not use in production.")
+ } else if strings.EqualFold(cfg.Server.Mode, "release") && len(cfg.JWT.Secret) < 32 {
+ return fmt.Errorf("jwt secret must be at least 32 characters in release mode")
}
// Test connections
@@ -474,12 +481,17 @@ func AutoSetupFromEnv() error {
// Generate JWT secret if not provided
if cfg.JWT.Secret == "" {
+ if strings.EqualFold(cfg.Server.Mode, "release") {
+ return fmt.Errorf("jwt secret is required in release mode")
+ }
secret, err := generateSecret(32)
if err != nil {
return fmt.Errorf("failed to generate jwt secret: %w", err)
}
cfg.JWT.Secret = secret
- log.Println("Generated JWT secret automatically")
+ log.Println("Warning: JWT secret auto-generated for non-release mode. Do not use in production.")
+ } else if strings.EqualFold(cfg.Server.Mode, "release") && len(cfg.JWT.Secret) < 32 {
+ return fmt.Errorf("jwt secret must be at least 32 characters in release mode")
}
// Generate admin password if not provided
@@ -489,8 +501,8 @@ func AutoSetupFromEnv() error {
return fmt.Errorf("failed to generate admin password: %w", err)
}
cfg.Admin.Password = password
- log.Printf("Generated admin password: %s", cfg.Admin.Password)
- log.Println("IMPORTANT: Save this password! It will not be shown again.")
+ fmt.Printf("Generated admin password (one-time): %s\n", cfg.Admin.Password)
+ fmt.Println("IMPORTANT: Save this password! It will not be shown again.")
}
// Test database connection
diff --git a/backend/internal/util/logredact/redact.go b/backend/internal/util/logredact/redact.go
new file mode 100644
index 00000000..b2d2429f
--- /dev/null
+++ b/backend/internal/util/logredact/redact.go
@@ -0,0 +1,100 @@
+package logredact
+
+import (
+ "encoding/json"
+ "strings"
+)
+
+// maxRedactDepth 限制递归深度以防止栈溢出
+const maxRedactDepth = 32
+
+var defaultSensitiveKeys = map[string]struct{}{
+ "authorization_code": {},
+ "code": {},
+ "code_verifier": {},
+ "access_token": {},
+ "refresh_token": {},
+ "id_token": {},
+ "client_secret": {},
+ "password": {},
+}
+
+func RedactMap(input map[string]any, extraKeys ...string) map[string]any {
+ if input == nil {
+ return map[string]any{}
+ }
+ keys := buildKeySet(extraKeys)
+ redacted, ok := redactValueWithDepth(input, keys, 0).(map[string]any)
+ if !ok {
+ return map[string]any{}
+ }
+ return redacted
+}
+
+func RedactJSON(raw []byte, extraKeys ...string) string {
+ if len(raw) == 0 {
+ return ""
+ }
+ var value any
+ if err := json.Unmarshal(raw, &value); err != nil {
+ return ""
+ }
+ keys := buildKeySet(extraKeys)
+ redacted := redactValueWithDepth(value, keys, 0)
+ encoded, err := json.Marshal(redacted)
+ if err != nil {
+ return ""
+ }
+ return string(encoded)
+}
+
+func buildKeySet(extraKeys []string) map[string]struct{} {
+ keys := make(map[string]struct{}, len(defaultSensitiveKeys)+len(extraKeys))
+ for k := range defaultSensitiveKeys {
+ keys[k] = struct{}{}
+ }
+ for _, key := range extraKeys {
+ normalized := normalizeKey(key)
+ if normalized == "" {
+ continue
+ }
+ keys[normalized] = struct{}{}
+ }
+ return keys
+}
+
+func redactValueWithDepth(value any, keys map[string]struct{}, depth int) any {
+ if depth > maxRedactDepth {
+ return ""
+ }
+
+ switch v := value.(type) {
+ case map[string]any:
+ out := make(map[string]any, len(v))
+ for k, val := range v {
+ if isSensitiveKey(k, keys) {
+ out[k] = "***"
+ continue
+ }
+ out[k] = redactValueWithDepth(val, keys, depth+1)
+ }
+ return out
+ case []any:
+ out := make([]any, len(v))
+ for i, item := range v {
+ out[i] = redactValueWithDepth(item, keys, depth+1)
+ }
+ return out
+ default:
+ return value
+ }
+}
+
+func isSensitiveKey(key string, keys map[string]struct{}) bool {
+ _, ok := keys[normalizeKey(key)]
+ return ok
+}
+
+func normalizeKey(key string) string {
+ return strings.ToLower(strings.TrimSpace(key))
+}
diff --git a/backend/internal/util/responseheaders/responseheaders.go b/backend/internal/util/responseheaders/responseheaders.go
new file mode 100644
index 00000000..3635f1b4
--- /dev/null
+++ b/backend/internal/util/responseheaders/responseheaders.go
@@ -0,0 +1,92 @@
+package responseheaders
+
+import (
+ "net/http"
+ "strings"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+)
+
+// defaultAllowed 定义允许透传的响应头白名单
+// 注意:以下头部由 Go HTTP 包自动处理,不应手动设置:
+// - content-length: 由 ResponseWriter 根据实际写入数据自动设置
+// - transfer-encoding: 由 HTTP 库根据需要自动添加/移除
+// - connection: 由 HTTP 库管理连接复用
+var defaultAllowed = map[string]struct{}{
+ "content-type": {},
+ "content-encoding": {},
+ "content-language": {},
+ "cache-control": {},
+ "etag": {},
+ "last-modified": {},
+ "expires": {},
+ "vary": {},
+ "date": {},
+ "x-request-id": {},
+ "x-ratelimit-limit-requests": {},
+ "x-ratelimit-limit-tokens": {},
+ "x-ratelimit-remaining-requests": {},
+ "x-ratelimit-remaining-tokens": {},
+ "x-ratelimit-reset-requests": {},
+ "x-ratelimit-reset-tokens": {},
+ "retry-after": {},
+ "location": {},
+}
+
+// hopByHopHeaders 是跳过的 hop-by-hop 头部,这些头部由 HTTP 库自动处理
+var hopByHopHeaders = map[string]struct{}{
+ "content-length": {},
+ "transfer-encoding": {},
+ "connection": {},
+}
+
+func FilterHeaders(src http.Header, cfg config.ResponseHeaderConfig) http.Header {
+ allowed := make(map[string]struct{}, len(defaultAllowed)+len(cfg.AdditionalAllowed))
+ for key := range defaultAllowed {
+ allowed[key] = struct{}{}
+ }
+ for _, key := range cfg.AdditionalAllowed {
+ normalized := strings.ToLower(strings.TrimSpace(key))
+ if normalized == "" {
+ continue
+ }
+ allowed[normalized] = struct{}{}
+ }
+
+ forceRemove := make(map[string]struct{}, len(cfg.ForceRemove))
+ for _, key := range cfg.ForceRemove {
+ normalized := strings.ToLower(strings.TrimSpace(key))
+ if normalized == "" {
+ continue
+ }
+ forceRemove[normalized] = struct{}{}
+ }
+
+ filtered := make(http.Header, len(src))
+ for key, values := range src {
+ lower := strings.ToLower(key)
+ if _, blocked := forceRemove[lower]; blocked {
+ continue
+ }
+ if _, ok := allowed[lower]; !ok {
+ continue
+ }
+ // 跳过 hop-by-hop 头部,这些由 HTTP 库自动处理
+ if _, isHopByHop := hopByHopHeaders[lower]; isHopByHop {
+ continue
+ }
+ for _, value := range values {
+ filtered.Add(key, value)
+ }
+ }
+ return filtered
+}
+
+func WriteFilteredHeaders(dst http.Header, src http.Header, cfg config.ResponseHeaderConfig) {
+ filtered := FilterHeaders(src, cfg)
+ for key, values := range filtered {
+ for _, value := range values {
+ dst.Add(key, value)
+ }
+ }
+}
diff --git a/backend/internal/util/urlvalidator/validator.go b/backend/internal/util/urlvalidator/validator.go
new file mode 100644
index 00000000..b8f8c72f
--- /dev/null
+++ b/backend/internal/util/urlvalidator/validator.go
@@ -0,0 +1,121 @@
+package urlvalidator
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net"
+ "net/url"
+ "strings"
+ "time"
+)
+
+type ValidationOptions struct {
+ AllowedHosts []string
+ RequireAllowlist bool
+ AllowPrivate bool
+}
+
+func ValidateHTTPSURL(raw string, opts ValidationOptions) (string, error) {
+ trimmed := strings.TrimSpace(raw)
+ if trimmed == "" {
+ return "", errors.New("url is required")
+ }
+
+ parsed, err := url.Parse(trimmed)
+ if err != nil || parsed.Scheme == "" || parsed.Host == "" {
+ return "", fmt.Errorf("invalid url: %s", trimmed)
+ }
+ if !strings.EqualFold(parsed.Scheme, "https") {
+ return "", fmt.Errorf("invalid url scheme: %s", parsed.Scheme)
+ }
+
+ host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
+ if host == "" {
+ return "", errors.New("invalid host")
+ }
+ if !opts.AllowPrivate && isBlockedHost(host) {
+ return "", fmt.Errorf("host is not allowed: %s", host)
+ }
+
+ allowlist := normalizeAllowlist(opts.AllowedHosts)
+ if opts.RequireAllowlist && len(allowlist) == 0 {
+ return "", errors.New("allowlist is not configured")
+ }
+ if len(allowlist) > 0 && !isAllowedHost(host, allowlist) {
+ return "", fmt.Errorf("host is not allowed: %s", host)
+ }
+
+ parsed.Path = strings.TrimRight(parsed.Path, "/")
+ parsed.RawPath = ""
+ return strings.TrimRight(parsed.String(), "/"), nil
+}
+
+// ValidateResolvedIP 验证 DNS 解析后的 IP 地址是否安全
+// 用于防止 DNS Rebinding 攻击:在实际 HTTP 请求时调用此函数验证解析后的 IP
+func ValidateResolvedIP(host string) error {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
+ if err != nil {
+ return fmt.Errorf("dns resolution failed: %w", err)
+ }
+
+ for _, ip := range ips {
+ if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() ||
+ ip.IsLinkLocalMulticast() || ip.IsUnspecified() {
+ return fmt.Errorf("resolved ip %s is not allowed", ip.String())
+ }
+ }
+ return nil
+}
+
+func normalizeAllowlist(values []string) []string {
+ if len(values) == 0 {
+ return nil
+ }
+ normalized := make([]string, 0, len(values))
+ for _, v := range values {
+ entry := strings.ToLower(strings.TrimSpace(v))
+ if entry == "" {
+ continue
+ }
+ if host, _, err := net.SplitHostPort(entry); err == nil {
+ entry = host
+ }
+ normalized = append(normalized, entry)
+ }
+ return normalized
+}
+
+func isAllowedHost(host string, allowlist []string) bool {
+ for _, entry := range allowlist {
+ if entry == "" {
+ continue
+ }
+ if strings.HasPrefix(entry, "*.") {
+ suffix := strings.TrimPrefix(entry, "*.")
+ if host == suffix || strings.HasSuffix(host, "."+suffix) {
+ return true
+ }
+ continue
+ }
+ if host == entry {
+ return true
+ }
+ }
+ return false
+}
+
+func isBlockedHost(host string) bool {
+ if host == "localhost" || strings.HasSuffix(host, ".localhost") {
+ return true
+ }
+ if ip := net.ParseIP(host); ip != nil {
+ if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsUnspecified() {
+ return true
+ }
+ }
+ return false
+}
diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml
index 5bd85d7d..b62806c4 100644
--- a/deploy/config.example.yaml
+++ b/deploy/config.example.yaml
@@ -12,6 +12,8 @@ server:
port: 8080
# Mode: "debug" for development, "release" for production
mode: "release"
+ # Trusted proxies for X-Forwarded-For parsing (CIDR/IP). Empty disables trusted proxies.
+ trusted_proxies: []
# =============================================================================
# Run Mode Configuration
@@ -21,6 +23,48 @@ server:
# - simple: Hides SaaS features and skips billing/balance checks
run_mode: "standard"
+# =============================================================================
+# CORS Configuration
+# =============================================================================
+cors:
+ # Allowed origins list. Leave empty to disable cross-origin requests.
+ allowed_origins: []
+ # Allow credentials (cookies/authorization headers). Cannot be used with "*".
+ allow_credentials: true
+
+# =============================================================================
+# Security Configuration
+# =============================================================================
+security:
+ url_allowlist:
+ # Allowed upstream hosts for API proxying
+ upstream_hosts:
+ - "api.openai.com"
+ - "api.anthropic.com"
+ - "generativelanguage.googleapis.com"
+ - "cloudcode-pa.googleapis.com"
+ - "*.openai.azure.com"
+ # Allowed hosts for pricing data download
+ pricing_hosts:
+ - "raw.githubusercontent.com"
+ # Allowed hosts for CRS sync (required when using CRS sync)
+ crs_hosts: []
+ # Allow localhost/private IPs for upstream/pricing/CRS (use only in trusted networks)
+ allow_private_hosts: false
+ response_headers:
+ # Extra allowed response headers from upstream
+ additional_allowed: []
+ # Force-remove response headers from upstream
+ force_remove: []
+ csp:
+ # Enable Content-Security-Policy header
+ enabled: true
+ # Default CSP policy (override if you host assets on other domains)
+ policy: "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self' https:; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
+ proxy_probe:
+ # Allow skipping TLS verification for proxy probe (debug only)
+ insecure_skip_verify: false
+
# =============================================================================
# 网关配置
# =============================================================================
@@ -77,7 +121,7 @@ jwt:
# IMPORTANT: Change this to a random string in production!
# Generate with: openssl rand -hex 32
secret: "change-this-to-a-secure-random-string"
- # Token expiration time in hours
+ # Token expiration time in hours (max 24)
expire_hour: 24
# =============================================================================
@@ -122,6 +166,23 @@ pricing:
# Hash check interval in minutes
hash_check_interval_minutes: 10
+# =============================================================================
+# Billing Configuration
+# =============================================================================
+billing:
+ circuit_breaker:
+ enabled: true
+ failure_threshold: 5
+ reset_timeout_seconds: 30
+ half_open_requests: 3
+
+# =============================================================================
+# Turnstile Configuration
+# =============================================================================
+turnstile:
+ # Require Turnstile in release mode (when enabled, login/register will fail if not configured)
+ required: false
+
# =============================================================================
# Gemini OAuth (Required for Gemini accounts)
# =============================================================================
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index cf5cba6d..6c89f674 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -26,14 +26,37 @@ export interface SystemSettings {
smtp_host: string
smtp_port: number
smtp_username: string
- smtp_password: string
+ smtp_password_configured: boolean
smtp_from_email: string
smtp_from_name: string
smtp_use_tls: boolean
// Cloudflare Turnstile settings
turnstile_enabled: boolean
turnstile_site_key: string
- turnstile_secret_key: string
+ turnstile_secret_key_configured: boolean
+}
+
+export interface UpdateSettingsRequest {
+ registration_enabled?: boolean
+ email_verify_enabled?: boolean
+ default_balance?: number
+ default_concurrency?: number
+ site_name?: string
+ site_logo?: string
+ site_subtitle?: string
+ api_base_url?: string
+ contact_info?: string
+ doc_url?: string
+ smtp_host?: string
+ smtp_port?: number
+ smtp_username?: string
+ smtp_password?: string
+ smtp_from_email?: string
+ smtp_from_name?: string
+ smtp_use_tls?: boolean
+ turnstile_enabled?: boolean
+ turnstile_site_key?: string
+ turnstile_secret_key?: string
}
/**
@@ -50,7 +73,7 @@ export async function getSettings(): Promise {
* @param settings - Partial settings to update
* @returns Updated settings
*/
-export async function updateSettings(settings: Partial): Promise {
+export async function updateSettings(settings: UpdateSettingsRequest): Promise {
const { data } = await apiClient.put('/admin/settings', settings)
return data
}
diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts
index 3aac41a6..e8c0a44c 100644
--- a/frontend/src/api/client.ts
+++ b/frontend/src/api/client.ts
@@ -62,8 +62,24 @@ apiClient.interceptors.response.use(
// 401: Unauthorized - clear token and redirect to login
if (status === 401) {
+ const hasToken = !!localStorage.getItem('auth_token')
+ const url = error.config?.url || ''
+ const isAuthEndpoint =
+ url.includes('/auth/login') || url.includes('/auth/register') || url.includes('/auth/refresh')
+ const headers = error.config?.headers as Record | undefined
+ const authHeader = headers?.Authorization ?? headers?.authorization
+ const sentAuth =
+ typeof authHeader === 'string'
+ ? authHeader.trim() !== ''
+ : Array.isArray(authHeader)
+ ? authHeader.length > 0
+ : !!authHeader
+
localStorage.removeItem('auth_token')
localStorage.removeItem('auth_user')
+ if ((hasToken || sentAuth) && !isAuthEndpoint) {
+ sessionStorage.setItem('auth_expired', '1')
+ }
// Only redirect if not already on login page
if (!window.location.pathname.includes('/login')) {
window.location.href = '/login'
diff --git a/frontend/src/components/account/OAuthAuthorizationFlow.vue b/frontend/src/components/account/OAuthAuthorizationFlow.vue
index 41c316f5..7ce30b46 100644
--- a/frontend/src/components/account/OAuthAuthorizationFlow.vue
+++ b/frontend/src/components/account/OAuthAuthorizationFlow.vue
@@ -136,16 +136,16 @@
-
-
-
-
-
-
+ - {{ t('admin.accounts.oauth.step1') }}
+ - {{ t('admin.accounts.oauth.step2') }}
+ - {{ t('admin.accounts.oauth.step3') }}
+ - {{ t('admin.accounts.oauth.step4') }}
+ - {{ t('admin.accounts.oauth.step5') }}
+ - {{ t('admin.accounts.oauth.step6') }}
@@ -390,7 +390,7 @@
>
@@ -400,7 +400,7 @@
>
@@ -423,7 +423,7 @@
-
+
@@ -142,7 +142,6 @@ interface TabConfig {
interface FileConfig {
path: string
content: string
- highlighted: string
hint?: string // Optional hint message for this file
}
@@ -227,13 +226,6 @@ const platformNote = computed(() => {
})
// Syntax highlighting helpers
-const keyword = (text: string) => `${text}`
-const variable = (text: string) => `${text}`
-const string = (text: string) => `${text}`
-const operator = (text: string) => `${text}`
-const comment = (text: string) => `${text}`
-const key = (text: string) => `${text}`
-
// Generate file configs based on platform and active tab
const currentFiles = computed((): FileConfig[] => {
const baseUrl = props.baseUrl || window.location.origin
@@ -249,37 +241,29 @@ const currentFiles = computed((): FileConfig[] => {
function generateAnthropicFiles(baseUrl: string, apiKey: string): FileConfig[] {
let path: string
let content: string
- let highlighted: string
switch (activeTab.value) {
case 'unix':
path = 'Terminal'
content = `export ANTHROPIC_BASE_URL="${baseUrl}"
export ANTHROPIC_AUTH_TOKEN="${apiKey}"`
- highlighted = `${keyword('export')} ${variable('ANTHROPIC_BASE_URL')}${operator('=')}${string(`"${baseUrl}"`)}
-${keyword('export')} ${variable('ANTHROPIC_AUTH_TOKEN')}${operator('=')}${string(`"${apiKey}"`)}`
break
case 'cmd':
path = 'Command Prompt'
content = `set ANTHROPIC_BASE_URL=${baseUrl}
set ANTHROPIC_AUTH_TOKEN=${apiKey}`
- highlighted = `${keyword('set')} ${variable('ANTHROPIC_BASE_URL')}${operator('=')}${baseUrl}
-${keyword('set')} ${variable('ANTHROPIC_AUTH_TOKEN')}${operator('=')}${apiKey}`
break
case 'powershell':
path = 'PowerShell'
content = `$env:ANTHROPIC_BASE_URL="${baseUrl}"
$env:ANTHROPIC_AUTH_TOKEN="${apiKey}"`
- highlighted = `${keyword('$env:')}${variable('ANTHROPIC_BASE_URL')}${operator('=')}${string(`"${baseUrl}"`)}
-${keyword('$env:')}${variable('ANTHROPIC_AUTH_TOKEN')}${operator('=')}${string(`"${apiKey}"`)}`
break
default:
path = 'Terminal'
content = ''
- highlighted = ''
}
- return [{ path, content, highlighted }]
+ return [{ path, content }]
}
function generateOpenAIFiles(baseUrl: string, apiKey: string): FileConfig[] {
@@ -301,40 +285,20 @@ base_url = "${baseUrl}"
wire_api = "responses"
requires_openai_auth = true`
- const configHighlighted = `${key('model_provider')} ${operator('=')} ${string('"sub2api"')}
-${key('model')} ${operator('=')} ${string('"gpt-5.2-codex"')}
-${key('model_reasoning_effort')} ${operator('=')} ${string('"high"')}
-${key('network_access')} ${operator('=')} ${string('"enabled"')}
-${key('disable_response_storage')} ${operator('=')} ${keyword('true')}
-${key('windows_wsl_setup_acknowledged')} ${operator('=')} ${keyword('true')}
-${key('model_verbosity')} ${operator('=')} ${string('"high"')}
-
-${comment('[model_providers.sub2api]')}
-${key('name')} ${operator('=')} ${string('"sub2api"')}
-${key('base_url')} ${operator('=')} ${string(`"${baseUrl}"`)}
-${key('wire_api')} ${operator('=')} ${string('"responses"')}
-${key('requires_openai_auth')} ${operator('=')} ${keyword('true')}`
-
// auth.json content
const authContent = `{
"OPENAI_API_KEY": "${apiKey}"
}`
- const authHighlighted = `{
- ${key('"OPENAI_API_KEY"')}: ${string(`"${apiKey}"`)}
-}`
-
return [
{
path: `${configDir}/config.toml`,
content: configContent,
- highlighted: configHighlighted,
hint: t('keys.useKeyModal.openai.configTomlHint')
},
{
path: `${configDir}/auth.json`,
- content: authContent,
- highlighted: authHighlighted
+ content: authContent
}
]
}
diff --git a/frontend/src/components/layout/AuthLayout.vue b/frontend/src/components/layout/AuthLayout.vue
index 1a0cfec7..3cfc1d4d 100644
--- a/frontend/src/components/layout/AuthLayout.vue
+++ b/frontend/src/components/layout/AuthLayout.vue
@@ -63,6 +63,7 @@
diff --git a/frontend/src/views/user/KeysView.vue b/frontend/src/views/user/KeysView.vue
index bc8986f6..1a3b584b 100644
--- a/frontend/src/views/user/KeysView.vue
+++ b/frontend/src/views/user/KeysView.vue
@@ -335,12 +335,14 @@
/>
{{ t('keys.selectGroup') }}
-
-
+
@@ -516,26 +518,19 @@
? 'bg-primary-50 dark:bg-primary-900/20'
: 'hover:bg-gray-100 dark:hover:bg-dark-700'
]"
+ :title="option.description || undefined"
>
-
-
+ />
@@ -562,6 +557,7 @@ import EmptyState from '@/components/common/EmptyState.vue'
import Select from '@/components/common/Select.vue'
import UseKeyModal from '@/components/keys/UseKeyModal.vue'
import GroupBadge from '@/components/common/GroupBadge.vue'
+import GroupOptionItem from '@/components/common/GroupOptionItem.vue'
import type { ApiKey, Group, PublicSettings, SubscriptionType, GroupPlatform } from '@/types'
import type { Column } from '@/components/common/types'
import type { BatchApiKeyUsageStats } from '@/api/usage'
@@ -570,6 +566,7 @@ import { formatDateTime } from '@/utils/format'
interface GroupOption {
value: number
label: string
+ description: string | null
rate: number
subscriptionType: SubscriptionType
platform: GroupPlatform
@@ -665,6 +662,7 @@ const groupOptions = computed(() =>
groups.value.map((group) => ({
value: group.id,
label: group.name,
+ description: group.description,
rate: group.rate_multiplier,
subscriptionType: group.subscription_type,
platform: group.platform
diff --git a/frontend/vite.config.js b/frontend/vite.config.js
deleted file mode 100644
index efcf347a..00000000
--- a/frontend/vite.config.js
+++ /dev/null
@@ -1,36 +0,0 @@
-import { defineConfig } from 'vite';
-import vue from '@vitejs/plugin-vue';
-import checker from 'vite-plugin-checker';
-import { resolve } from 'path';
-export default defineConfig({
- plugins: [
- vue(),
- checker({
- typescript: true,
- vueTsc: true
- })
- ],
- resolve: {
- alias: {
- '@': resolve(__dirname, 'src')
- }
- },
- build: {
- outDir: '../backend/internal/web/dist',
- emptyOutDir: true
- },
- server: {
- host: '0.0.0.0',
- port: 3000,
- proxy: {
- '/api': {
- target: 'http://localhost:8080',
- changeOrigin: true
- },
- '/setup': {
- target: 'http://localhost:8080',
- changeOrigin: true
- }
- }
- }
-});
From 5dd8b8802bd27d3234e5c965abe7f20e998c0411 Mon Sep 17 00:00:00 2001
From: yangjianbo
Date: Sun, 4 Jan 2026 22:10:32 +0800
Subject: [PATCH 6/7] =?UTF-8?q?fix(=E5=90=8E=E7=AB=AF):=20=E4=BF=AE?=
=?UTF-8?q?=E5=A4=8D=20lint=20=E5=A4=B1=E8=B4=A5=E5=B9=B6=E6=B8=85?=
=?UTF-8?q?=E7=90=86=E6=97=A0=E7=94=A8=E4=BB=A3=E7=A0=81?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
修正测试中的 APIKey 名称引用
移除不可达返回与未使用函数
统一 gofmt 格式并处理 Close 错误
---
backend/internal/config/config.go | 6 ++--
backend/internal/handler/gateway_handler.go | 2 +-
.../repository/claude_oauth_service.go | 10 ------
.../middleware/api_key_auth_google_test.go | 14 ++++----
.../internal/service/billing_cache_service.go | 35 +++----------------
backend/internal/service/gateway_service.go | 1 -
.../service/openai_gateway_service.go | 1 -
.../service/openai_gateway_service_test.go | 2 +-
8 files changed, 17 insertions(+), 54 deletions(-)
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index a1d80ad6..1d8c64ef 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -497,9 +497,9 @@ func setDefaults() {
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
- viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认)
- viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认)
- viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃,HTTP/2 场景默认)
+ viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认)
+ viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认)
+ viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃,HTTP/2 场景默认)
viper.SetDefault("gateway.idle_conn_timeout_seconds", 90) // 空闲连接超时(秒)
viper.SetDefault("gateway.max_upstream_clients", 5000)
viper.SetDefault("gateway.client_idle_ttl_seconds", 900)
diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go
index 9528d9c0..de3cbad9 100644
--- a/backend/internal/handler/gateway_handler.go
+++ b/backend/internal/handler/gateway_handler.go
@@ -13,8 +13,8 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
- pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
+ pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
diff --git a/backend/internal/repository/claude_oauth_service.go b/backend/internal/repository/claude_oauth_service.go
index 8595e783..677fce52 100644
--- a/backend/internal/repository/claude_oauth_service.go
+++ b/backend/internal/repository/claude_oauth_service.go
@@ -246,13 +246,3 @@ func createReqClient(proxyURL string) *req.Client {
return client
}
-
-func prefix(s string, n int) string {
- if n <= 0 {
- return ""
- }
- if len(s) <= n {
- return s
- }
- return s[:n]
-}
diff --git a/backend/internal/server/middleware/api_key_auth_google_test.go b/backend/internal/server/middleware/api_key_auth_google_test.go
index 9397406e..0ed5a4a2 100644
--- a/backend/internal/server/middleware/api_key_auth_google_test.go
+++ b/backend/internal/server/middleware/api_key_auth_google_test.go
@@ -113,12 +113,12 @@ func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
- apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
- getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
+ apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
+ getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return nil, errors.New("should not be called")
},
})
- r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
+ r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test?api_key=legacy", nil)
@@ -137,9 +137,9 @@ func TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta(t *testing.T)
gin.SetMode(gin.TestMode)
r := gin.New()
- apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
- getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
- return &service.ApiKey{
+ apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
+ getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
+ return &service.APIKey{
ID: 1,
Key: key,
Status: service.StatusActive,
@@ -151,7 +151,7 @@ func TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta(t *testing.T)
},
})
cfg := &config.Config{RunMode: config.RunModeSimple}
- r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
+ r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test?key=valid", nil)
diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go
index 8112090f..c09cafb9 100644
--- a/backend/internal/service/billing_cache_service.go
+++ b/backend/internal/service/billing_cache_service.go
@@ -16,7 +16,7 @@ import (
// 注:ErrInsufficientBalance在redeem_service.go中定义
// 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
var (
- ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
+ ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
)
@@ -73,10 +73,10 @@ type cacheWriteTask struct {
// BillingCacheService 计费缓存服务
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
type BillingCacheService struct {
- cache BillingCache
- userRepo UserRepository
- subRepo UserSubscriptionRepository
- cfg *config.Config
+ cache BillingCache
+ userRepo UserRepository
+ subRepo UserSubscriptionRepository
+ cfg *config.Config
circuitBreaker *billingCircuitBreaker
cacheWriteChan chan cacheWriteTask
@@ -659,28 +659,3 @@ func circuitStateString(state billingCircuitBreakerState) string {
return "unknown"
}
}
-
-// checkSubscriptionLimitsFallback 降级检查订阅限额
-func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *UserSubscription, group *Group) error {
- if subscription == nil {
- return ErrSubscriptionInvalid
- }
-
- if !subscription.IsActive() {
- return ErrSubscriptionInvalid
- }
-
- if !subscription.CheckDailyLimit(group, 0) {
- return ErrDailyLimitExceeded
- }
-
- if !subscription.CheckWeeklyLimit(group, 0) {
- return ErrWeeklyLimitExceeded
- }
-
- if !subscription.CheckMonthlyLimit(group, 0) {
- return ErrMonthlyLimitExceeded
- }
-
- return nil
-}
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index cce76918..75a157c8 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -1731,7 +1731,6 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
}
- return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
// replaceModelInSSELine 替换SSE数据行中的model字段
diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go
index 9a4a470c..b9cf4b9e 100644
--- a/backend/internal/service/openai_gateway_service.go
+++ b/backend/internal/service/openai_gateway_service.go
@@ -934,7 +934,6 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
}
}
- return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go
index dd8ca6b6..bcad7ac8 100644
--- a/backend/internal/service/openai_gateway_service_test.go
+++ b/backend/internal/service/openai_gateway_service_test.go
@@ -72,7 +72,7 @@ func TestOpenAIStreamingTooLong(t *testing.T) {
}
go func() {
- defer pw.Close()
+ defer func() { _ = pw.Close() }()
// 写入超过 MaxLineSize 的单行数据,触发 ErrTooLong
payload := "data: " + strings.Repeat("a", 128*1024) + "\n"
_, _ = pw.Write([]byte(payload))
From f8e7255c32196c67b28cfbf91a58ff5a3265329d Mon Sep 17 00:00:00 2001
From: yangjianbo
Date: Sun, 4 Jan 2026 22:19:11 +0800
Subject: [PATCH 7/7] =?UTF-8?q?feat(=E7=95=8C=E9=9D=A2):=20=E4=B8=BA=20Gem?=
=?UTF-8?q?ini=20=E9=85=8D=E7=BD=AE=E7=89=87=E6=AE=B5=E6=B7=BB=E5=8A=A0?=
=?UTF-8?q?=E8=AF=AD=E6=B3=95=E9=AB=98=E4=BA=AE?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
补齐高亮渲染并保留纯文本回退
新增高亮 token 工具并做 HTML 转义
---
frontend/src/components/keys/UseKeyModal.vue | 28 +++++++++++++++++---
1 file changed, 24 insertions(+), 4 deletions(-)
diff --git a/frontend/src/components/keys/UseKeyModal.vue b/frontend/src/components/keys/UseKeyModal.vue
index dabb60af..3d687b5a 100644
--- a/frontend/src/components/keys/UseKeyModal.vue
+++ b/frontend/src/components/keys/UseKeyModal.vue
@@ -107,7 +107,10 @@
-
+
+
+
+
@@ -165,6 +168,7 @@ interface FileConfig {
path: string
content: string
hint?: string // Optional hint message for this file
+ highlighted?: string
}
const props = defineProps()
@@ -310,6 +314,22 @@ const platformNote = computed(() => {
}
})
+const escapeHtml = (value: string) => value
+ .replace(/&/g, '&')
+ .replace(//g, '>')
+ .replace(/"/g, '"')
+ .replace(/'/g, ''')
+
+const wrapToken = (className: string, value: string) =>
+ `${escapeHtml(value)}`
+
+const keyword = (value: string) => wrapToken('text-emerald-300', value)
+const variable = (value: string) => wrapToken('text-sky-200', value)
+const operator = (value: string) => wrapToken('text-slate-400', value)
+const string = (value: string) => wrapToken('text-amber-200', value)
+const comment = (value: string) => wrapToken('text-slate-500', value)
+
// Syntax highlighting helpers
// Generate file configs based on platform and active tab
const currentFiles = computed((): FileConfig[] => {
@@ -382,9 +402,9 @@ ${keyword('export')} ${variable('GEMINI_MODEL')}${operator('=')}${string(`"${mod
content = `set GOOGLE_GEMINI_BASE_URL=${baseUrl}
set GEMINI_API_KEY=${apiKey}
set GEMINI_MODEL=${model}`
- highlighted = `${keyword('set')} ${variable('GOOGLE_GEMINI_BASE_URL')}${operator('=')}${baseUrl}
-${keyword('set')} ${variable('GEMINI_API_KEY')}${operator('=')}${apiKey}
-${keyword('set')} ${variable('GEMINI_MODEL')}${operator('=')}${model}
+ highlighted = `${keyword('set')} ${variable('GOOGLE_GEMINI_BASE_URL')}${operator('=')}${string(baseUrl)}
+${keyword('set')} ${variable('GEMINI_API_KEY')}${operator('=')}${string(apiKey)}
+${keyword('set')} ${variable('GEMINI_MODEL')}${operator('=')}${string(model)}
${comment(`REM ${modelComment}`)}`
break
case 'powershell':