Files
yinghuoapi/backend/internal/config/config.go
shaw e769f67699 fix(setup): 支持从配置文件读取 Setup Wizard 监听地址
Setup Wizard 之前硬编码使用 8080 端口,现在支持从 config.yaml 或
环境变量 (SERVER_HOST, SERVER_PORT) 读取监听地址,方便用户在端口
被占用时使用其他地址启动初始化向导。
2025-12-19 11:21:58 +08:00

232 lines
7.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package config
import (
"fmt"
"strings"
"github.com/spf13/viper"
)
type Config struct {
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
JWT JWTConfig `mapstructure:"jwt"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
}
type PricingConfig struct {
// 价格数据远程URL默认使用LiteLLM镜像
RemoteURL string `mapstructure:"remote_url"`
// 哈希校验文件URL
HashURL string `mapstructure:"hash_url"`
// 本地数据目录
DataDir string `mapstructure:"data_dir"`
// 回退文件路径
FallbackFile string `mapstructure:"fallback_file"`
// 更新间隔(小时)
UpdateIntervalHours int `mapstructure:"update_interval_hours"`
// 哈希校验间隔(分钟)
HashCheckIntervalMinutes int `mapstructure:"hash_check_interval_minutes"`
}
type ServerConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Mode string `mapstructure:"mode"` // debug/release
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
}
// GatewayConfig API网关相关配置
type GatewayConfig struct {
// 等待上游响应头的超时时间0表示无超时
// 注意:这不影响流式数据传输,只控制等待响应头的时间
ResponseHeaderTimeout int `mapstructure:"response_header_timeout"`
}
func (s *ServerConfig) Address() string {
return fmt.Sprintf("%s:%d", s.Host, s.Port)
}
type DatabaseConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
User string `mapstructure:"user"`
Password string `mapstructure:"password"`
DBName string `mapstructure:"dbname"`
SSLMode string `mapstructure:"sslmode"`
}
func (d *DatabaseConfig) DSN() string {
return fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode,
)
}
// DSNWithTimezone returns DSN with timezone setting
func (d *DatabaseConfig) DSNWithTimezone(tz string) string {
if tz == "" {
tz = "Asia/Shanghai"
}
return fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s TimeZone=%s",
d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode, tz,
)
}
type RedisConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Password string `mapstructure:"password"`
DB int `mapstructure:"db"`
}
func (r *RedisConfig) Address() string {
return fmt.Sprintf("%s:%d", r.Host, r.Port)
}
type JWTConfig struct {
Secret string `mapstructure:"secret"`
ExpireHour int `mapstructure:"expire_hour"`
}
type DefaultConfig struct {
AdminEmail string `mapstructure:"admin_email"`
AdminPassword string `mapstructure:"admin_password"`
UserConcurrency int `mapstructure:"user_concurrency"`
UserBalance float64 `mapstructure:"user_balance"`
ApiKeyPrefix string `mapstructure:"api_key_prefix"`
RateMultiplier float64 `mapstructure:"rate_multiplier"`
}
type RateLimitConfig struct {
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
}
func Load() (*Config, error) {
viper.SetConfigName("config")
viper.SetConfigType("yaml")
viper.AddConfigPath(".")
viper.AddConfigPath("./config")
viper.AddConfigPath("/etc/sub2api")
// 环境变量支持
viper.AutomaticEnv()
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
// 默认值
setDefaults()
if err := viper.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
return nil, fmt.Errorf("read config error: %w", err)
}
// 配置文件不存在时使用默认值
}
var cfg Config
if err := viper.Unmarshal(&cfg); err != nil {
return nil, fmt.Errorf("unmarshal config error: %w", err)
}
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("validate config error: %w", err)
}
return &cfg, nil
}
func setDefaults() {
// Server
viper.SetDefault("server.host", "0.0.0.0")
viper.SetDefault("server.port", 8080)
viper.SetDefault("server.mode", "debug")
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
// Database
viper.SetDefault("database.host", "localhost")
viper.SetDefault("database.port", 5432)
viper.SetDefault("database.user", "postgres")
viper.SetDefault("database.password", "postgres")
viper.SetDefault("database.dbname", "sub2api")
viper.SetDefault("database.sslmode", "disable")
// Redis
viper.SetDefault("redis.host", "localhost")
viper.SetDefault("redis.port", 6379)
viper.SetDefault("redis.password", "")
viper.SetDefault("redis.db", 0)
// JWT
viper.SetDefault("jwt.secret", "change-me-in-production")
viper.SetDefault("jwt.expire_hour", 24)
// Default
viper.SetDefault("default.admin_email", "admin@sub2api.com")
viper.SetDefault("default.admin_password", "admin123")
viper.SetDefault("default.user_concurrency", 5)
viper.SetDefault("default.user_balance", 0)
viper.SetDefault("default.api_key_prefix", "sk-")
viper.SetDefault("default.rate_multiplier", 1.0)
// RateLimit
viper.SetDefault("rate_limit.overload_cooldown_minutes", 10)
// Pricing - 从 price-mirror 分支同步,该分支维护了 sha256 哈希文件用于增量更新检查
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.json")
viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.sha256")
viper.SetDefault("pricing.data_dir", "./data")
viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json")
viper.SetDefault("pricing.update_interval_hours", 24)
viper.SetDefault("pricing.hash_check_interval_minutes", 10)
// Timezone (default to Asia/Shanghai for Chinese users)
viper.SetDefault("timezone", "Asia/Shanghai")
// Gateway
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头LLM高负载时可能排队较久
}
func (c *Config) Validate() error {
if c.JWT.Secret == "" {
return fmt.Errorf("jwt.secret is required")
}
if c.JWT.Secret == "change-me-in-production" && c.Server.Mode == "release" {
return fmt.Errorf("jwt.secret must be changed in production")
}
return nil
}
// GetServerAddress returns the server address (host:port) from config file or environment variable.
// This is a lightweight function that can be used before full config validation,
// such as during setup wizard startup.
// Priority: config.yaml > environment variables > defaults
func GetServerAddress() string {
v := viper.New()
v.SetConfigName("config")
v.SetConfigType("yaml")
v.AddConfigPath(".")
v.AddConfigPath("./config")
v.AddConfigPath("/etc/sub2api")
// Support SERVER_HOST and SERVER_PORT environment variables
v.AutomaticEnv()
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.SetDefault("server.host", "0.0.0.0")
v.SetDefault("server.port", 8080)
// Try to read config file (ignore errors if not found)
_ = v.ReadInConfig()
host := v.GetString("server.host")
port := v.GetInt("server.port")
return fmt.Sprintf("%s:%d", host, port)
}