First commit
This commit is contained in:
205
backend/internal/config/config.go
Normal file
205
backend/internal/config/config.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
Default DefaultConfig `mapstructure:"default"`
|
||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||
Pricing PricingConfig `mapstructure:"pricing"`
|
||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
}
|
||||
|
||||
type PricingConfig struct {
|
||||
// 价格数据远程URL(默认使用LiteLLM镜像)
|
||||
RemoteURL string `mapstructure:"remote_url"`
|
||||
// 哈希校验文件URL
|
||||
HashURL string `mapstructure:"hash_url"`
|
||||
// 本地数据目录
|
||||
DataDir string `mapstructure:"data_dir"`
|
||||
// 回退文件路径
|
||||
FallbackFile string `mapstructure:"fallback_file"`
|
||||
// 更新间隔(小时)
|
||||
UpdateIntervalHours int `mapstructure:"update_interval_hours"`
|
||||
// 哈希校验间隔(分钟)
|
||||
HashCheckIntervalMinutes int `mapstructure:"hash_check_interval_minutes"`
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
Mode string `mapstructure:"mode"` // debug/release
|
||||
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
|
||||
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
|
||||
}
|
||||
|
||||
// GatewayConfig API网关相关配置
|
||||
type GatewayConfig struct {
|
||||
// 等待上游响应头的超时时间(秒),0表示无超时
|
||||
// 注意:这不影响流式数据传输,只控制等待响应头的时间
|
||||
ResponseHeaderTimeout int `mapstructure:"response_header_timeout"`
|
||||
}
|
||||
|
||||
func (s *ServerConfig) Address() string {
|
||||
return fmt.Sprintf("%s:%d", s.Host, s.Port)
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
User string `mapstructure:"user"`
|
||||
Password string `mapstructure:"password"`
|
||||
DBName string `mapstructure:"dbname"`
|
||||
SSLMode string `mapstructure:"sslmode"`
|
||||
}
|
||||
|
||||
func (d *DatabaseConfig) DSN() string {
|
||||
return fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
||||
d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode,
|
||||
)
|
||||
}
|
||||
|
||||
// DSNWithTimezone returns DSN with timezone setting
|
||||
func (d *DatabaseConfig) DSNWithTimezone(tz string) string {
|
||||
if tz == "" {
|
||||
tz = "Asia/Shanghai"
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s TimeZone=%s",
|
||||
d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode, tz,
|
||||
)
|
||||
}
|
||||
|
||||
type RedisConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
Password string `mapstructure:"password"`
|
||||
DB int `mapstructure:"db"`
|
||||
}
|
||||
|
||||
func (r *RedisConfig) Address() string {
|
||||
return fmt.Sprintf("%s:%d", r.Host, r.Port)
|
||||
}
|
||||
|
||||
type JWTConfig struct {
|
||||
Secret string `mapstructure:"secret"`
|
||||
ExpireHour int `mapstructure:"expire_hour"`
|
||||
}
|
||||
|
||||
type DefaultConfig struct {
|
||||
AdminEmail string `mapstructure:"admin_email"`
|
||||
AdminPassword string `mapstructure:"admin_password"`
|
||||
UserConcurrency int `mapstructure:"user_concurrency"`
|
||||
UserBalance float64 `mapstructure:"user_balance"`
|
||||
ApiKeyPrefix string `mapstructure:"api_key_prefix"`
|
||||
RateMultiplier float64 `mapstructure:"rate_multiplier"`
|
||||
}
|
||||
|
||||
type RateLimitConfig struct {
|
||||
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
|
||||
}
|
||||
|
||||
func Load() (*Config, error) {
|
||||
viper.SetConfigName("config")
|
||||
viper.SetConfigType("yaml")
|
||||
viper.AddConfigPath(".")
|
||||
viper.AddConfigPath("./config")
|
||||
viper.AddConfigPath("/etc/sub2api")
|
||||
|
||||
// 环境变量支持
|
||||
viper.AutomaticEnv()
|
||||
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
|
||||
// 默认值
|
||||
setDefaults()
|
||||
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
|
||||
return nil, fmt.Errorf("read config error: %w", err)
|
||||
}
|
||||
// 配置文件不存在时使用默认值
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := viper.Unmarshal(&cfg); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal config error: %w", err)
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("validate config error: %w", err)
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func setDefaults() {
|
||||
// Server
|
||||
viper.SetDefault("server.host", "0.0.0.0")
|
||||
viper.SetDefault("server.port", 8080)
|
||||
viper.SetDefault("server.mode", "debug")
|
||||
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
|
||||
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
|
||||
|
||||
// Database
|
||||
viper.SetDefault("database.host", "localhost")
|
||||
viper.SetDefault("database.port", 5432)
|
||||
viper.SetDefault("database.user", "postgres")
|
||||
viper.SetDefault("database.password", "postgres")
|
||||
viper.SetDefault("database.dbname", "sub2api")
|
||||
viper.SetDefault("database.sslmode", "disable")
|
||||
|
||||
// Redis
|
||||
viper.SetDefault("redis.host", "localhost")
|
||||
viper.SetDefault("redis.port", 6379)
|
||||
viper.SetDefault("redis.password", "")
|
||||
viper.SetDefault("redis.db", 0)
|
||||
|
||||
// JWT
|
||||
viper.SetDefault("jwt.secret", "change-me-in-production")
|
||||
viper.SetDefault("jwt.expire_hour", 24)
|
||||
|
||||
// Default
|
||||
viper.SetDefault("default.admin_email", "admin@sub2api.com")
|
||||
viper.SetDefault("default.admin_password", "admin123")
|
||||
viper.SetDefault("default.user_concurrency", 5)
|
||||
viper.SetDefault("default.user_balance", 0)
|
||||
viper.SetDefault("default.api_key_prefix", "sk-")
|
||||
viper.SetDefault("default.rate_multiplier", 1.0)
|
||||
|
||||
// RateLimit
|
||||
viper.SetDefault("rate_limit.overload_cooldown_minutes", 10)
|
||||
|
||||
// Pricing - 从 price-mirror 分支同步,该分支维护了 sha256 哈希文件用于增量更新检查
|
||||
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.json")
|
||||
viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.sha256")
|
||||
viper.SetDefault("pricing.data_dir", "./data")
|
||||
viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json")
|
||||
viper.SetDefault("pricing.update_interval_hours", 24)
|
||||
viper.SetDefault("pricing.hash_check_interval_minutes", 10)
|
||||
|
||||
// Timezone (default to Asia/Shanghai for Chinese users)
|
||||
viper.SetDefault("timezone", "Asia/Shanghai")
|
||||
|
||||
// Gateway
|
||||
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error {
|
||||
if c.JWT.Secret == "" {
|
||||
return fmt.Errorf("jwt.secret is required")
|
||||
}
|
||||
if c.JWT.Secret == "change-me-in-production" && c.Server.Mode == "release" {
|
||||
return fmt.Errorf("jwt.secret must be changed in production")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user