package config import ( "fmt" "strings" "github.com/spf13/viper" ) type Config struct { Server ServerConfig `mapstructure:"server"` Database DatabaseConfig `mapstructure:"database"` Redis RedisConfig `mapstructure:"redis"` JWT JWTConfig `mapstructure:"jwt"` Default DefaultConfig `mapstructure:"default"` RateLimit RateLimitConfig `mapstructure:"rate_limit"` Pricing PricingConfig `mapstructure:"pricing"` Gateway GatewayConfig `mapstructure:"gateway"` Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" } type PricingConfig struct { // 价格数据远程URL(默认使用LiteLLM镜像) RemoteURL string `mapstructure:"remote_url"` // 哈希校验文件URL HashURL string `mapstructure:"hash_url"` // 本地数据目录 DataDir string `mapstructure:"data_dir"` // 回退文件路径 FallbackFile string `mapstructure:"fallback_file"` // 更新间隔(小时) UpdateIntervalHours int `mapstructure:"update_interval_hours"` // 哈希校验间隔(分钟) HashCheckIntervalMinutes int `mapstructure:"hash_check_interval_minutes"` } type ServerConfig struct { Host string `mapstructure:"host"` Port int `mapstructure:"port"` Mode string `mapstructure:"mode"` // debug/release ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒) IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒) } // GatewayConfig API网关相关配置 type GatewayConfig struct { // 等待上游响应头的超时时间(秒),0表示无超时 // 注意:这不影响流式数据传输,只控制等待响应头的时间 ResponseHeaderTimeout int `mapstructure:"response_header_timeout"` } func (s *ServerConfig) Address() string { return fmt.Sprintf("%s:%d", s.Host, s.Port) } type DatabaseConfig struct { Host string `mapstructure:"host"` Port int `mapstructure:"port"` User string `mapstructure:"user"` Password string `mapstructure:"password"` DBName string `mapstructure:"dbname"` SSLMode string `mapstructure:"sslmode"` } func (d *DatabaseConfig) DSN() string { return fmt.Sprintf( "host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode, ) } // DSNWithTimezone returns DSN with timezone setting func (d *DatabaseConfig) DSNWithTimezone(tz string) string { if tz == "" { tz = "Asia/Shanghai" } return fmt.Sprintf( "host=%s port=%d user=%s password=%s dbname=%s sslmode=%s TimeZone=%s", d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode, tz, ) } type RedisConfig struct { Host string `mapstructure:"host"` Port int `mapstructure:"port"` Password string `mapstructure:"password"` DB int `mapstructure:"db"` } func (r *RedisConfig) Address() string { return fmt.Sprintf("%s:%d", r.Host, r.Port) } type JWTConfig struct { Secret string `mapstructure:"secret"` ExpireHour int `mapstructure:"expire_hour"` } type DefaultConfig struct { AdminEmail string `mapstructure:"admin_email"` AdminPassword string `mapstructure:"admin_password"` UserConcurrency int `mapstructure:"user_concurrency"` UserBalance float64 `mapstructure:"user_balance"` ApiKeyPrefix string `mapstructure:"api_key_prefix"` RateMultiplier float64 `mapstructure:"rate_multiplier"` } type RateLimitConfig struct { OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟) } func Load() (*Config, error) { viper.SetConfigName("config") viper.SetConfigType("yaml") viper.AddConfigPath(".") viper.AddConfigPath("./config") viper.AddConfigPath("/etc/sub2api") // 环境变量支持 viper.AutomaticEnv() viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) // 默认值 setDefaults() if err := viper.ReadInConfig(); err != nil { if _, ok := err.(viper.ConfigFileNotFoundError); !ok { return nil, fmt.Errorf("read config error: %w", err) } // 配置文件不存在时使用默认值 } var cfg Config if err := viper.Unmarshal(&cfg); err != nil { return nil, fmt.Errorf("unmarshal config error: %w", err) } if err := cfg.Validate(); err != nil { return nil, fmt.Errorf("validate config error: %w", err) } return &cfg, nil } func setDefaults() { // Server viper.SetDefault("server.host", "0.0.0.0") viper.SetDefault("server.port", 8080) viper.SetDefault("server.mode", "debug") viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头 viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时 // Database viper.SetDefault("database.host", "localhost") viper.SetDefault("database.port", 5432) viper.SetDefault("database.user", "postgres") viper.SetDefault("database.password", "postgres") viper.SetDefault("database.dbname", "sub2api") viper.SetDefault("database.sslmode", "disable") // Redis viper.SetDefault("redis.host", "localhost") viper.SetDefault("redis.port", 6379) viper.SetDefault("redis.password", "") viper.SetDefault("redis.db", 0) // JWT viper.SetDefault("jwt.secret", "change-me-in-production") viper.SetDefault("jwt.expire_hour", 24) // Default viper.SetDefault("default.admin_email", "admin@sub2api.com") viper.SetDefault("default.admin_password", "admin123") viper.SetDefault("default.user_concurrency", 5) viper.SetDefault("default.user_balance", 0) viper.SetDefault("default.api_key_prefix", "sk-") viper.SetDefault("default.rate_multiplier", 1.0) // RateLimit viper.SetDefault("rate_limit.overload_cooldown_minutes", 10) // Pricing - 从 price-mirror 分支同步,该分支维护了 sha256 哈希文件用于增量更新检查 viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.json") viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.sha256") viper.SetDefault("pricing.data_dir", "./data") viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json") viper.SetDefault("pricing.update_interval_hours", 24) viper.SetDefault("pricing.hash_check_interval_minutes", 10) // Timezone (default to Asia/Shanghai for Chinese users) viper.SetDefault("timezone", "Asia/Shanghai") // Gateway viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久 } func (c *Config) Validate() error { if c.JWT.Secret == "" { return fmt.Errorf("jwt.secret is required") } if c.JWT.Secret == "change-me-in-production" && c.Server.Mode == "release" { return fmt.Errorf("jwt.secret must be changed in production") } return nil }