feat(安全): 强化安全策略与配置校验
- 增加 CORS/CSP/安全响应头与代理信任配置 - 引入 URL 白名单与私网开关,校验上游与价格源 - 改善 API Key 处理与网关错误返回 - 管理端设置隐藏敏感字段并优化前端提示 - 增加计费熔断与相关配置示例 测试: go test ./...
This commit is contained in:
@@ -268,6 +268,15 @@ default:
|
||||
rate_multiplier: 1.0
|
||||
```
|
||||
|
||||
Additional security-related options are available in `config.yaml`:
|
||||
|
||||
- `cors.allowed_origins` for CORS allowlist
|
||||
- `security.url_allowlist` for upstream/pricing/CRS host allowlists
|
||||
- `security.csp` to control Content-Security-Policy headers
|
||||
- `billing.circuit_breaker` to fail closed on billing errors
|
||||
- `server.trusted_proxies` to enable X-Forwarded-For parsing
|
||||
- `turnstile.required` to require Turnstile in release mode
|
||||
|
||||
```bash
|
||||
# 6. Run the application
|
||||
./sub2api
|
||||
|
||||
@@ -268,6 +268,15 @@ default:
|
||||
rate_multiplier: 1.0
|
||||
```
|
||||
|
||||
`config.yaml` 还支持以下安全相关配置:
|
||||
|
||||
- `cors.allowed_origins` 配置 CORS 白名单
|
||||
- `security.url_allowlist` 配置上游/价格数据/CRS 主机白名单
|
||||
- `security.csp` 配置 Content-Security-Policy
|
||||
- `billing.circuit_breaker` 计费异常时 fail-closed
|
||||
- `server.trusted_proxies` 启用可信代理解析 X-Forwarded-For
|
||||
- `turnstile.required` 在 release 模式强制启用 Turnstile
|
||||
|
||||
```bash
|
||||
# 6. 运行应用
|
||||
./sub2api
|
||||
|
||||
@@ -86,7 +86,8 @@ func main() {
|
||||
func runSetupServer() {
|
||||
r := gin.New()
|
||||
r.Use(middleware.Recovery())
|
||||
r.Use(middleware.CORS())
|
||||
r.Use(middleware.CORS(config.CORSConfig{}))
|
||||
r.Use(middleware.SecurityHeaders(config.CSPConfig{Enabled: true, Policy: config.DefaultCSPPolicy}))
|
||||
|
||||
// Register setup routes
|
||||
setup.RegisterRoutes(r)
|
||||
|
||||
@@ -77,7 +77,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
dashboardHandler := admin.NewDashboardHandler(dashboardService)
|
||||
accountRepository := repository.NewAccountRepository(client, db)
|
||||
proxyRepository := repository.NewProxyRepository(client, db)
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber()
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber)
|
||||
adminUserHandler := admin.NewUserHandler(adminService)
|
||||
groupHandler := admin.NewGroupHandler(adminService)
|
||||
@@ -98,10 +98,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
|
||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||
concurrencyService := service.NewConcurrencyService(concurrencyCache)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
|
||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||
@@ -129,7 +129,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
timingWheelService := service.ProvideTimingWheelService()
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
|
||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
@@ -12,6 +15,8 @@ const (
|
||||
RunModeSimple = "simple"
|
||||
)
|
||||
|
||||
const DefaultCSPPolicy = "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self' https:; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
|
||||
|
||||
// 连接池隔离策略常量
|
||||
// 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗
|
||||
const (
|
||||
@@ -28,6 +33,10 @@ const (
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
CORS CORSConfig `mapstructure:"cors"`
|
||||
Security SecurityConfig `mapstructure:"security"`
|
||||
Billing BillingConfig `mapstructure:"billing"`
|
||||
Turnstile TurnstileConfig `mapstructure:"turnstile"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
@@ -81,11 +90,56 @@ type PricingConfig struct {
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
Mode string `mapstructure:"mode"` // debug/release
|
||||
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
|
||||
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
Mode string `mapstructure:"mode"` // debug/release
|
||||
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
|
||||
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
|
||||
TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP)
|
||||
}
|
||||
|
||||
type CORSConfig struct {
|
||||
AllowedOrigins []string `mapstructure:"allowed_origins"`
|
||||
AllowCredentials bool `mapstructure:"allow_credentials"`
|
||||
}
|
||||
|
||||
type SecurityConfig struct {
|
||||
URLAllowlist URLAllowlistConfig `mapstructure:"url_allowlist"`
|
||||
ResponseHeaders ResponseHeaderConfig `mapstructure:"response_headers"`
|
||||
CSP CSPConfig `mapstructure:"csp"`
|
||||
ProxyProbe ProxyProbeConfig `mapstructure:"proxy_probe"`
|
||||
}
|
||||
|
||||
type URLAllowlistConfig struct {
|
||||
UpstreamHosts []string `mapstructure:"upstream_hosts"`
|
||||
PricingHosts []string `mapstructure:"pricing_hosts"`
|
||||
CRSHosts []string `mapstructure:"crs_hosts"`
|
||||
AllowPrivateHosts bool `mapstructure:"allow_private_hosts"`
|
||||
}
|
||||
|
||||
type ResponseHeaderConfig struct {
|
||||
AdditionalAllowed []string `mapstructure:"additional_allowed"`
|
||||
ForceRemove []string `mapstructure:"force_remove"`
|
||||
}
|
||||
|
||||
type CSPConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
Policy string `mapstructure:"policy"`
|
||||
}
|
||||
|
||||
type ProxyProbeConfig struct {
|
||||
InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"`
|
||||
}
|
||||
|
||||
type BillingConfig struct {
|
||||
CircuitBreaker CircuitBreakerConfig `mapstructure:"circuit_breaker"`
|
||||
}
|
||||
|
||||
type CircuitBreakerConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
FailureThreshold int `mapstructure:"failure_threshold"`
|
||||
ResetTimeoutSeconds int `mapstructure:"reset_timeout_seconds"`
|
||||
HalfOpenRequests int `mapstructure:"half_open_requests"`
|
||||
}
|
||||
|
||||
// GatewayConfig API网关相关配置
|
||||
@@ -192,6 +246,10 @@ type JWTConfig struct {
|
||||
ExpireHour int `mapstructure:"expire_hour"`
|
||||
}
|
||||
|
||||
type TurnstileConfig struct {
|
||||
Required bool `mapstructure:"required"`
|
||||
}
|
||||
|
||||
type DefaultConfig struct {
|
||||
AdminEmail string `mapstructure:"admin_email"`
|
||||
AdminPassword string `mapstructure:"admin_password"`
|
||||
@@ -242,11 +300,39 @@ func Load() (*Config, error) {
|
||||
}
|
||||
|
||||
cfg.RunMode = NormalizeRunMode(cfg.RunMode)
|
||||
cfg.Server.Mode = strings.ToLower(strings.TrimSpace(cfg.Server.Mode))
|
||||
if cfg.Server.Mode == "" {
|
||||
cfg.Server.Mode = "debug"
|
||||
}
|
||||
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
|
||||
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
|
||||
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
|
||||
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
|
||||
cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy)
|
||||
|
||||
if cfg.Server.Mode != "release" && cfg.JWT.Secret == "" {
|
||||
secret, err := generateJWTSecret(64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate jwt secret error: %w", err)
|
||||
}
|
||||
cfg.JWT.Secret = secret
|
||||
log.Println("Warning: JWT secret auto-generated for non-release mode. Do not use in production.")
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("validate config error: %w", err)
|
||||
}
|
||||
|
||||
if cfg.Server.Mode != "release" && cfg.JWT.Secret != "" && isWeakJWTSecret(cfg.JWT.Secret) {
|
||||
log.Println("Warning: JWT secret appears weak; use a 32+ character random secret in production.")
|
||||
}
|
||||
if len(cfg.Security.ResponseHeaders.AdditionalAllowed) > 0 || len(cfg.Security.ResponseHeaders.ForceRemove) > 0 {
|
||||
log.Printf("AUDIT: response header policy configured additional_allowed=%v force_remove=%v",
|
||||
cfg.Security.ResponseHeaders.AdditionalAllowed,
|
||||
cfg.Security.ResponseHeaders.ForceRemove,
|
||||
)
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
@@ -259,6 +345,39 @@ func setDefaults() {
|
||||
viper.SetDefault("server.mode", "debug")
|
||||
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
|
||||
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
|
||||
viper.SetDefault("server.trusted_proxies", []string{})
|
||||
|
||||
// CORS
|
||||
viper.SetDefault("cors.allowed_origins", []string{})
|
||||
viper.SetDefault("cors.allow_credentials", true)
|
||||
|
||||
// Security
|
||||
viper.SetDefault("security.url_allowlist.upstream_hosts", []string{
|
||||
"api.openai.com",
|
||||
"api.anthropic.com",
|
||||
"generativelanguage.googleapis.com",
|
||||
"cloudcode-pa.googleapis.com",
|
||||
"*.openai.azure.com",
|
||||
})
|
||||
viper.SetDefault("security.url_allowlist.pricing_hosts", []string{
|
||||
"raw.githubusercontent.com",
|
||||
})
|
||||
viper.SetDefault("security.url_allowlist.crs_hosts", []string{})
|
||||
viper.SetDefault("security.url_allowlist.allow_private_hosts", false)
|
||||
viper.SetDefault("security.response_headers.additional_allowed", []string{})
|
||||
viper.SetDefault("security.response_headers.force_remove", []string{})
|
||||
viper.SetDefault("security.csp.enabled", true)
|
||||
viper.SetDefault("security.csp.policy", DefaultCSPPolicy)
|
||||
viper.SetDefault("security.proxy_probe.insecure_skip_verify", false)
|
||||
|
||||
// Billing
|
||||
viper.SetDefault("billing.circuit_breaker.enabled", true)
|
||||
viper.SetDefault("billing.circuit_breaker.failure_threshold", 5)
|
||||
viper.SetDefault("billing.circuit_breaker.reset_timeout_seconds", 30)
|
||||
viper.SetDefault("billing.circuit_breaker.half_open_requests", 3)
|
||||
|
||||
// Turnstile
|
||||
viper.SetDefault("turnstile.required", false)
|
||||
|
||||
// Database
|
||||
viper.SetDefault("database.host", "localhost")
|
||||
@@ -284,7 +403,7 @@ func setDefaults() {
|
||||
viper.SetDefault("redis.min_idle_conns", 10)
|
||||
|
||||
// JWT
|
||||
viper.SetDefault("jwt.secret", "change-me-in-production")
|
||||
viper.SetDefault("jwt.secret", "")
|
||||
viper.SetDefault("jwt.expire_hour", 24)
|
||||
|
||||
// Default
|
||||
@@ -340,11 +459,39 @@ func setDefaults() {
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error {
|
||||
if c.JWT.Secret == "" {
|
||||
return fmt.Errorf("jwt.secret is required")
|
||||
if c.Server.Mode == "release" {
|
||||
if c.JWT.Secret == "" {
|
||||
return fmt.Errorf("jwt.secret is required in release mode")
|
||||
}
|
||||
if len(c.JWT.Secret) < 32 {
|
||||
return fmt.Errorf("jwt.secret must be at least 32 characters")
|
||||
}
|
||||
if isWeakJWTSecret(c.JWT.Secret) {
|
||||
return fmt.Errorf("jwt.secret is too weak")
|
||||
}
|
||||
}
|
||||
if c.JWT.Secret == "change-me-in-production" && c.Server.Mode == "release" {
|
||||
return fmt.Errorf("jwt.secret must be changed in production")
|
||||
if c.JWT.ExpireHour <= 0 {
|
||||
return fmt.Errorf("jwt.expire_hour must be positive")
|
||||
}
|
||||
if c.JWT.ExpireHour > 168 {
|
||||
return fmt.Errorf("jwt.expire_hour must be <= 168 (7 days)")
|
||||
}
|
||||
if c.JWT.ExpireHour > 24 {
|
||||
log.Printf("Warning: jwt.expire_hour is %d hours (> 24). Consider shorter expiration for security.", c.JWT.ExpireHour)
|
||||
}
|
||||
if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" {
|
||||
return fmt.Errorf("security.csp.policy is required when CSP is enabled")
|
||||
}
|
||||
if c.Billing.CircuitBreaker.Enabled {
|
||||
if c.Billing.CircuitBreaker.FailureThreshold <= 0 {
|
||||
return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive")
|
||||
}
|
||||
if c.Billing.CircuitBreaker.ResetTimeoutSeconds <= 0 {
|
||||
return fmt.Errorf("billing.circuit_breaker.reset_timeout_seconds must be positive")
|
||||
}
|
||||
if c.Billing.CircuitBreaker.HalfOpenRequests <= 0 {
|
||||
return fmt.Errorf("billing.circuit_breaker.half_open_requests must be positive")
|
||||
}
|
||||
}
|
||||
if c.Database.MaxOpenConns <= 0 {
|
||||
return fmt.Errorf("database.max_open_conns must be positive")
|
||||
@@ -414,6 +561,51 @@ func (c *Config) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeStringSlice(values []string) []string {
|
||||
if len(values) == 0 {
|
||||
return values
|
||||
}
|
||||
normalized := make([]string, 0, len(values))
|
||||
for _, v := range values {
|
||||
trimmed := strings.TrimSpace(v)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
normalized = append(normalized, trimmed)
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func isWeakJWTSecret(secret string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(secret))
|
||||
if lower == "" {
|
||||
return true
|
||||
}
|
||||
weak := map[string]struct{}{
|
||||
"change-me-in-production": {},
|
||||
"changeme": {},
|
||||
"secret": {},
|
||||
"password": {},
|
||||
"123456": {},
|
||||
"12345678": {},
|
||||
"admin": {},
|
||||
"jwt-secret": {},
|
||||
}
|
||||
_, exists := weak[lower]
|
||||
return exists
|
||||
}
|
||||
|
||||
func generateJWTSecret(byteLength int) (string, error) {
|
||||
if byteLength <= 0 {
|
||||
byteLength = 32
|
||||
}
|
||||
buf := make([]byte, byteLength)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
// GetServerAddress returns the server address (host:port) from config file or environment variable.
|
||||
// This is a lightweight function that can be used before full config validation,
|
||||
// such as during setup wizard startup.
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -37,13 +41,13 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
SmtpHost: settings.SmtpHost,
|
||||
SmtpPort: settings.SmtpPort,
|
||||
SmtpUsername: settings.SmtpUsername,
|
||||
SmtpPassword: settings.SmtpPassword,
|
||||
SmtpPasswordConfigured: settings.SmtpPasswordConfigured,
|
||||
SmtpFrom: settings.SmtpFrom,
|
||||
SmtpFromName: settings.SmtpFromName,
|
||||
SmtpUseTLS: settings.SmtpUseTLS,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
TurnstileSecretKey: settings.TurnstileSecretKey,
|
||||
TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
@@ -97,6 +101,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
previousSettings, err := h.settingService.GetAllSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证参数
|
||||
if req.DefaultConcurrency < 1 {
|
||||
req.DefaultConcurrency = 1
|
||||
@@ -136,6 +146,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
h.auditSettingsUpdate(c, previousSettings, settings, req)
|
||||
|
||||
// 重新获取设置返回
|
||||
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
@@ -149,13 +161,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
SmtpHost: updatedSettings.SmtpHost,
|
||||
SmtpPort: updatedSettings.SmtpPort,
|
||||
SmtpUsername: updatedSettings.SmtpUsername,
|
||||
SmtpPassword: updatedSettings.SmtpPassword,
|
||||
SmtpPasswordConfigured: updatedSettings.SmtpPasswordConfigured,
|
||||
SmtpFrom: updatedSettings.SmtpFrom,
|
||||
SmtpFromName: updatedSettings.SmtpFromName,
|
||||
SmtpUseTLS: updatedSettings.SmtpUseTLS,
|
||||
TurnstileEnabled: updatedSettings.TurnstileEnabled,
|
||||
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
|
||||
TurnstileSecretKey: updatedSettings.TurnstileSecretKey,
|
||||
TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
|
||||
SiteName: updatedSettings.SiteName,
|
||||
SiteLogo: updatedSettings.SiteLogo,
|
||||
SiteSubtitle: updatedSettings.SiteSubtitle,
|
||||
@@ -167,6 +179,91 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) {
|
||||
if before == nil || after == nil {
|
||||
return
|
||||
}
|
||||
|
||||
changed := diffSettings(before, after, req)
|
||||
if len(changed) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
subject, _ := middleware.GetAuthSubjectFromContext(c)
|
||||
role, _ := middleware.GetUserRoleFromContext(c)
|
||||
log.Printf("AUDIT: settings updated at=%s user_id=%d role=%s changed=%v",
|
||||
time.Now().UTC().Format(time.RFC3339),
|
||||
subject.UserID,
|
||||
role,
|
||||
changed,
|
||||
)
|
||||
}
|
||||
|
||||
func diffSettings(before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) []string {
|
||||
changed := make([]string, 0, 16)
|
||||
if before.RegistrationEnabled != after.RegistrationEnabled {
|
||||
changed = append(changed, "registration_enabled")
|
||||
}
|
||||
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
|
||||
changed = append(changed, "email_verify_enabled")
|
||||
}
|
||||
if before.SmtpHost != after.SmtpHost {
|
||||
changed = append(changed, "smtp_host")
|
||||
}
|
||||
if before.SmtpPort != after.SmtpPort {
|
||||
changed = append(changed, "smtp_port")
|
||||
}
|
||||
if before.SmtpUsername != after.SmtpUsername {
|
||||
changed = append(changed, "smtp_username")
|
||||
}
|
||||
if req.SmtpPassword != "" {
|
||||
changed = append(changed, "smtp_password")
|
||||
}
|
||||
if before.SmtpFrom != after.SmtpFrom {
|
||||
changed = append(changed, "smtp_from_email")
|
||||
}
|
||||
if before.SmtpFromName != after.SmtpFromName {
|
||||
changed = append(changed, "smtp_from_name")
|
||||
}
|
||||
if before.SmtpUseTLS != after.SmtpUseTLS {
|
||||
changed = append(changed, "smtp_use_tls")
|
||||
}
|
||||
if before.TurnstileEnabled != after.TurnstileEnabled {
|
||||
changed = append(changed, "turnstile_enabled")
|
||||
}
|
||||
if before.TurnstileSiteKey != after.TurnstileSiteKey {
|
||||
changed = append(changed, "turnstile_site_key")
|
||||
}
|
||||
if req.TurnstileSecretKey != "" {
|
||||
changed = append(changed, "turnstile_secret_key")
|
||||
}
|
||||
if before.SiteName != after.SiteName {
|
||||
changed = append(changed, "site_name")
|
||||
}
|
||||
if before.SiteLogo != after.SiteLogo {
|
||||
changed = append(changed, "site_logo")
|
||||
}
|
||||
if before.SiteSubtitle != after.SiteSubtitle {
|
||||
changed = append(changed, "site_subtitle")
|
||||
}
|
||||
if before.ApiBaseUrl != after.ApiBaseUrl {
|
||||
changed = append(changed, "api_base_url")
|
||||
}
|
||||
if before.ContactInfo != after.ContactInfo {
|
||||
changed = append(changed, "contact_info")
|
||||
}
|
||||
if before.DocUrl != after.DocUrl {
|
||||
changed = append(changed, "doc_url")
|
||||
}
|
||||
if before.DefaultConcurrency != after.DefaultConcurrency {
|
||||
changed = append(changed, "default_concurrency")
|
||||
}
|
||||
if before.DefaultBalance != after.DefaultBalance {
|
||||
changed = append(changed, "default_balance")
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
// TestSmtpRequest 测试SMTP连接请求
|
||||
type TestSmtpRequest struct {
|
||||
SmtpHost string `json:"smtp_host" binding:"required"`
|
||||
|
||||
@@ -8,14 +8,14 @@ type SystemSettings struct {
|
||||
SmtpHost string `json:"smtp_host"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password,omitempty"`
|
||||
SmtpPasswordConfigured bool `json:"smtp_password_configured"`
|
||||
SmtpFrom string `json:"smtp_from_email"`
|
||||
SmtpFromName string `json:"smtp_from_name"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"`
|
||||
TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"`
|
||||
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
@@ -127,7 +128,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 2. 【新增】Wait后二次检查余额/订阅
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
log.Printf("Billing eligibility check failed after wait: %v", err)
|
||||
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -535,7 +537,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
// 校验 billing eligibility(订阅/余额)
|
||||
// 【注意】不计算并发,但需要校验订阅/余额
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
h.errorResponse(c, http.StatusForbidden, "billing_error", err.Error())
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.errorResponse(c, status, code, message)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -642,3 +645,18 @@ func sendMockWarmupResponse(c *gin.Context, model string) {
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func billingErrorDetails(err error) (status int, code, message string) {
|
||||
if errors.Is(err, service.ErrBillingServiceUnavailable) {
|
||||
msg := infraerrors.Message(err)
|
||||
if msg == "" {
|
||||
msg = "Billing service temporarily unavailable. Please retry later."
|
||||
}
|
||||
return http.StatusServiceUnavailable, "billing_service_error", msg
|
||||
}
|
||||
msg := infraerrors.Message(err)
|
||||
if msg == "" {
|
||||
msg = err.Error()
|
||||
}
|
||||
return http.StatusForbidden, "billing_error", msg
|
||||
}
|
||||
|
||||
@@ -190,7 +190,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
|
||||
// 2) billing eligibility check (after wait)
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
googleError(c, http.StatusForbidden, err.Error())
|
||||
status, _, message := billingErrorDetails(err)
|
||||
googleError(c, status, message)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -329,7 +330,7 @@ func writeUpstreamResponse(c *gin.Context, res *service.UpstreamHTTPResult) {
|
||||
}
|
||||
for k, vv := range res.Headers {
|
||||
// Avoid overriding content-length and hop-by-hop headers.
|
||||
if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Connection") {
|
||||
if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Connection") || strings.EqualFold(k, "Www-Authenticate") {
|
||||
continue
|
||||
}
|
||||
for _, v := range vv {
|
||||
|
||||
@@ -131,7 +131,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// 2. Re-check billing eligibility after wait
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
log.Printf("Billing eligibility check failed after wait: %v", err)
|
||||
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
@@ -54,7 +55,7 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
||||
log.Printf("[OAuth] Step 1 Response - Status: %d", resp.StatusCode)
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
@@ -84,8 +85,8 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
|
||||
reqBodyJSON, _ := json.Marshal(reqBody)
|
||||
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
|
||||
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
|
||||
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
|
||||
|
||||
var result struct {
|
||||
@@ -113,7 +114,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
||||
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
@@ -141,7 +142,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
|
||||
fullCode = authCode + "#" + responseState
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", prefix(authCode, 20))
|
||||
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code")
|
||||
return fullCode, nil
|
||||
}
|
||||
|
||||
@@ -173,8 +174,8 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
||||
reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds
|
||||
}
|
||||
|
||||
reqBodyJSON, _ := json.Marshal(reqBody)
|
||||
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
|
||||
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
|
||||
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
|
||||
|
||||
var tokenResp oauth.TokenResponse
|
||||
@@ -191,7 +192,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
||||
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
|
||||
@@ -8,25 +8,38 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"log"
|
||||
)
|
||||
|
||||
func NewProxyExitInfoProber() service.ProxyExitInfoProber {
|
||||
return &proxyProbeService{ipInfoURL: defaultIPInfoURL}
|
||||
func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
||||
insecure := false
|
||||
if cfg != nil {
|
||||
insecure = cfg.Security.ProxyProbe.InsecureSkipVerify
|
||||
}
|
||||
if insecure {
|
||||
log.Printf("[ProxyProbe] Warning: TLS verification is disabled for proxy probing.")
|
||||
}
|
||||
return &proxyProbeService{
|
||||
ipInfoURL: defaultIPInfoURL,
|
||||
insecureSkipVerify: insecure,
|
||||
}
|
||||
}
|
||||
|
||||
const defaultIPInfoURL = "https://ipinfo.io/json"
|
||||
|
||||
type proxyProbeService struct {
|
||||
ipInfoURL string
|
||||
ipInfoURL string
|
||||
insecureSkipVerify bool
|
||||
}
|
||||
|
||||
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 15 * time.Second,
|
||||
InsecureSkipVerify: true,
|
||||
InsecureSkipVerify: s.insecureSkipVerify,
|
||||
ProxyStrict: true,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -295,13 +295,13 @@ func TestAPIContracts(t *testing.T) {
|
||||
"smtp_host": "smtp.example.com",
|
||||
"smtp_port": 587,
|
||||
"smtp_username": "user",
|
||||
"smtp_password": "secret",
|
||||
"smtp_password_configured": true,
|
||||
"smtp_from_email": "no-reply@example.com",
|
||||
"smtp_from_name": "Sub2API",
|
||||
"smtp_use_tls": true,
|
||||
"turnstile_enabled": true,
|
||||
"turnstile_site_key": "site-key",
|
||||
"turnstile_secret_key": "secret-key",
|
||||
"turnstile_secret_key_configured": true,
|
||||
"site_name": "Sub2API",
|
||||
"site_logo": "",
|
||||
"site_subtitle": "Subtitle",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
@@ -35,6 +36,15 @@ func ProvideRouter(
|
||||
|
||||
r := gin.New()
|
||||
r.Use(middleware2.Recovery())
|
||||
if len(cfg.Server.TrustedProxies) > 0 {
|
||||
if err := r.SetTrustedProxies(cfg.Server.TrustedProxies); err != nil {
|
||||
log.Printf("Failed to set trusted proxies: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err := r.SetTrustedProxies(nil); err != nil {
|
||||
log.Printf("Failed to disable trusted proxies: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg)
|
||||
}
|
||||
|
||||
@@ -19,6 +19,13 @@ func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionS
|
||||
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
|
||||
func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
queryKey := strings.TrimSpace(c.Query("key"))
|
||||
queryApiKey := strings.TrimSpace(c.Query("api_key"))
|
||||
if queryKey != "" || queryApiKey != "" {
|
||||
AbortWithError(c, 400, "api_key_in_query_deprecated", "API key in query parameter is deprecated. Please use Authorization header instead.")
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试从Authorization header中提取API key (Bearer scheme)
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
var apiKeyString string
|
||||
@@ -41,19 +48,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
|
||||
apiKeyString = c.GetHeader("x-goog-api-key")
|
||||
}
|
||||
|
||||
// 如果header中没有,尝试从query参数中提取(Google API key风格)
|
||||
if apiKeyString == "" {
|
||||
apiKeyString = c.Query("key")
|
||||
}
|
||||
|
||||
// 兼容常见别名
|
||||
if apiKeyString == "" {
|
||||
apiKeyString = c.Query("api_key")
|
||||
}
|
||||
|
||||
// 如果所有header都没有API key
|
||||
if apiKeyString == "" {
|
||||
AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter")
|
||||
AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, or x-goog-api-key header")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -22,6 +22,10 @@ func ApiKeyAuthGoogle(apiKeyService *service.ApiKeyService, cfg *config.Config)
|
||||
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
|
||||
func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if v := strings.TrimSpace(c.Query("api_key")); v != "" {
|
||||
abortWithGoogleError(c, 400, "Query parameter api_key is deprecated. Use Authorization header or key instead.")
|
||||
return
|
||||
}
|
||||
apiKeyString := extractAPIKeyFromRequest(c)
|
||||
if apiKeyString == "" {
|
||||
abortWithGoogleError(c, 401, "API key is required")
|
||||
@@ -116,15 +120,18 @@ func extractAPIKeyFromRequest(c *gin.Context) string {
|
||||
if v := strings.TrimSpace(c.GetHeader("x-goog-api-key")); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(c.Query("key")); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(c.Query("api_key")); v != "" {
|
||||
return v
|
||||
if allowGoogleQueryKey(c.Request.URL.Path) {
|
||||
if v := strings.TrimSpace(c.Query("key")); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func allowGoogleQueryKey(path string) bool {
|
||||
return strings.HasPrefix(path, "/v1beta") || strings.HasPrefix(path, "/antigravity/v1beta")
|
||||
}
|
||||
|
||||
func abortWithGoogleError(c *gin.Context, status int, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
|
||||
@@ -109,6 +109,58 @@ func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
|
||||
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return nil, errors.New("should not be called")
|
||||
},
|
||||
})
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test?api_key=legacy", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
var resp googleErrorResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusBadRequest, resp.Error.Code)
|
||||
require.Equal(t, "Query parameter api_key is deprecated. Use Authorization header or key instead.", resp.Error.Message)
|
||||
require.Equal(t, "INVALID_ARGUMENT", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return &service.ApiKey{
|
||||
ID: 1,
|
||||
Key: key,
|
||||
Status: service.StatusActive,
|
||||
User: &service.User{
|
||||
ID: 123,
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test?key=valid", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -1,24 +1,103 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var corsWarningOnce sync.Once
|
||||
|
||||
// CORS 跨域中间件
|
||||
func CORS() gin.HandlerFunc {
|
||||
func CORS(cfg config.CORSConfig) gin.HandlerFunc {
|
||||
allowedOrigins := normalizeOrigins(cfg.AllowedOrigins)
|
||||
allowAll := false
|
||||
for _, origin := range allowedOrigins {
|
||||
if origin == "*" {
|
||||
allowAll = true
|
||||
break
|
||||
}
|
||||
}
|
||||
wildcardWithSpecific := allowAll && len(allowedOrigins) > 1
|
||||
if wildcardWithSpecific {
|
||||
allowedOrigins = []string{"*"}
|
||||
}
|
||||
allowCredentials := cfg.AllowCredentials
|
||||
|
||||
corsWarningOnce.Do(func() {
|
||||
if len(allowedOrigins) == 0 {
|
||||
log.Println("Warning: CORS allowed_origins not configured; cross-origin requests will be rejected.")
|
||||
}
|
||||
if wildcardWithSpecific {
|
||||
log.Println("Warning: CORS allowed_origins includes '*'; wildcard will take precedence over explicit origins.")
|
||||
}
|
||||
if allowAll && allowCredentials {
|
||||
log.Println("Warning: CORS allowed_origins set to '*', disabling allow_credentials.")
|
||||
}
|
||||
})
|
||||
if allowAll && allowCredentials {
|
||||
allowCredentials = false
|
||||
}
|
||||
|
||||
allowedSet := make(map[string]struct{}, len(allowedOrigins))
|
||||
for _, origin := range allowedOrigins {
|
||||
if origin == "" || origin == "*" {
|
||||
continue
|
||||
}
|
||||
allowedSet[origin] = struct{}{}
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
// 设置允许跨域的响应头
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
origin := strings.TrimSpace(c.GetHeader("Origin"))
|
||||
originAllowed := allowAll
|
||||
if origin != "" && !allowAll {
|
||||
_, originAllowed = allowedSet[origin]
|
||||
}
|
||||
|
||||
if originAllowed {
|
||||
if allowAll {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
} else if origin != "" {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
c.Writer.Header().Add("Vary", "Origin")
|
||||
}
|
||||
if allowCredentials {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
|
||||
|
||||
// 处理预检请求
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(204)
|
||||
if c.Request.Method == http.MethodOptions {
|
||||
if originAllowed {
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
} else {
|
||||
c.AbortWithStatus(http.StatusForbidden)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeOrigins(values []string) []string {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
normalized := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
normalized = append(normalized, trimmed)
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
26
backend/internal/server/middleware/security_headers.go
Normal file
26
backend/internal/server/middleware/security_headers.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// SecurityHeaders sets baseline security headers for all responses.
|
||||
func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
|
||||
policy := strings.TrimSpace(cfg.Policy)
|
||||
if policy == "" {
|
||||
policy = config.DefaultCSPPolicy
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
c.Header("X-Frame-Options", "DENY")
|
||||
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
if cfg.Enabled {
|
||||
c.Header("Content-Security-Policy", policy)
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -24,7 +24,8 @@ func SetupRouter(
|
||||
) *gin.Engine {
|
||||
// 应用中间件
|
||||
r.Use(middleware2.Logger())
|
||||
r.Use(middleware2.CORS())
|
||||
r.Use(middleware2.CORS(cfg.CORS))
|
||||
r.Use(middleware2.SecurityHeaders(cfg.Security.CSP))
|
||||
|
||||
// Serve embedded frontend if available
|
||||
if web.HasEmbeddedFrontend() {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@@ -15,9 +16,11 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -49,6 +52,7 @@ type AccountTestService struct {
|
||||
geminiTokenProvider *GeminiTokenProvider
|
||||
antigravityGatewayService *AntigravityGatewayService
|
||||
httpUpstream HTTPUpstream
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewAccountTestService creates a new AccountTestService
|
||||
@@ -59,6 +63,7 @@ func NewAccountTestService(
|
||||
geminiTokenProvider *GeminiTokenProvider,
|
||||
antigravityGatewayService *AntigravityGatewayService,
|
||||
httpUpstream HTTPUpstream,
|
||||
cfg *config.Config,
|
||||
) *AccountTestService {
|
||||
return &AccountTestService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -67,9 +72,25 @@ func NewAccountTestService(
|
||||
geminiTokenProvider: geminiTokenProvider,
|
||||
antigravityGatewayService: antigravityGatewayService,
|
||||
httpUpstream: httpUpstream,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
if s.cfg == nil {
|
||||
return "", errors.New("config is not available")
|
||||
}
|
||||
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
|
||||
RequireAllowlist: true,
|
||||
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// generateSessionString generates a Claude Code style session string
|
||||
func generateSessionString() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
@@ -207,11 +228,15 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
||||
return s.sendErrorAndEnd(c, "No API key available")
|
||||
}
|
||||
|
||||
apiURL = account.GetBaseURL()
|
||||
if apiURL == "" {
|
||||
apiURL = "https://api.anthropic.com"
|
||||
baseURL := account.GetBaseURL()
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.anthropic.com"
|
||||
}
|
||||
apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages"
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
|
||||
}
|
||||
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages"
|
||||
} else {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||
}
|
||||
@@ -333,7 +358,11 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
apiURL = strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
|
||||
}
|
||||
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/responses"
|
||||
} else {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||
}
|
||||
@@ -513,10 +542,14 @@ func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, accou
|
||||
if baseURL == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use streamGenerateContent for real-time feedback
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse",
|
||||
strings.TrimRight(baseURL, "/"), modelID)
|
||||
strings.TrimRight(normalizedBaseURL, "/"), modelID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
@@ -548,7 +581,11 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun
|
||||
if strings.TrimSpace(baseURL) == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(baseURL, "/"), modelID)
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(normalizedBaseURL, "/"), modelID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
@@ -577,7 +614,11 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT
|
||||
}
|
||||
wrappedBytes, _ := json.Marshal(wrapped)
|
||||
|
||||
fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", geminicli.GeminiCliBaseURL)
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", normalizedBaseURL)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(wrappedBytes))
|
||||
if err != nil {
|
||||
|
||||
@@ -221,9 +221,33 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
|
||||
|
||||
// VerifyTurnstile 验证Turnstile token
|
||||
func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error {
|
||||
required := s.cfg != nil && s.cfg.Server.Mode == "release" && s.cfg.Turnstile.Required
|
||||
|
||||
if required {
|
||||
if s.settingService == nil {
|
||||
log.Println("[Auth] Turnstile required but settings service is not configured")
|
||||
return ErrTurnstileNotConfigured
|
||||
}
|
||||
enabled := s.settingService.IsTurnstileEnabled(ctx)
|
||||
secretConfigured := s.settingService.GetTurnstileSecretKey(ctx) != ""
|
||||
if !enabled || !secretConfigured {
|
||||
log.Printf("[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)", enabled, secretConfigured)
|
||||
return ErrTurnstileNotConfigured
|
||||
}
|
||||
}
|
||||
|
||||
if s.turnstileService == nil {
|
||||
if required {
|
||||
log.Println("[Auth] Turnstile required but service not configured")
|
||||
return ErrTurnstileNotConfigured
|
||||
}
|
||||
return nil // 服务未配置则跳过验证
|
||||
}
|
||||
|
||||
if !required && s.settingService != nil && s.settingService.IsTurnstileEnabled(ctx) && s.settingService.GetTurnstileSecretKey(ctx) == "" {
|
||||
log.Println("[Auth] Turnstile enabled but secret key not configured")
|
||||
}
|
||||
|
||||
return s.turnstileService.VerifyToken(ctx, token, remoteIP)
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
// 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
|
||||
var (
|
||||
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
|
||||
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
|
||||
)
|
||||
|
||||
// subscriptionCacheData 订阅缓存数据结构(内部使用)
|
||||
@@ -76,6 +77,7 @@ type BillingCacheService struct {
|
||||
userRepo UserRepository
|
||||
subRepo UserSubscriptionRepository
|
||||
cfg *config.Config
|
||||
circuitBreaker *billingCircuitBreaker
|
||||
|
||||
cacheWriteChan chan cacheWriteTask
|
||||
cacheWriteWg sync.WaitGroup
|
||||
@@ -95,6 +97,7 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo
|
||||
subRepo: subRepo,
|
||||
cfg: cfg,
|
||||
}
|
||||
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
|
||||
svc.startCacheWriteWorkers()
|
||||
return svc
|
||||
}
|
||||
@@ -450,6 +453,9 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
|
||||
if s.cfg.RunMode == config.RunModeSimple {
|
||||
return nil
|
||||
}
|
||||
if s.circuitBreaker != nil && !s.circuitBreaker.Allow() {
|
||||
return ErrBillingServiceUnavailable
|
||||
}
|
||||
|
||||
// 判断计费模式
|
||||
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
|
||||
@@ -465,9 +471,14 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
|
||||
func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userID int64) error {
|
||||
balance, err := s.GetUserBalance(ctx, userID)
|
||||
if err != nil {
|
||||
// 缓存/数据库错误,允许通过(降级处理)
|
||||
log.Printf("Warning: get user balance failed, allowing request: %v", err)
|
||||
return nil
|
||||
if s.circuitBreaker != nil {
|
||||
s.circuitBreaker.OnFailure(err)
|
||||
}
|
||||
log.Printf("ALERT: billing balance check failed for user %d: %v", userID, err)
|
||||
return ErrBillingServiceUnavailable.WithCause(err)
|
||||
}
|
||||
if s.circuitBreaker != nil {
|
||||
s.circuitBreaker.OnSuccess()
|
||||
}
|
||||
|
||||
if balance <= 0 {
|
||||
@@ -482,9 +493,14 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
|
||||
// 获取订阅缓存数据
|
||||
subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID)
|
||||
if err != nil {
|
||||
// 缓存/数据库错误,降级使用传入的subscription进行检查
|
||||
log.Printf("Warning: get subscription cache failed, using fallback: %v", err)
|
||||
return s.checkSubscriptionLimitsFallback(subscription, group)
|
||||
if s.circuitBreaker != nil {
|
||||
s.circuitBreaker.OnFailure(err)
|
||||
}
|
||||
log.Printf("ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err)
|
||||
return ErrBillingServiceUnavailable.WithCause(err)
|
||||
}
|
||||
if s.circuitBreaker != nil {
|
||||
s.circuitBreaker.OnSuccess()
|
||||
}
|
||||
|
||||
// 检查订阅状态
|
||||
@@ -513,6 +529,137 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
|
||||
return nil
|
||||
}
|
||||
|
||||
type billingCircuitBreakerState int
|
||||
|
||||
const (
|
||||
billingCircuitClosed billingCircuitBreakerState = iota
|
||||
billingCircuitOpen
|
||||
billingCircuitHalfOpen
|
||||
)
|
||||
|
||||
type billingCircuitBreaker struct {
|
||||
mu sync.Mutex
|
||||
state billingCircuitBreakerState
|
||||
failures int
|
||||
openedAt time.Time
|
||||
failureThreshold int
|
||||
resetTimeout time.Duration
|
||||
halfOpenRequests int
|
||||
halfOpenRemaining int
|
||||
}
|
||||
|
||||
func newBillingCircuitBreaker(cfg config.CircuitBreakerConfig) *billingCircuitBreaker {
|
||||
if !cfg.Enabled {
|
||||
return nil
|
||||
}
|
||||
resetTimeout := time.Duration(cfg.ResetTimeoutSeconds) * time.Second
|
||||
if resetTimeout <= 0 {
|
||||
resetTimeout = 30 * time.Second
|
||||
}
|
||||
halfOpen := cfg.HalfOpenRequests
|
||||
if halfOpen <= 0 {
|
||||
halfOpen = 1
|
||||
}
|
||||
threshold := cfg.FailureThreshold
|
||||
if threshold <= 0 {
|
||||
threshold = 5
|
||||
}
|
||||
return &billingCircuitBreaker{
|
||||
state: billingCircuitClosed,
|
||||
failureThreshold: threshold,
|
||||
resetTimeout: resetTimeout,
|
||||
halfOpenRequests: halfOpen,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *billingCircuitBreaker) Allow() bool {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
switch b.state {
|
||||
case billingCircuitClosed:
|
||||
return true
|
||||
case billingCircuitOpen:
|
||||
if time.Since(b.openedAt) < b.resetTimeout {
|
||||
return false
|
||||
}
|
||||
b.state = billingCircuitHalfOpen
|
||||
b.halfOpenRemaining = b.halfOpenRequests
|
||||
log.Printf("ALERT: billing circuit breaker entering half-open state")
|
||||
fallthrough
|
||||
case billingCircuitHalfOpen:
|
||||
if b.halfOpenRemaining <= 0 {
|
||||
return false
|
||||
}
|
||||
b.halfOpenRemaining--
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (b *billingCircuitBreaker) OnFailure(err error) {
|
||||
if b == nil {
|
||||
return
|
||||
}
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
switch b.state {
|
||||
case billingCircuitOpen:
|
||||
return
|
||||
case billingCircuitHalfOpen:
|
||||
b.state = billingCircuitOpen
|
||||
b.openedAt = time.Now()
|
||||
b.halfOpenRemaining = 0
|
||||
log.Printf("ALERT: billing circuit breaker opened after half-open failure: %v", err)
|
||||
return
|
||||
default:
|
||||
b.failures++
|
||||
if b.failures >= b.failureThreshold {
|
||||
b.state = billingCircuitOpen
|
||||
b.openedAt = time.Now()
|
||||
b.halfOpenRemaining = 0
|
||||
log.Printf("ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *billingCircuitBreaker) OnSuccess() {
|
||||
if b == nil {
|
||||
return
|
||||
}
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
previousState := b.state
|
||||
previousFailures := b.failures
|
||||
|
||||
b.state = billingCircuitClosed
|
||||
b.failures = 0
|
||||
b.halfOpenRemaining = 0
|
||||
|
||||
// 只有状态真正发生变化时才记录日志
|
||||
if previousState != billingCircuitClosed {
|
||||
log.Printf("ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState))
|
||||
} else if previousFailures > 0 {
|
||||
log.Printf("INFO: billing circuit breaker failures reset from %d", previousFailures)
|
||||
}
|
||||
}
|
||||
|
||||
func circuitStateString(state billingCircuitBreakerState) string {
|
||||
switch state {
|
||||
case billingCircuitClosed:
|
||||
return "closed"
|
||||
case billingCircuitOpen:
|
||||
return "open"
|
||||
case billingCircuitHalfOpen:
|
||||
return "half-open"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// checkSubscriptionLimitsFallback 降级检查订阅限额
|
||||
func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *UserSubscription, group *Group) error {
|
||||
if subscription == nil {
|
||||
|
||||
@@ -8,12 +8,13 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
)
|
||||
|
||||
type CRSSyncService struct {
|
||||
@@ -22,6 +23,7 @@ type CRSSyncService struct {
|
||||
oauthService *OAuthService
|
||||
openaiOAuthService *OpenAIOAuthService
|
||||
geminiOAuthService *GeminiOAuthService
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewCRSSyncService(
|
||||
@@ -30,6 +32,7 @@ func NewCRSSyncService(
|
||||
oauthService *OAuthService,
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
geminiOAuthService *GeminiOAuthService,
|
||||
cfg *config.Config,
|
||||
) *CRSSyncService {
|
||||
return &CRSSyncService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -37,6 +40,7 @@ func NewCRSSyncService(
|
||||
oauthService: oauthService,
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
geminiOAuthService: geminiOAuthService,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -187,7 +191,10 @@ type crsGeminiAPIKeyAccount struct {
|
||||
}
|
||||
|
||||
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
|
||||
baseURL, err := normalizeBaseURL(input.BaseURL)
|
||||
if s.cfg == nil {
|
||||
return nil, errors.New("config is not available")
|
||||
}
|
||||
baseURL, err := normalizeBaseURL(input.BaseURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1055,17 +1062,18 @@ func mapCRSStatus(isActive bool, status string) string {
|
||||
return "active"
|
||||
}
|
||||
|
||||
func normalizeBaseURL(raw string) (string, error) {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return "", errors.New("base_url is required")
|
||||
func normalizeBaseURL(raw string, allowlist []string, allowPrivate bool) (string, error) {
|
||||
// 当 allowlist 为空时,不强制要求白名单(只进行基本的 URL 和 SSRF 验证)
|
||||
requireAllowlist := len(allowlist) > 0
|
||||
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: allowlist,
|
||||
RequireAllowlist: requireAllowlist,
|
||||
AllowPrivate: allowPrivate,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid base_url: %w", err)
|
||||
}
|
||||
u, err := url.Parse(trimmed)
|
||||
if err != nil || u.Scheme == "" || u.Host == "" {
|
||||
return "", fmt.Errorf("invalid base_url: %s", trimmed)
|
||||
}
|
||||
u.Path = strings.TrimRight(u.Path, "/")
|
||||
return strings.TrimRight(u.String(), "/"), nil
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// cleanBaseURL removes trailing suffix from base_url in credentials
|
||||
|
||||
@@ -19,6 +19,8 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -724,7 +726,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
targetURL := claudeAPIURL
|
||||
if account.Type == AccountTypeApiKey {
|
||||
baseURL := account.GetBaseURL()
|
||||
targetURL = baseURL + "/v1/messages"
|
||||
if baseURL != "" {
|
||||
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetURL = validatedURL + "/v1/messages"
|
||||
}
|
||||
}
|
||||
|
||||
// OAuth账号:应用统一指纹
|
||||
@@ -1107,12 +1115,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||
}
|
||||
|
||||
// 透传响应头
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
c.Header(key, value)
|
||||
}
|
||||
}
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
|
||||
// 写入响应
|
||||
c.Data(resp.StatusCode, "application/json", body)
|
||||
@@ -1352,7 +1355,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
targetURL := claudeAPICountTokensURL
|
||||
if account.Type == AccountTypeApiKey {
|
||||
baseURL := account.GetBaseURL()
|
||||
targetURL = baseURL + "/v1/messages/count_tokens"
|
||||
if baseURL != "" {
|
||||
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetURL = validatedURL + "/v1/messages/count_tokens"
|
||||
}
|
||||
}
|
||||
|
||||
// OAuth 账号:应用统一指纹和重写 userID
|
||||
@@ -1424,3 +1433,15 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
|
||||
RequireAllowlist: true,
|
||||
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid base_url: %w", err)
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
@@ -18,9 +18,12 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -41,6 +44,7 @@ type GeminiMessagesCompatService struct {
|
||||
rateLimitService *RateLimitService
|
||||
httpUpstream HTTPUpstream
|
||||
antigravityGatewayService *AntigravityGatewayService
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewGeminiMessagesCompatService(
|
||||
@@ -51,6 +55,7 @@ func NewGeminiMessagesCompatService(
|
||||
rateLimitService *RateLimitService,
|
||||
httpUpstream HTTPUpstream,
|
||||
antigravityGatewayService *AntigravityGatewayService,
|
||||
cfg *config.Config,
|
||||
) *GeminiMessagesCompatService {
|
||||
return &GeminiMessagesCompatService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -60,6 +65,7 @@ func NewGeminiMessagesCompatService(
|
||||
rateLimitService: rateLimitService,
|
||||
httpUpstream: httpUpstream,
|
||||
antigravityGatewayService: antigravityGatewayService,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -209,6 +215,18 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit
|
||||
return s.antigravityGatewayService
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
|
||||
RequireAllowlist: true,
|
||||
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid base_url: %w", err)
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
|
||||
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
|
||||
var accounts []Account
|
||||
@@ -360,16 +378,20 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
return nil, "", errors.New("gemini api_key not configured")
|
||||
}
|
||||
|
||||
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
|
||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||
if baseURL == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
action := "generateContent"
|
||||
if req.Stream {
|
||||
action = "streamGenerateContent"
|
||||
}
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, action)
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
|
||||
if req.Stream {
|
||||
fullURL += "?alt=sse"
|
||||
}
|
||||
@@ -406,7 +428,11 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
|
||||
if projectID != "" {
|
||||
// Mode 1: Code Assist API
|
||||
fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, action)
|
||||
baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), action)
|
||||
if useUpstreamStream {
|
||||
fullURL += "?alt=sse"
|
||||
}
|
||||
@@ -432,12 +458,16 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
return upstreamReq, "x-request-id", nil
|
||||
} else {
|
||||
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
|
||||
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
|
||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||
if baseURL == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, action)
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
|
||||
if useUpstreamStream {
|
||||
fullURL += "?alt=sse"
|
||||
}
|
||||
@@ -622,12 +652,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
return nil, "", errors.New("gemini api_key not configured")
|
||||
}
|
||||
|
||||
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
|
||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||
if baseURL == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, upstreamAction)
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
|
||||
if useUpstreamStream {
|
||||
fullURL += "?alt=sse"
|
||||
}
|
||||
@@ -659,7 +693,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
|
||||
if projectID != "" && !forceAIStudio {
|
||||
// Mode 1: Code Assist API
|
||||
fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, upstreamAction)
|
||||
baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), upstreamAction)
|
||||
if useUpstreamStream {
|
||||
fullURL += "?alt=sse"
|
||||
}
|
||||
@@ -685,12 +723,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
return upstreamReq, "x-request-id", nil
|
||||
} else {
|
||||
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
|
||||
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
|
||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||
if baseURL == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, upstreamAction)
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
|
||||
if useUpstreamStream {
|
||||
fullURL += "?alt=sse"
|
||||
}
|
||||
@@ -1608,6 +1650,8 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
|
||||
_ = json.Unmarshal(respBody, &parsed)
|
||||
}
|
||||
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
@@ -1720,11 +1764,15 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
|
||||
return nil, errors.New("invalid path")
|
||||
}
|
||||
|
||||
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/")
|
||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||
if baseURL == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
fullURL := strings.TrimRight(baseURL, "/") + path
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fullURL := strings.TrimRight(normalizedBaseURL, "/") + path
|
||||
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
@@ -1763,9 +1811,14 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
|
||||
wwwAuthenticate := resp.Header.Get("Www-Authenticate")
|
||||
filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
if wwwAuthenticate != "" {
|
||||
filteredHeaders.Set("Www-Authenticate", wwwAuthenticate)
|
||||
}
|
||||
return &UpstreamHTTPResult{
|
||||
StatusCode: resp.StatusCode,
|
||||
Headers: resp.Header.Clone(),
|
||||
Headers: filteredHeaders,
|
||||
Body: body,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -18,6 +18,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -370,10 +372,14 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
||||
case AccountTypeApiKey:
|
||||
// API Key accounts use Platform API or custom base URL
|
||||
baseURL := account.GetOpenAIBaseURL()
|
||||
if baseURL != "" {
|
||||
targetURL = baseURL + "/responses"
|
||||
} else {
|
||||
if baseURL == "" {
|
||||
targetURL = openaiPlatformAPIURL
|
||||
} else {
|
||||
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetURL = validatedURL + "/responses"
|
||||
}
|
||||
default:
|
||||
targetURL = openaiPlatformAPIURL
|
||||
@@ -645,18 +651,25 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
|
||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||
}
|
||||
|
||||
// Pass through headers
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
c.Header(key, value)
|
||||
}
|
||||
}
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
|
||||
c.Data(resp.StatusCode, "application/json", body)
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
|
||||
RequireAllowlist: true,
|
||||
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid base_url: %w", err)
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -211,16 +212,35 @@ func (s *PricingService) syncWithRemote() error {
|
||||
|
||||
// downloadPricingData 从远程下载价格数据
|
||||
func (s *PricingService) downloadPricingData() error {
|
||||
log.Printf("[Pricing] Downloading from %s", s.cfg.Pricing.RemoteURL)
|
||||
remoteURL, err := s.validatePricingURL(s.cfg.Pricing.RemoteURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("[Pricing] Downloading from %s", remoteURL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
body, err := s.remoteClient.FetchPricingJSON(ctx, s.cfg.Pricing.RemoteURL)
|
||||
var expectedHash string
|
||||
if strings.TrimSpace(s.cfg.Pricing.HashURL) != "" {
|
||||
expectedHash, err = s.fetchRemoteHash()
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetch remote hash: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
body, err := s.remoteClient.FetchPricingJSON(ctx, remoteURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("download failed: %w", err)
|
||||
}
|
||||
|
||||
if expectedHash != "" {
|
||||
actualHash := sha256.Sum256(body)
|
||||
if !strings.EqualFold(expectedHash, hex.EncodeToString(actualHash[:])) {
|
||||
return fmt.Errorf("pricing hash mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
// 解析JSON数据(使用灵活的解析方式)
|
||||
data, err := s.parsePricingData(body)
|
||||
if err != nil {
|
||||
@@ -373,10 +393,31 @@ func (s *PricingService) useFallbackPricing() error {
|
||||
|
||||
// fetchRemoteHash 从远程获取哈希值
|
||||
func (s *PricingService) fetchRemoteHash() (string, error) {
|
||||
hashURL, err := s.validatePricingURL(s.cfg.Pricing.HashURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
return s.remoteClient.FetchHashText(ctx, s.cfg.Pricing.HashURL)
|
||||
hash, err := s.remoteClient.FetchHashText(ctx, hashURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(hash), nil
|
||||
}
|
||||
|
||||
func (s *PricingService) validatePricingURL(raw string) (string, error) {
|
||||
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: s.cfg.Security.URLAllowlist.PricingHosts,
|
||||
RequireAllowlist: true,
|
||||
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid pricing url: %w", err)
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// computeFileHash 计算文件哈希
|
||||
|
||||
@@ -215,8 +215,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
SmtpFrom: settings[SettingKeySmtpFrom],
|
||||
SmtpFromName: settings[SettingKeySmtpFromName],
|
||||
SmtpUseTLS: settings[SettingKeySmtpUseTLS] == "true",
|
||||
SmtpPasswordConfigured: settings[SettingKeySmtpPassword] != "",
|
||||
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
|
||||
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
|
||||
TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "",
|
||||
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
|
||||
SiteLogo: settings[SettingKeySiteLogo],
|
||||
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
||||
@@ -245,10 +247,6 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
result.DefaultBalance = s.cfg.Default.UserBalance
|
||||
}
|
||||
|
||||
// 敏感信息直接返回,方便测试连接时使用
|
||||
result.SmtpPassword = settings[SettingKeySmtpPassword]
|
||||
result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ type SystemSettings struct {
|
||||
SmtpPort int
|
||||
SmtpUsername string
|
||||
SmtpPassword string
|
||||
SmtpPasswordConfigured bool
|
||||
SmtpFrom string
|
||||
SmtpFromName string
|
||||
SmtpUseTLS bool
|
||||
@@ -15,6 +16,7 @@ type SystemSettings struct {
|
||||
TurnstileEnabled bool
|
||||
TurnstileSiteKey string
|
||||
TurnstileSecretKey string
|
||||
TurnstileSecretKeyConfigured bool
|
||||
|
||||
SiteName string
|
||||
SiteLogo string
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
|
||||
@@ -196,11 +197,17 @@ func Install(cfg *SetupConfig) error {
|
||||
|
||||
// Generate JWT secret if not provided
|
||||
if cfg.JWT.Secret == "" {
|
||||
if strings.EqualFold(cfg.Server.Mode, "release") {
|
||||
return fmt.Errorf("jwt secret is required in release mode")
|
||||
}
|
||||
secret, err := generateSecret(32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate jwt secret: %w", err)
|
||||
}
|
||||
cfg.JWT.Secret = secret
|
||||
log.Println("Warning: JWT secret auto-generated for non-release mode. Do not use in production.")
|
||||
} else if strings.EqualFold(cfg.Server.Mode, "release") && len(cfg.JWT.Secret) < 32 {
|
||||
return fmt.Errorf("jwt secret must be at least 32 characters in release mode")
|
||||
}
|
||||
|
||||
// Test connections
|
||||
@@ -474,12 +481,17 @@ func AutoSetupFromEnv() error {
|
||||
|
||||
// Generate JWT secret if not provided
|
||||
if cfg.JWT.Secret == "" {
|
||||
if strings.EqualFold(cfg.Server.Mode, "release") {
|
||||
return fmt.Errorf("jwt secret is required in release mode")
|
||||
}
|
||||
secret, err := generateSecret(32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate jwt secret: %w", err)
|
||||
}
|
||||
cfg.JWT.Secret = secret
|
||||
log.Println("Generated JWT secret automatically")
|
||||
log.Println("Warning: JWT secret auto-generated for non-release mode. Do not use in production.")
|
||||
} else if strings.EqualFold(cfg.Server.Mode, "release") && len(cfg.JWT.Secret) < 32 {
|
||||
return fmt.Errorf("jwt secret must be at least 32 characters in release mode")
|
||||
}
|
||||
|
||||
// Generate admin password if not provided
|
||||
@@ -489,8 +501,8 @@ func AutoSetupFromEnv() error {
|
||||
return fmt.Errorf("failed to generate admin password: %w", err)
|
||||
}
|
||||
cfg.Admin.Password = password
|
||||
log.Printf("Generated admin password: %s", cfg.Admin.Password)
|
||||
log.Println("IMPORTANT: Save this password! It will not be shown again.")
|
||||
fmt.Printf("Generated admin password (one-time): %s\n", cfg.Admin.Password)
|
||||
fmt.Println("IMPORTANT: Save this password! It will not be shown again.")
|
||||
}
|
||||
|
||||
// Test database connection
|
||||
|
||||
100
backend/internal/util/logredact/redact.go
Normal file
100
backend/internal/util/logredact/redact.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package logredact
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// maxRedactDepth 限制递归深度以防止栈溢出
|
||||
const maxRedactDepth = 32
|
||||
|
||||
var defaultSensitiveKeys = map[string]struct{}{
|
||||
"authorization_code": {},
|
||||
"code": {},
|
||||
"code_verifier": {},
|
||||
"access_token": {},
|
||||
"refresh_token": {},
|
||||
"id_token": {},
|
||||
"client_secret": {},
|
||||
"password": {},
|
||||
}
|
||||
|
||||
func RedactMap(input map[string]any, extraKeys ...string) map[string]any {
|
||||
if input == nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
keys := buildKeySet(extraKeys)
|
||||
redacted, ok := redactValueWithDepth(input, keys, 0).(map[string]any)
|
||||
if !ok {
|
||||
return map[string]any{}
|
||||
}
|
||||
return redacted
|
||||
}
|
||||
|
||||
func RedactJSON(raw []byte, extraKeys ...string) string {
|
||||
if len(raw) == 0 {
|
||||
return ""
|
||||
}
|
||||
var value any
|
||||
if err := json.Unmarshal(raw, &value); err != nil {
|
||||
return "<non-json payload redacted>"
|
||||
}
|
||||
keys := buildKeySet(extraKeys)
|
||||
redacted := redactValueWithDepth(value, keys, 0)
|
||||
encoded, err := json.Marshal(redacted)
|
||||
if err != nil {
|
||||
return "<redacted>"
|
||||
}
|
||||
return string(encoded)
|
||||
}
|
||||
|
||||
func buildKeySet(extraKeys []string) map[string]struct{} {
|
||||
keys := make(map[string]struct{}, len(defaultSensitiveKeys)+len(extraKeys))
|
||||
for k := range defaultSensitiveKeys {
|
||||
keys[k] = struct{}{}
|
||||
}
|
||||
for _, key := range extraKeys {
|
||||
normalized := normalizeKey(key)
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
keys[normalized] = struct{}{}
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func redactValueWithDepth(value any, keys map[string]struct{}, depth int) any {
|
||||
if depth > maxRedactDepth {
|
||||
return "<depth limit exceeded>"
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case map[string]any:
|
||||
out := make(map[string]any, len(v))
|
||||
for k, val := range v {
|
||||
if isSensitiveKey(k, keys) {
|
||||
out[k] = "***"
|
||||
continue
|
||||
}
|
||||
out[k] = redactValueWithDepth(val, keys, depth+1)
|
||||
}
|
||||
return out
|
||||
case []any:
|
||||
out := make([]any, len(v))
|
||||
for i, item := range v {
|
||||
out[i] = redactValueWithDepth(item, keys, depth+1)
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
func isSensitiveKey(key string, keys map[string]struct{}) bool {
|
||||
_, ok := keys[normalizeKey(key)]
|
||||
return ok
|
||||
}
|
||||
|
||||
func normalizeKey(key string) string {
|
||||
return strings.ToLower(strings.TrimSpace(key))
|
||||
}
|
||||
92
backend/internal/util/responseheaders/responseheaders.go
Normal file
92
backend/internal/util/responseheaders/responseheaders.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package responseheaders
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
// defaultAllowed 定义允许透传的响应头白名单
|
||||
// 注意:以下头部由 Go HTTP 包自动处理,不应手动设置:
|
||||
// - content-length: 由 ResponseWriter 根据实际写入数据自动设置
|
||||
// - transfer-encoding: 由 HTTP 库根据需要自动添加/移除
|
||||
// - connection: 由 HTTP 库管理连接复用
|
||||
var defaultAllowed = map[string]struct{}{
|
||||
"content-type": {},
|
||||
"content-encoding": {},
|
||||
"content-language": {},
|
||||
"cache-control": {},
|
||||
"etag": {},
|
||||
"last-modified": {},
|
||||
"expires": {},
|
||||
"vary": {},
|
||||
"date": {},
|
||||
"x-request-id": {},
|
||||
"x-ratelimit-limit-requests": {},
|
||||
"x-ratelimit-limit-tokens": {},
|
||||
"x-ratelimit-remaining-requests": {},
|
||||
"x-ratelimit-remaining-tokens": {},
|
||||
"x-ratelimit-reset-requests": {},
|
||||
"x-ratelimit-reset-tokens": {},
|
||||
"retry-after": {},
|
||||
"location": {},
|
||||
}
|
||||
|
||||
// hopByHopHeaders 是跳过的 hop-by-hop 头部,这些头部由 HTTP 库自动处理
|
||||
var hopByHopHeaders = map[string]struct{}{
|
||||
"content-length": {},
|
||||
"transfer-encoding": {},
|
||||
"connection": {},
|
||||
}
|
||||
|
||||
func FilterHeaders(src http.Header, cfg config.ResponseHeaderConfig) http.Header {
|
||||
allowed := make(map[string]struct{}, len(defaultAllowed)+len(cfg.AdditionalAllowed))
|
||||
for key := range defaultAllowed {
|
||||
allowed[key] = struct{}{}
|
||||
}
|
||||
for _, key := range cfg.AdditionalAllowed {
|
||||
normalized := strings.ToLower(strings.TrimSpace(key))
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
allowed[normalized] = struct{}{}
|
||||
}
|
||||
|
||||
forceRemove := make(map[string]struct{}, len(cfg.ForceRemove))
|
||||
for _, key := range cfg.ForceRemove {
|
||||
normalized := strings.ToLower(strings.TrimSpace(key))
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
forceRemove[normalized] = struct{}{}
|
||||
}
|
||||
|
||||
filtered := make(http.Header, len(src))
|
||||
for key, values := range src {
|
||||
lower := strings.ToLower(key)
|
||||
if _, blocked := forceRemove[lower]; blocked {
|
||||
continue
|
||||
}
|
||||
if _, ok := allowed[lower]; !ok {
|
||||
continue
|
||||
}
|
||||
// 跳过 hop-by-hop 头部,这些由 HTTP 库自动处理
|
||||
if _, isHopByHop := hopByHopHeaders[lower]; isHopByHop {
|
||||
continue
|
||||
}
|
||||
for _, value := range values {
|
||||
filtered.Add(key, value)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func WriteFilteredHeaders(dst http.Header, src http.Header, cfg config.ResponseHeaderConfig) {
|
||||
filtered := FilterHeaders(src, cfg)
|
||||
for key, values := range filtered {
|
||||
for _, value := range values {
|
||||
dst.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
121
backend/internal/util/urlvalidator/validator.go
Normal file
121
backend/internal/util/urlvalidator/validator.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package urlvalidator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ValidationOptions struct {
|
||||
AllowedHosts []string
|
||||
RequireAllowlist bool
|
||||
AllowPrivate bool
|
||||
}
|
||||
|
||||
func ValidateHTTPSURL(raw string, opts ValidationOptions) (string, error) {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return "", errors.New("url is required")
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(trimmed)
|
||||
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
|
||||
return "", fmt.Errorf("invalid url: %s", trimmed)
|
||||
}
|
||||
if !strings.EqualFold(parsed.Scheme, "https") {
|
||||
return "", fmt.Errorf("invalid url scheme: %s", parsed.Scheme)
|
||||
}
|
||||
|
||||
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
|
||||
if host == "" {
|
||||
return "", errors.New("invalid host")
|
||||
}
|
||||
if !opts.AllowPrivate && isBlockedHost(host) {
|
||||
return "", fmt.Errorf("host is not allowed: %s", host)
|
||||
}
|
||||
|
||||
allowlist := normalizeAllowlist(opts.AllowedHosts)
|
||||
if opts.RequireAllowlist && len(allowlist) == 0 {
|
||||
return "", errors.New("allowlist is not configured")
|
||||
}
|
||||
if len(allowlist) > 0 && !isAllowedHost(host, allowlist) {
|
||||
return "", fmt.Errorf("host is not allowed: %s", host)
|
||||
}
|
||||
|
||||
parsed.Path = strings.TrimRight(parsed.Path, "/")
|
||||
parsed.RawPath = ""
|
||||
return strings.TrimRight(parsed.String(), "/"), nil
|
||||
}
|
||||
|
||||
// ValidateResolvedIP 验证 DNS 解析后的 IP 地址是否安全
|
||||
// 用于防止 DNS Rebinding 攻击:在实际 HTTP 请求时调用此函数验证解析后的 IP
|
||||
func ValidateResolvedIP(host string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dns resolution failed: %w", err)
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() ||
|
||||
ip.IsLinkLocalMulticast() || ip.IsUnspecified() {
|
||||
return fmt.Errorf("resolved ip %s is not allowed", ip.String())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeAllowlist(values []string) []string {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
normalized := make([]string, 0, len(values))
|
||||
for _, v := range values {
|
||||
entry := strings.ToLower(strings.TrimSpace(v))
|
||||
if entry == "" {
|
||||
continue
|
||||
}
|
||||
if host, _, err := net.SplitHostPort(entry); err == nil {
|
||||
entry = host
|
||||
}
|
||||
normalized = append(normalized, entry)
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func isAllowedHost(host string, allowlist []string) bool {
|
||||
for _, entry := range allowlist {
|
||||
if entry == "" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(entry, "*.") {
|
||||
suffix := strings.TrimPrefix(entry, "*.")
|
||||
if host == suffix || strings.HasSuffix(host, "."+suffix) {
|
||||
return true
|
||||
}
|
||||
continue
|
||||
}
|
||||
if host == entry {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isBlockedHost(host string) bool {
|
||||
if host == "localhost" || strings.HasSuffix(host, ".localhost") {
|
||||
return true
|
||||
}
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsUnspecified() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -12,6 +12,8 @@ server:
|
||||
port: 8080
|
||||
# Mode: "debug" for development, "release" for production
|
||||
mode: "release"
|
||||
# Trusted proxies for X-Forwarded-For parsing (CIDR/IP). Empty disables trusted proxies.
|
||||
trusted_proxies: []
|
||||
|
||||
# =============================================================================
|
||||
# Run Mode Configuration
|
||||
@@ -21,6 +23,48 @@ server:
|
||||
# - simple: Hides SaaS features and skips billing/balance checks
|
||||
run_mode: "standard"
|
||||
|
||||
# =============================================================================
|
||||
# CORS Configuration
|
||||
# =============================================================================
|
||||
cors:
|
||||
# Allowed origins list. Leave empty to disable cross-origin requests.
|
||||
allowed_origins: []
|
||||
# Allow credentials (cookies/authorization headers). Cannot be used with "*".
|
||||
allow_credentials: true
|
||||
|
||||
# =============================================================================
|
||||
# Security Configuration
|
||||
# =============================================================================
|
||||
security:
|
||||
url_allowlist:
|
||||
# Allowed upstream hosts for API proxying
|
||||
upstream_hosts:
|
||||
- "api.openai.com"
|
||||
- "api.anthropic.com"
|
||||
- "generativelanguage.googleapis.com"
|
||||
- "cloudcode-pa.googleapis.com"
|
||||
- "*.openai.azure.com"
|
||||
# Allowed hosts for pricing data download
|
||||
pricing_hosts:
|
||||
- "raw.githubusercontent.com"
|
||||
# Allowed hosts for CRS sync (required when using CRS sync)
|
||||
crs_hosts: []
|
||||
# Allow localhost/private IPs for upstream/pricing/CRS (use only in trusted networks)
|
||||
allow_private_hosts: false
|
||||
response_headers:
|
||||
# Extra allowed response headers from upstream
|
||||
additional_allowed: []
|
||||
# Force-remove response headers from upstream
|
||||
force_remove: []
|
||||
csp:
|
||||
# Enable Content-Security-Policy header
|
||||
enabled: true
|
||||
# Default CSP policy (override if you host assets on other domains)
|
||||
policy: "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self' https:; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
|
||||
proxy_probe:
|
||||
# Allow skipping TLS verification for proxy probe (debug only)
|
||||
insecure_skip_verify: false
|
||||
|
||||
# =============================================================================
|
||||
# 网关配置
|
||||
# =============================================================================
|
||||
@@ -77,7 +121,7 @@ jwt:
|
||||
# IMPORTANT: Change this to a random string in production!
|
||||
# Generate with: openssl rand -hex 32
|
||||
secret: "change-this-to-a-secure-random-string"
|
||||
# Token expiration time in hours
|
||||
# Token expiration time in hours (max 24)
|
||||
expire_hour: 24
|
||||
|
||||
# =============================================================================
|
||||
@@ -122,6 +166,23 @@ pricing:
|
||||
# Hash check interval in minutes
|
||||
hash_check_interval_minutes: 10
|
||||
|
||||
# =============================================================================
|
||||
# Billing Configuration
|
||||
# =============================================================================
|
||||
billing:
|
||||
circuit_breaker:
|
||||
enabled: true
|
||||
failure_threshold: 5
|
||||
reset_timeout_seconds: 30
|
||||
half_open_requests: 3
|
||||
|
||||
# =============================================================================
|
||||
# Turnstile Configuration
|
||||
# =============================================================================
|
||||
turnstile:
|
||||
# Require Turnstile in release mode (when enabled, login/register will fail if not configured)
|
||||
required: false
|
||||
|
||||
# =============================================================================
|
||||
# Gemini OAuth (Required for Gemini accounts)
|
||||
# =============================================================================
|
||||
|
||||
@@ -26,14 +26,37 @@ export interface SystemSettings {
|
||||
smtp_host: string
|
||||
smtp_port: number
|
||||
smtp_username: string
|
||||
smtp_password: string
|
||||
smtp_password_configured: boolean
|
||||
smtp_from_email: string
|
||||
smtp_from_name: string
|
||||
smtp_use_tls: boolean
|
||||
// Cloudflare Turnstile settings
|
||||
turnstile_enabled: boolean
|
||||
turnstile_site_key: string
|
||||
turnstile_secret_key: string
|
||||
turnstile_secret_key_configured: boolean
|
||||
}
|
||||
|
||||
export interface UpdateSettingsRequest {
|
||||
registration_enabled?: boolean
|
||||
email_verify_enabled?: boolean
|
||||
default_balance?: number
|
||||
default_concurrency?: number
|
||||
site_name?: string
|
||||
site_logo?: string
|
||||
site_subtitle?: string
|
||||
api_base_url?: string
|
||||
contact_info?: string
|
||||
doc_url?: string
|
||||
smtp_host?: string
|
||||
smtp_port?: number
|
||||
smtp_username?: string
|
||||
smtp_password?: string
|
||||
smtp_from_email?: string
|
||||
smtp_from_name?: string
|
||||
smtp_use_tls?: boolean
|
||||
turnstile_enabled?: boolean
|
||||
turnstile_site_key?: string
|
||||
turnstile_secret_key?: string
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -50,7 +73,7 @@ export async function getSettings(): Promise<SystemSettings> {
|
||||
* @param settings - Partial settings to update
|
||||
* @returns Updated settings
|
||||
*/
|
||||
export async function updateSettings(settings: Partial<SystemSettings>): Promise<SystemSettings> {
|
||||
export async function updateSettings(settings: UpdateSettingsRequest): Promise<SystemSettings> {
|
||||
const { data } = await apiClient.put<SystemSettings>('/admin/settings', settings)
|
||||
return data
|
||||
}
|
||||
|
||||
@@ -62,8 +62,24 @@ apiClient.interceptors.response.use(
|
||||
|
||||
// 401: Unauthorized - clear token and redirect to login
|
||||
if (status === 401) {
|
||||
const hasToken = !!localStorage.getItem('auth_token')
|
||||
const url = error.config?.url || ''
|
||||
const isAuthEndpoint =
|
||||
url.includes('/auth/login') || url.includes('/auth/register') || url.includes('/auth/refresh')
|
||||
const headers = error.config?.headers as Record<string, unknown> | undefined
|
||||
const authHeader = headers?.Authorization ?? headers?.authorization
|
||||
const sentAuth =
|
||||
typeof authHeader === 'string'
|
||||
? authHeader.trim() !== ''
|
||||
: Array.isArray(authHeader)
|
||||
? authHeader.length > 0
|
||||
: !!authHeader
|
||||
|
||||
localStorage.removeItem('auth_token')
|
||||
localStorage.removeItem('auth_user')
|
||||
if ((hasToken || sentAuth) && !isAuthEndpoint) {
|
||||
sessionStorage.setItem('auth_expired', '1')
|
||||
}
|
||||
// Only redirect if not already on login page
|
||||
if (!window.location.pathname.includes('/login')) {
|
||||
window.location.href = '/login'
|
||||
|
||||
@@ -136,16 +136,16 @@
|
||||
<ol
|
||||
class="list-inside list-decimal space-y-1 text-xs text-amber-700 dark:text-amber-300"
|
||||
>
|
||||
<li v-html="t('admin.accounts.oauth.step1')"></li>
|
||||
<li v-html="t('admin.accounts.oauth.step2')"></li>
|
||||
<li v-html="t('admin.accounts.oauth.step3')"></li>
|
||||
<li v-html="t('admin.accounts.oauth.step4')"></li>
|
||||
<li v-html="t('admin.accounts.oauth.step5')"></li>
|
||||
<li v-html="t('admin.accounts.oauth.step6')"></li>
|
||||
<li>{{ t('admin.accounts.oauth.step1') }}</li>
|
||||
<li>{{ t('admin.accounts.oauth.step2') }}</li>
|
||||
<li>{{ t('admin.accounts.oauth.step3') }}</li>
|
||||
<li>{{ t('admin.accounts.oauth.step4') }}</li>
|
||||
<li>{{ t('admin.accounts.oauth.step5') }}</li>
|
||||
<li>{{ t('admin.accounts.oauth.step6') }}</li>
|
||||
</ol>
|
||||
<p
|
||||
class="mt-2 text-xs text-amber-600 dark:text-amber-400"
|
||||
v-html="t('admin.accounts.oauth.sessionKeyFormat')"
|
||||
v-text="t('admin.accounts.oauth.sessionKeyFormat')"
|
||||
></p>
|
||||
</div>
|
||||
|
||||
@@ -390,7 +390,7 @@
|
||||
>
|
||||
<p
|
||||
class="text-xs text-amber-800 dark:text-amber-300"
|
||||
v-html="oauthImportantNotice"
|
||||
v-text="oauthImportantNotice"
|
||||
></p>
|
||||
</div>
|
||||
<!-- Proxy Warning (for non-OpenAI) -->
|
||||
@@ -400,7 +400,7 @@
|
||||
>
|
||||
<p
|
||||
class="text-xs text-yellow-800 dark:text-yellow-300"
|
||||
v-html="t('admin.accounts.oauth.proxyWarning')"
|
||||
v-text="t('admin.accounts.oauth.proxyWarning')"
|
||||
></p>
|
||||
</div>
|
||||
</div>
|
||||
@@ -423,7 +423,7 @@
|
||||
</p>
|
||||
<p
|
||||
class="mb-3 text-sm text-blue-700 dark:text-blue-300"
|
||||
v-html="oauthAuthCodeDesc"
|
||||
v-text="oauthAuthCodeDesc"
|
||||
></p>
|
||||
<div>
|
||||
<label class="input-label">
|
||||
|
||||
@@ -85,7 +85,7 @@
|
||||
</button>
|
||||
</div>
|
||||
<!-- Code Content -->
|
||||
<pre class="p-4 text-sm font-mono text-gray-100 overflow-x-auto"><code v-html="file.highlighted"></code></pre>
|
||||
<pre class="p-4 text-sm font-mono text-gray-100 overflow-x-auto"><code v-text="file.content"></code></pre>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -142,7 +142,6 @@ interface TabConfig {
|
||||
interface FileConfig {
|
||||
path: string
|
||||
content: string
|
||||
highlighted: string
|
||||
hint?: string // Optional hint message for this file
|
||||
}
|
||||
|
||||
@@ -227,13 +226,6 @@ const platformNote = computed(() => {
|
||||
})
|
||||
|
||||
// Syntax highlighting helpers
|
||||
const keyword = (text: string) => `<span class="text-purple-400">${text}</span>`
|
||||
const variable = (text: string) => `<span class="text-cyan-400">${text}</span>`
|
||||
const string = (text: string) => `<span class="text-green-400">${text}</span>`
|
||||
const operator = (text: string) => `<span class="text-yellow-400">${text}</span>`
|
||||
const comment = (text: string) => `<span class="text-gray-500">${text}</span>`
|
||||
const key = (text: string) => `<span class="text-blue-400">${text}</span>`
|
||||
|
||||
// Generate file configs based on platform and active tab
|
||||
const currentFiles = computed((): FileConfig[] => {
|
||||
const baseUrl = props.baseUrl || window.location.origin
|
||||
@@ -249,37 +241,29 @@ const currentFiles = computed((): FileConfig[] => {
|
||||
function generateAnthropicFiles(baseUrl: string, apiKey: string): FileConfig[] {
|
||||
let path: string
|
||||
let content: string
|
||||
let highlighted: string
|
||||
|
||||
switch (activeTab.value) {
|
||||
case 'unix':
|
||||
path = 'Terminal'
|
||||
content = `export ANTHROPIC_BASE_URL="${baseUrl}"
|
||||
export ANTHROPIC_AUTH_TOKEN="${apiKey}"`
|
||||
highlighted = `${keyword('export')} ${variable('ANTHROPIC_BASE_URL')}${operator('=')}${string(`"${baseUrl}"`)}
|
||||
${keyword('export')} ${variable('ANTHROPIC_AUTH_TOKEN')}${operator('=')}${string(`"${apiKey}"`)}`
|
||||
break
|
||||
case 'cmd':
|
||||
path = 'Command Prompt'
|
||||
content = `set ANTHROPIC_BASE_URL=${baseUrl}
|
||||
set ANTHROPIC_AUTH_TOKEN=${apiKey}`
|
||||
highlighted = `${keyword('set')} ${variable('ANTHROPIC_BASE_URL')}${operator('=')}${baseUrl}
|
||||
${keyword('set')} ${variable('ANTHROPIC_AUTH_TOKEN')}${operator('=')}${apiKey}`
|
||||
break
|
||||
case 'powershell':
|
||||
path = 'PowerShell'
|
||||
content = `$env:ANTHROPIC_BASE_URL="${baseUrl}"
|
||||
$env:ANTHROPIC_AUTH_TOKEN="${apiKey}"`
|
||||
highlighted = `${keyword('$env:')}${variable('ANTHROPIC_BASE_URL')}${operator('=')}${string(`"${baseUrl}"`)}
|
||||
${keyword('$env:')}${variable('ANTHROPIC_AUTH_TOKEN')}${operator('=')}${string(`"${apiKey}"`)}`
|
||||
break
|
||||
default:
|
||||
path = 'Terminal'
|
||||
content = ''
|
||||
highlighted = ''
|
||||
}
|
||||
|
||||
return [{ path, content, highlighted }]
|
||||
return [{ path, content }]
|
||||
}
|
||||
|
||||
function generateOpenAIFiles(baseUrl: string, apiKey: string): FileConfig[] {
|
||||
@@ -301,40 +285,20 @@ base_url = "${baseUrl}"
|
||||
wire_api = "responses"
|
||||
requires_openai_auth = true`
|
||||
|
||||
const configHighlighted = `${key('model_provider')} ${operator('=')} ${string('"sub2api"')}
|
||||
${key('model')} ${operator('=')} ${string('"gpt-5.2-codex"')}
|
||||
${key('model_reasoning_effort')} ${operator('=')} ${string('"high"')}
|
||||
${key('network_access')} ${operator('=')} ${string('"enabled"')}
|
||||
${key('disable_response_storage')} ${operator('=')} ${keyword('true')}
|
||||
${key('windows_wsl_setup_acknowledged')} ${operator('=')} ${keyword('true')}
|
||||
${key('model_verbosity')} ${operator('=')} ${string('"high"')}
|
||||
|
||||
${comment('[model_providers.sub2api]')}
|
||||
${key('name')} ${operator('=')} ${string('"sub2api"')}
|
||||
${key('base_url')} ${operator('=')} ${string(`"${baseUrl}"`)}
|
||||
${key('wire_api')} ${operator('=')} ${string('"responses"')}
|
||||
${key('requires_openai_auth')} ${operator('=')} ${keyword('true')}`
|
||||
|
||||
// auth.json content
|
||||
const authContent = `{
|
||||
"OPENAI_API_KEY": "${apiKey}"
|
||||
}`
|
||||
|
||||
const authHighlighted = `{
|
||||
${key('"OPENAI_API_KEY"')}: ${string(`"${apiKey}"`)}
|
||||
}`
|
||||
|
||||
return [
|
||||
{
|
||||
path: `${configDir}/config.toml`,
|
||||
content: configContent,
|
||||
highlighted: configHighlighted,
|
||||
hint: t('keys.useKeyModal.openai.configTomlHint')
|
||||
},
|
||||
{
|
||||
path: `${configDir}/auth.json`,
|
||||
content: authContent,
|
||||
highlighted: authHighlighted
|
||||
content: authContent
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -63,6 +63,7 @@
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { getPublicSettings } from '@/api/auth'
|
||||
import { sanitizeUrl } from '@/utils/url'
|
||||
|
||||
const siteName = ref('Sub2API')
|
||||
const siteLogo = ref('')
|
||||
@@ -74,7 +75,7 @@ onMounted(async () => {
|
||||
try {
|
||||
const settings = await getPublicSettings()
|
||||
siteName.value = settings.site_name || 'Sub2API'
|
||||
siteLogo.value = settings.site_logo || ''
|
||||
siteLogo.value = sanitizeUrl(settings.site_logo || '', { allowRelative: true })
|
||||
siteSubtitle.value = settings.site_subtitle || 'Subscription to API Conversion Platform'
|
||||
} catch (error) {
|
||||
console.error('Failed to load public settings:', error)
|
||||
|
||||
@@ -197,6 +197,7 @@ export default {
|
||||
registrationFailed: 'Registration failed. Please try again.',
|
||||
loginSuccess: 'Login successful! Welcome back.',
|
||||
accountCreatedSuccess: 'Account created successfully! Welcome to {siteName}.',
|
||||
reloginRequired: 'Session expired. Please log in again.',
|
||||
turnstileExpired: 'Verification expired, please try again',
|
||||
turnstileFailed: 'Verification failed, please try again',
|
||||
completeVerification: 'Please complete the verification',
|
||||
@@ -991,13 +992,13 @@ export default {
|
||||
'One sessionKey per line, e.g.:\nsk-ant-sid01-xxxxx...\nsk-ant-sid01-yyyyy...',
|
||||
sessionKeyPlaceholderSingle: 'sk-ant-sid01-xxxxx...',
|
||||
howToGetSessionKey: 'How to get sessionKey',
|
||||
step1: 'Login to <strong>claude.ai</strong> in your browser',
|
||||
step2: 'Press <kbd>F12</kbd> to open Developer Tools',
|
||||
step3: 'Go to <strong>Application</strong> tab',
|
||||
step4: 'Find <strong>Cookies</strong> → <strong>https://claude.ai</strong>',
|
||||
step5: 'Find the row with key <strong>sessionKey</strong>',
|
||||
step6: 'Copy the <strong>Value</strong>',
|
||||
sessionKeyFormat: 'sessionKey usually starts with <code>sk-ant-sid01-</code>',
|
||||
step1: 'Login to claude.ai in your browser',
|
||||
step2: 'Press F12 to open Developer Tools',
|
||||
step3: 'Go to Application tab',
|
||||
step4: 'Find Cookies → https://claude.ai',
|
||||
step5: 'Find the row with key sessionKey',
|
||||
step6: 'Copy the Value',
|
||||
sessionKeyFormat: 'sessionKey usually starts with sk-ant-sid01-',
|
||||
startAutoAuth: 'Start Auto-Auth',
|
||||
authorizing: 'Authorizing...',
|
||||
followSteps: 'Follow these steps to authorize your Claude account:',
|
||||
@@ -1009,10 +1010,10 @@ export default {
|
||||
openUrlDesc:
|
||||
'Open the authorization URL in a new tab, log in to your Claude account and authorize.',
|
||||
proxyWarning:
|
||||
'<strong>Note:</strong> If you configured a proxy, make sure your browser uses the same proxy to access the authorization page.',
|
||||
'Note: If you configured a proxy, make sure your browser uses the same proxy to access the authorization page.',
|
||||
step3EnterCode: 'Enter the Authorization Code',
|
||||
authCodeDesc:
|
||||
'After authorization is complete, the page will display an <strong>Authorization Code</strong>. Copy and paste it below:',
|
||||
'After authorization is complete, the page will display an Authorization Code. Copy and paste it below:',
|
||||
authCode: 'Authorization Code',
|
||||
authCodePlaceholder: 'Paste the Authorization Code from Claude page...',
|
||||
authCodeHint: 'Paste the Authorization Code copied from the Claude page',
|
||||
@@ -1033,10 +1034,10 @@ export default {
|
||||
openUrlDesc:
|
||||
'Open the authorization URL in a new tab, log in to your OpenAI account and authorize.',
|
||||
importantNotice:
|
||||
'<strong>Important:</strong> The page may take a while to load after authorization. Please wait patiently. When the browser address bar changes to <code>http://localhost...</code>, the authorization is complete.',
|
||||
'Important: The page may take a while to load after authorization. Please wait patiently. When the browser address bar changes to http://localhost..., the authorization is complete.',
|
||||
step3EnterCode: 'Enter Authorization URL or Code',
|
||||
authCodeDesc:
|
||||
'After authorization is complete, when the page URL becomes <code>http://localhost:xxx/auth/callback?code=...</code>:',
|
||||
'After authorization is complete, when the page URL becomes http://localhost:xxx/auth/callback?code=...:',
|
||||
authCode: 'Authorization URL or Code',
|
||||
authCodePlaceholder:
|
||||
'Option 1: Copy the complete URL\n(http://localhost:xxx/auth/callback?code=...)\nOption 2: Copy only the code parameter value',
|
||||
@@ -1059,7 +1060,7 @@ export default {
|
||||
'Open the authorization URL in a new tab, log in to your Google account and authorize.',
|
||||
step3EnterCode: 'Enter Authorization URL or Code',
|
||||
authCodeDesc:
|
||||
'After authorization, copy the callback URL (recommended) or just the <code>code</code> and paste it below.',
|
||||
'After authorization, copy the callback URL (recommended) or just the code and paste it below.',
|
||||
authCode: 'Callback URL or Code',
|
||||
authCodePlaceholder:
|
||||
'Option 1 (recommended): Paste the callback URL\nOption 2: Paste only the code value',
|
||||
@@ -1100,10 +1101,10 @@ export default {
|
||||
step2OpenUrl: 'Open the URL in your browser and complete authorization',
|
||||
openUrlDesc: 'Open the authorization URL in a new tab, log in to your Google account and authorize.',
|
||||
importantNotice:
|
||||
'<strong>Important:</strong> The page may take a while to load after authorization. Please wait patiently. When the browser address bar shows <code>http://localhost...</code>, authorization is complete.',
|
||||
'Important: The page may take a while to load after authorization. Please wait patiently. When the browser address bar shows http://localhost..., authorization is complete.',
|
||||
step3EnterCode: 'Enter Authorization URL or Code',
|
||||
authCodeDesc:
|
||||
'After authorization, when the page URL becomes <code>http://localhost:xxx/auth/callback?code=...</code>:',
|
||||
'After authorization, when the page URL becomes http://localhost:xxx/auth/callback?code=...:',
|
||||
authCode: 'Authorization URL or Code',
|
||||
authCodePlaceholder:
|
||||
'Option 1: Copy the complete URL\n(http://localhost:xxx/auth/callback?code=...)\nOption 2: Copy only the code parameter value',
|
||||
@@ -1377,7 +1378,8 @@ export default {
|
||||
siteKey: 'Site Key',
|
||||
secretKey: 'Secret Key',
|
||||
siteKeyHint: 'Get this from your Cloudflare Dashboard',
|
||||
secretKeyHint: 'Server-side verification key (keep this secret)'
|
||||
secretKeyHint: 'Server-side verification key (keep this secret)',
|
||||
secretKeyConfiguredHint: 'Secret key configured. Leave empty to keep the current value.'
|
||||
},
|
||||
defaults: {
|
||||
title: 'Default User Settings',
|
||||
@@ -1428,6 +1430,8 @@ export default {
|
||||
password: 'SMTP Password',
|
||||
passwordPlaceholder: '********',
|
||||
passwordHint: 'Leave empty to keep existing password',
|
||||
passwordConfiguredPlaceholder: '********',
|
||||
passwordConfiguredHint: 'Password configured. Leave empty to keep the current value.',
|
||||
fromEmail: 'From Email',
|
||||
fromEmailPlaceholder: "noreply{'@'}example.com",
|
||||
fromName: 'From Name',
|
||||
|
||||
@@ -194,6 +194,7 @@ export default {
|
||||
registrationFailed: '注册失败,请重试。',
|
||||
loginSuccess: '登录成功!欢迎回来。',
|
||||
accountCreatedSuccess: '账户创建成功!欢迎使用 {siteName}。',
|
||||
reloginRequired: '会话已过期,请重新登录。',
|
||||
turnstileExpired: '验证已过期,请重试',
|
||||
turnstileFailed: '验证失败,请重试',
|
||||
completeVerification: '请完成验证',
|
||||
@@ -1137,13 +1138,13 @@ export default {
|
||||
'每行一个 sessionKey,例如:\nsk-ant-sid01-xxxxx...\nsk-ant-sid01-yyyyy...',
|
||||
sessionKeyPlaceholderSingle: 'sk-ant-sid01-xxxxx...',
|
||||
howToGetSessionKey: '如何获取 sessionKey',
|
||||
step1: '在浏览器中登录 <strong>claude.ai</strong>',
|
||||
step2: '按 <kbd>F12</kbd> 打开开发者工具',
|
||||
step3: '切换到 <strong>Application</strong> 标签',
|
||||
step4: '找到 <strong>Cookies</strong> → <strong>https://claude.ai</strong>',
|
||||
step5: '找到 <strong>sessionKey</strong> 所在行',
|
||||
step6: '复制 <strong>Value</strong> 列的值',
|
||||
sessionKeyFormat: 'sessionKey 通常以 <code>sk-ant-sid01-</code> 开头',
|
||||
step1: '在浏览器中登录 claude.ai',
|
||||
step2: '按 F12 打开开发者工具',
|
||||
step3: '切换到 Application 标签',
|
||||
step4: '找到 Cookies → https://claude.ai',
|
||||
step5: '找到 sessionKey 所在行',
|
||||
step6: '复制 Value 列的值',
|
||||
sessionKeyFormat: 'sessionKey 通常以 sk-ant-sid01- 开头',
|
||||
startAutoAuth: '开始自动授权',
|
||||
authorizing: '授权中...',
|
||||
followSteps: '按照以下步骤授权您的 Claude 账号:',
|
||||
@@ -1154,9 +1155,9 @@ export default {
|
||||
step2OpenUrl: '在浏览器中打开 URL 并完成授权',
|
||||
openUrlDesc: '在新标签页中打开授权 URL,登录您的 Claude 账号并授权。',
|
||||
proxyWarning:
|
||||
'<strong>注意:</strong>如果您配置了代理,请确保浏览器使用相同的代理访问授权页面。',
|
||||
'注意:如果您配置了代理,请确保浏览器使用相同的代理访问授权页面。',
|
||||
step3EnterCode: '输入授权码',
|
||||
authCodeDesc: '授权完成后,页面会显示一个 <strong>授权码</strong>。复制并粘贴到下方:',
|
||||
authCodeDesc: '授权完成后,页面会显示一个授权码。复制并粘贴到下方:',
|
||||
authCode: '授权码',
|
||||
authCodePlaceholder: '粘贴 Claude 页面的授权码...',
|
||||
authCodeHint: '粘贴从 Claude 页面复制的授权码',
|
||||
@@ -1176,10 +1177,10 @@ export default {
|
||||
step2OpenUrl: '在浏览器中打开链接并完成授权',
|
||||
openUrlDesc: '请在新标签页中打开授权链接,登录您的 OpenAI 账户并授权。',
|
||||
importantNotice:
|
||||
'<strong>重要提示:</strong>授权后页面可能会加载较长时间,请耐心等待。当浏览器地址栏变为 <code>http://localhost...</code> 开头时,表示授权已完成。',
|
||||
'重要提示:授权后页面可能会加载较长时间,请耐心等待。当浏览器地址栏变为 http://localhost... 开头时,表示授权已完成。',
|
||||
step3EnterCode: '输入授权链接或 Code',
|
||||
authCodeDesc:
|
||||
'授权完成后,当页面地址变为 <code>http://localhost:xxx/auth/callback?code=...</code> 时:',
|
||||
'授权完成后,当页面地址变为 http://localhost:xxx/auth/callback?code=... 时:',
|
||||
authCode: '授权链接或 Code',
|
||||
authCodePlaceholder:
|
||||
'方式1:复制完整的链接\n(http://localhost:xxx/auth/callback?code=...)\n方式2:仅复制 code 参数的值',
|
||||
@@ -1198,7 +1199,7 @@ export default {
|
||||
step2OpenUrl: '在浏览器中打开链接并完成授权',
|
||||
openUrlDesc: '请在新标签页中打开授权链接,登录您的 Google 账户并授权。',
|
||||
step3EnterCode: '输入回调链接或 Code',
|
||||
authCodeDesc: '授权完成后,复制浏览器跳转后的回调链接(推荐)或仅复制 <code>code</code>,粘贴到下方即可。',
|
||||
authCodeDesc: '授权完成后,复制浏览器跳转后的回调链接(推荐)或仅复制 code,粘贴到下方即可。',
|
||||
authCode: '回调链接或 Code',
|
||||
authCodePlaceholder: '方式1(推荐):粘贴回调链接\n方式2:仅粘贴 code 参数的值',
|
||||
authCodeHint: '系统会自动从链接中解析 code/state。',
|
||||
@@ -1233,10 +1234,10 @@ export default {
|
||||
step2OpenUrl: '在浏览器中打开链接并完成授权',
|
||||
openUrlDesc: '请在新标签页中打开授权链接,登录您的 Google 账户并授权。',
|
||||
importantNotice:
|
||||
'<strong>重要提示:</strong>授权后页面可能会加载较长时间,请耐心等待。当浏览器地址栏变为 <code>http://localhost...</code> 开头时,表示授权已完成。',
|
||||
'重要提示:授权后页面可能会加载较长时间,请耐心等待。当浏览器地址栏变为 http://localhost... 开头时,表示授权已完成。',
|
||||
step3EnterCode: '输入授权链接或 Code',
|
||||
authCodeDesc:
|
||||
'授权完成后,当页面地址变为 <code>http://localhost:xxx/auth/callback?code=...</code> 时:',
|
||||
'授权完成后,当页面地址变为 http://localhost:xxx/auth/callback?code=... 时:',
|
||||
authCode: '授权链接或 Code',
|
||||
authCodePlaceholder:
|
||||
'方式1:复制完整的链接\n(http://localhost:xxx/auth/callback?code=...)\n方式2:仅复制 code 参数的值',
|
||||
@@ -1576,7 +1577,8 @@ export default {
|
||||
siteKey: '站点密钥',
|
||||
secretKey: '私密密钥',
|
||||
siteKeyHint: '从 Cloudflare Dashboard 获取',
|
||||
secretKeyHint: '服务端验证密钥(请保密)'
|
||||
secretKeyHint: '服务端验证密钥(请保密)',
|
||||
secretKeyConfiguredHint: '密钥已配置,留空以保留当前值。'
|
||||
},
|
||||
defaults: {
|
||||
title: '用户默认设置',
|
||||
@@ -1626,6 +1628,8 @@ export default {
|
||||
password: 'SMTP 密码',
|
||||
passwordPlaceholder: '********',
|
||||
passwordHint: '留空以保留现有密码',
|
||||
passwordConfiguredPlaceholder: '********',
|
||||
passwordConfiguredHint: '密码已配置,留空以保留当前值。',
|
||||
fromEmail: '发件人邮箱',
|
||||
fromEmailPlaceholder: "noreply{'@'}example.com",
|
||||
fromName: '发件人名称',
|
||||
|
||||
37
frontend/src/utils/url.ts
Normal file
37
frontend/src/utils/url.ts
Normal file
@@ -0,0 +1,37 @@
|
||||
/**
|
||||
* 验证并规范化 URL
|
||||
* 默认只接受绝对 URL(以 http:// 或 https:// 开头),可按需允许相对路径
|
||||
* @param value 用户输入的 URL
|
||||
* @returns 规范化后的 URL,如果无效则返回空字符串
|
||||
*/
|
||||
type SanitizeOptions = {
|
||||
allowRelative?: boolean
|
||||
}
|
||||
|
||||
export function sanitizeUrl(value: string, options: SanitizeOptions = {}): string {
|
||||
const trimmed = value.trim()
|
||||
if (!trimmed) {
|
||||
return ''
|
||||
}
|
||||
|
||||
if (options.allowRelative && trimmed.startsWith('/')) {
|
||||
return trimmed
|
||||
}
|
||||
|
||||
// 只接受绝对 URL,不使用 base URL 来避免相对路径被解析为当前域名
|
||||
// 检查是否以 http:// 或 https:// 开头
|
||||
if (!trimmed.match(/^https?:\/\//i)) {
|
||||
return ''
|
||||
}
|
||||
|
||||
try {
|
||||
const parsed = new URL(trimmed)
|
||||
const protocol = parsed.protocol.toLowerCase()
|
||||
if (protocol !== 'http:' && protocol !== 'https:') {
|
||||
return ''
|
||||
}
|
||||
return parsed.toString()
|
||||
} catch {
|
||||
return ''
|
||||
}
|
||||
}
|
||||
@@ -493,6 +493,7 @@ import { useI18n } from 'vue-i18n'
|
||||
import { getPublicSettings } from '@/api/auth'
|
||||
import { useAuthStore } from '@/stores'
|
||||
import LocaleSwitcher from '@/components/common/LocaleSwitcher.vue'
|
||||
import { sanitizeUrl } from '@/utils/url'
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
@@ -549,9 +550,9 @@ onMounted(async () => {
|
||||
try {
|
||||
const settings = await getPublicSettings()
|
||||
siteName.value = settings.site_name || 'Sub2API'
|
||||
siteLogo.value = settings.site_logo || ''
|
||||
siteLogo.value = sanitizeUrl(settings.site_logo || '', { allowRelative: true })
|
||||
siteSubtitle.value = settings.site_subtitle || 'AI API Gateway Platform'
|
||||
docUrl.value = settings.doc_url || ''
|
||||
docUrl.value = sanitizeUrl(settings.doc_url || '', { allowRelative: true })
|
||||
} catch (error) {
|
||||
console.error('Failed to load public settings:', error)
|
||||
}
|
||||
|
||||
@@ -255,7 +255,11 @@
|
||||
placeholder="0x4AAAAAAA..."
|
||||
/>
|
||||
<p class="mt-1.5 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.settings.turnstile.secretKeyHint') }}
|
||||
{{
|
||||
form.turnstile_secret_key_configured
|
||||
? t('admin.settings.turnstile.secretKeyConfiguredHint')
|
||||
: t('admin.settings.turnstile.secretKeyHint')
|
||||
}}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
@@ -577,10 +581,18 @@
|
||||
v-model="form.smtp_password"
|
||||
type="password"
|
||||
class="input"
|
||||
:placeholder="t('admin.settings.smtp.passwordPlaceholder')"
|
||||
:placeholder="
|
||||
form.smtp_password_configured
|
||||
? t('admin.settings.smtp.passwordConfiguredPlaceholder')
|
||||
: t('admin.settings.smtp.passwordPlaceholder')
|
||||
"
|
||||
/>
|
||||
<p class="mt-1.5 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.settings.smtp.passwordHint') }}
|
||||
{{
|
||||
form.smtp_password_configured
|
||||
? t('admin.settings.smtp.passwordConfiguredHint')
|
||||
: t('admin.settings.smtp.passwordHint')
|
||||
}}
|
||||
</p>
|
||||
</div>
|
||||
<div>
|
||||
@@ -713,7 +725,7 @@
|
||||
import { ref, reactive, onMounted } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { adminAPI } from '@/api'
|
||||
import type { SystemSettings } from '@/api/admin/settings'
|
||||
import type { SystemSettings, UpdateSettingsRequest } from '@/api/admin/settings'
|
||||
import AppLayout from '@/components/layout/AppLayout.vue'
|
||||
import Toggle from '@/components/common/Toggle.vue'
|
||||
import { useAppStore } from '@/stores'
|
||||
@@ -735,7 +747,12 @@ const adminApiKeyMasked = ref('')
|
||||
const adminApiKeyOperating = ref(false)
|
||||
const newAdminApiKey = ref('')
|
||||
|
||||
const form = reactive<SystemSettings>({
|
||||
type SettingsForm = SystemSettings & {
|
||||
smtp_password: string
|
||||
turnstile_secret_key: string
|
||||
}
|
||||
|
||||
const form = reactive<SettingsForm>({
|
||||
registration_enabled: true,
|
||||
email_verify_enabled: false,
|
||||
default_balance: 0,
|
||||
@@ -750,13 +767,15 @@ const form = reactive<SystemSettings>({
|
||||
smtp_port: 587,
|
||||
smtp_username: '',
|
||||
smtp_password: '',
|
||||
smtp_password_configured: false,
|
||||
smtp_from_email: '',
|
||||
smtp_from_name: '',
|
||||
smtp_use_tls: true,
|
||||
// Cloudflare Turnstile
|
||||
turnstile_enabled: false,
|
||||
turnstile_site_key: '',
|
||||
turnstile_secret_key: ''
|
||||
turnstile_secret_key: '',
|
||||
turnstile_secret_key_configured: false
|
||||
})
|
||||
|
||||
function handleLogoUpload(event: Event) {
|
||||
@@ -802,6 +821,8 @@ async function loadSettings() {
|
||||
try {
|
||||
const settings = await adminAPI.settings.getSettings()
|
||||
Object.assign(form, settings)
|
||||
form.smtp_password = ''
|
||||
form.turnstile_secret_key = ''
|
||||
} catch (error: any) {
|
||||
appStore.showError(
|
||||
t('admin.settings.failedToLoad') + ': ' + (error.message || t('common.unknownError'))
|
||||
@@ -814,7 +835,32 @@ async function loadSettings() {
|
||||
async function saveSettings() {
|
||||
saving.value = true
|
||||
try {
|
||||
await adminAPI.settings.updateSettings(form)
|
||||
const payload: UpdateSettingsRequest = {
|
||||
registration_enabled: form.registration_enabled,
|
||||
email_verify_enabled: form.email_verify_enabled,
|
||||
default_balance: form.default_balance,
|
||||
default_concurrency: form.default_concurrency,
|
||||
site_name: form.site_name,
|
||||
site_logo: form.site_logo,
|
||||
site_subtitle: form.site_subtitle,
|
||||
api_base_url: form.api_base_url,
|
||||
contact_info: form.contact_info,
|
||||
doc_url: form.doc_url,
|
||||
smtp_host: form.smtp_host,
|
||||
smtp_port: form.smtp_port,
|
||||
smtp_username: form.smtp_username,
|
||||
smtp_password: form.smtp_password || undefined,
|
||||
smtp_from_email: form.smtp_from_email,
|
||||
smtp_from_name: form.smtp_from_name,
|
||||
smtp_use_tls: form.smtp_use_tls,
|
||||
turnstile_enabled: form.turnstile_enabled,
|
||||
turnstile_site_key: form.turnstile_site_key,
|
||||
turnstile_secret_key: form.turnstile_secret_key || undefined
|
||||
}
|
||||
const updated = await adminAPI.settings.updateSettings(payload)
|
||||
Object.assign(form, updated)
|
||||
form.smtp_password = ''
|
||||
form.turnstile_secret_key = ''
|
||||
// Refresh cached public settings so sidebar/header update immediately
|
||||
await appStore.fetchPublicSettings(true)
|
||||
appStore.showSuccess(t('admin.settings.settingsSaved'))
|
||||
|
||||
@@ -277,6 +277,14 @@ const errors = reactive({
|
||||
// ==================== Lifecycle ====================
|
||||
|
||||
onMounted(async () => {
|
||||
const expiredFlag = sessionStorage.getItem('auth_expired')
|
||||
if (expiredFlag) {
|
||||
sessionStorage.removeItem('auth_expired')
|
||||
const message = t('auth.reloginRequired')
|
||||
errorMessage.value = message
|
||||
appStore.showWarning(message)
|
||||
}
|
||||
|
||||
try {
|
||||
const settings = await getPublicSettings()
|
||||
turnstileEnabled.value = settings.turnstile_enabled
|
||||
|
||||
Reference in New Issue
Block a user