feat: enhance environment configuration and resource initialization

This commit is contained in:
CaIon
2025-07-01 13:13:30 +08:00
parent 6b9237f868
commit eb265a55e1
9 changed files with 90 additions and 64 deletions

View File

@@ -7,6 +7,8 @@
# 调试相关配置 # 调试相关配置
# 启用pprof # 启用pprof
# ENABLE_PPROF=true # ENABLE_PPROF=true
# 启用调试模式
# DEBUG=true
# 数据库相关配置 # 数据库相关配置
# 数据库连接字符串 # 数据库连接字符串
@@ -41,6 +43,14 @@
# 更新任务启用 # 更新任务启用
# UPDATE_TASK=true # UPDATE_TASK=true
# 对话超时设置
# 所有请求超时时间单位秒默认为0表示不限制
# RELAY_TIMEOUT=0
# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值
# STREAMING_TIMEOUT=120
# Gemini 识别图片 最大图片数量
# GEMINI_VISION_MAX_IMAGE_NUM=16
# 会话密钥 # 会话密钥
# SESSION_SECRET=random_string # SESSION_SECRET=random_string
@@ -58,8 +68,6 @@
# GET_MEDIA_TOKEN_NOT_STREAM=true # GET_MEDIA_TOKEN_NOT_STREAM=true
# 设置 Dify 渠道是否输出工作流和节点信息到客户端 # 设置 Dify 渠道是否输出工作流和节点信息到客户端
# DIFY_DEBUG=true # DIFY_DEBUG=true
# 设置流式一次回复的超时时间
# STREAMING_TIMEOUT=120
# 节点类型 # 节点类型

View File

@@ -24,7 +24,7 @@ func printHelp() {
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]") fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
} }
func LoadEnv() { func InitCommonEnv() {
flag.Parse() flag.Parse()
if *PrintVersion { if *PrintVersion {

View File

@@ -16,7 +16,7 @@ services:
- REDIS_CONN_STRING=redis://redis - REDIS_CONN_STRING=redis://redis
- TZ=Asia/Shanghai - TZ=Asia/Shanghai
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录 - ERROR_LOG_ENABLED=true # 是否启用错误日志记录
# - TIKTOKEN_CACHE_DIR=./tiktoken_cache # 如果需要使用tiktoken_cache请取消注释 # - STREAMING_TIMEOUT=120 # 流模式无响应超时时间单位秒默认120秒如果出现空补全可以尝试改为更大值
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!! # - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!!
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment # - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed # - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed

84
main.go
View File

@@ -32,12 +32,12 @@ var buildFS embed.FS
var indexPage []byte var indexPage []byte
func main() { func main() {
err := godotenv.Load(".env")
if err != nil {
common.SysLog("Support for .env file is disabled: " + err.Error())
}
common.LoadEnv() err := InitResources()
if err != nil {
common.FatalLog("failed to initialize resources: " + err.Error())
return
}
common.SetupLogger() common.SetupLogger()
common.SysLog("New API " + common.Version + " started") common.SysLog("New API " + common.Version + " started")
@@ -47,19 +47,7 @@ func main() {
if common.DebugEnabled { if common.DebugEnabled {
common.SysLog("running in debug mode") common.SysLog("running in debug mode")
} }
// Initialize SQL Database
err = model.InitDB()
if err != nil {
common.FatalLog("failed to initialize database: " + err.Error())
}
model.CheckSetup()
// Initialize SQL Database
err = model.InitLogDB()
if err != nil {
common.FatalLog("failed to initialize database: " + err.Error())
}
defer func() { defer func() {
err := model.CloseDB() err := model.CloseDB()
if err != nil { if err != nil {
@@ -67,21 +55,6 @@ func main() {
} }
}() }()
// Initialize Redis
err = common.InitRedisClient()
if err != nil {
common.FatalLog("failed to initialize Redis: " + err.Error())
}
// Initialize model settings
ratio_setting.InitRatioSettings()
// Initialize constants
constant.InitEnv()
// Initialize options
model.InitOptionMap()
service.InitTokenEncoders()
if common.RedisEnabled { if common.RedisEnabled {
// for compatibility with old versions // for compatibility with old versions
common.MemoryCacheEnabled = true common.MemoryCacheEnabled = true
@@ -186,3 +159,50 @@ func main() {
common.FatalLog("failed to start HTTP server: " + err.Error()) common.FatalLog("failed to start HTTP server: " + err.Error())
} }
} }
func InitResources() error {
// Initialize resources here if needed
// This is a placeholder function for future resource initialization
err := godotenv.Load(".env")
if err != nil {
common.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量")
common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
}
// 加载旧的(common)环境变量
common.InitCommonEnv()
// 加载constants的环境变量
constant.InitEnv()
// Initialize model settings
ratio_setting.InitRatioSettings()
service.InitHttpClient()
service.InitTokenEncoders()
// Initialize SQL Database
err = model.InitDB()
if err != nil {
common.FatalLog("failed to initialize database: " + err.Error())
return err
}
model.CheckSetup()
// Initialize options, should after model.InitDB()
model.InitOptionMap()
// Initialize SQL Database
err = model.InitLogDB()
if err != nil {
return err
}
// Initialize Redis
err = common.InitRedisClient()
if err != nil {
return err
}
return nil
}

View File

@@ -271,7 +271,7 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
} }
req.Header.Add("Content-Type", "application/json") req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json") req.Header.Add("Accept", "application/json")
res, err := service.GetImpatientHttpClient().Do(req) res, err := service.GetHttpClient().Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -95,7 +95,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
// Send request // Send request
client := service.GetImpatientHttpClient() client := service.GetHttpClient()
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
common.SysError("failed to send request: " + err.Error()) common.SysError("failed to send request: " + err.Error())

View File

@@ -20,8 +20,8 @@ import (
) )
const ( const (
InitialScannerBufferSize = 64 << 10 // 64KB (64*1024) InitialScannerBufferSize = 64 << 10 // 64KB (64*1024)
MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024) MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024)
DefaultPingInterval = 10 * time.Second DefaultPingInterval = 10 * time.Second
) )
@@ -49,7 +49,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
scanner = bufio.NewScanner(resp.Body) scanner = bufio.NewScanner(resp.Body)
ticker = time.NewTicker(streamingTimeout) ticker = time.NewTicker(streamingTimeout)
pingTicker *time.Ticker pingTicker *time.Ticker
writeMutex sync.Mutex // Mutex to protect concurrent writes writeMutex sync.Mutex // Mutex to protect concurrent writes
wg sync.WaitGroup // 用于等待所有 goroutine 退出 wg sync.WaitGroup // 用于等待所有 goroutine 退出
) )
@@ -64,32 +64,39 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
pingTicker = time.NewTicker(pingInterval) pingTicker = time.NewTicker(pingInterval)
} }
if common.DebugEnabled {
// print timeout and ping interval for debugging
println("relay timeout seconds:", common.RelayTimeout)
println("streaming timeout seconds:", int64(streamingTimeout.Seconds()))
println("ping interval seconds:", int64(pingInterval.Seconds()))
}
// 改进资源清理,确保所有 goroutine 正确退出 // 改进资源清理,确保所有 goroutine 正确退出
defer func() { defer func() {
// 通知所有 goroutine 停止 // 通知所有 goroutine 停止
common.SafeSendBool(stopChan, true) common.SafeSendBool(stopChan, true)
ticker.Stop() ticker.Stop()
if pingTicker != nil { if pingTicker != nil {
pingTicker.Stop() pingTicker.Stop()
} }
// 等待所有 goroutine 退出最多等待5秒 // 等待所有 goroutine 退出最多等待5秒
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
wg.Wait() wg.Wait()
close(done) close(done)
}() }()
select { select {
case <-done: case <-done:
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
common.LogError(c, "timeout waiting for goroutines to exit") common.LogError(c, "timeout waiting for goroutines to exit")
} }
close(stopChan) close(stopChan)
}() }()
scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize) scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize)
scanner.Split(bufio.ScanLines) scanner.Split(bufio.ScanLines)
SetEventStreamHeaders(c) SetEventStreamHeaders(c)
@@ -113,12 +120,12 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
println("ping goroutine exited") println("ping goroutine exited")
} }
}() }()
// 添加超时保护,防止 goroutine 无限运行 // 添加超时保护,防止 goroutine 无限运行
maxPingDuration := 30 * time.Minute // 最大 ping 持续时间 maxPingDuration := 30 * time.Minute // 最大 ping 持续时间
pingTimeout := time.NewTimer(maxPingDuration) pingTimeout := time.NewTimer(maxPingDuration)
defer pingTimeout.Stop() defer pingTimeout.Stop()
for { for {
select { select {
case <-pingTicker.C: case <-pingTicker.C:
@@ -129,7 +136,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
defer writeMutex.Unlock() defer writeMutex.Unlock()
done <- PingData(c) done <- PingData(c)
}() }()
select { select {
case err := <-done: case err := <-done:
if err != nil { if err != nil {
@@ -175,7 +182,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
println("scanner goroutine exited") println("scanner goroutine exited")
} }
}() }()
for scanner.Scan() { for scanner.Scan() {
// 检查是否需要停止 // 检查是否需要停止
select { select {
@@ -187,7 +194,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
return return
default: default:
} }
ticker.Reset(streamingTimeout) ticker.Reset(streamingTimeout)
data := scanner.Text() data := scanner.Text()
if common.DebugEnabled { if common.DebugEnabled {
@@ -205,7 +212,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
data = strings.TrimSuffix(data, "\r") data = strings.TrimSuffix(data, "\r")
if !strings.HasPrefix(data, "[DONE]") { if !strings.HasPrefix(data, "[DONE]") {
info.SetFirstResponseTime() info.SetFirstResponseTime()
// 使用超时机制防止写操作阻塞 // 使用超时机制防止写操作阻塞
done := make(chan bool, 1) done := make(chan bool, 1)
go func() { go func() {
@@ -213,7 +220,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
defer writeMutex.Unlock() defer writeMutex.Unlock()
done <- dataHandler(data) done <- dataHandler(data)
}() }()
select { select {
case success := <-done: case success := <-done:
if !success { if !success {

View File

@@ -13,9 +13,8 @@ import (
) )
var httpClient *http.Client var httpClient *http.Client
var impatientHTTPClient *http.Client
func init() { func InitHttpClient() {
if common.RelayTimeout == 0 { if common.RelayTimeout == 0 {
httpClient = &http.Client{} httpClient = &http.Client{}
} else { } else {
@@ -23,20 +22,12 @@ func init() {
Timeout: time.Duration(common.RelayTimeout) * time.Second, Timeout: time.Duration(common.RelayTimeout) * time.Second,
} }
} }
impatientHTTPClient = &http.Client{
Timeout: 5 * time.Second,
}
} }
func GetHttpClient() *http.Client { func GetHttpClient() *http.Client {
return httpClient return httpClient
} }
func GetImpatientHttpClient() *http.Client {
return impatientHTTPClient
}
// NewProxyHttpClient 创建支持代理的 HTTP 客户端 // NewProxyHttpClient 创建支持代理的 HTTP 客户端
func NewProxyHttpClient(proxyURL string) (*http.Client, error) { func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
if proxyURL == "" { if proxyURL == "" {

View File

@@ -101,7 +101,7 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error
} }
// 发送请求 // 发送请求
client := GetImpatientHttpClient() client := GetHttpClient()
resp, err = client.Do(req) resp, err = client.Do(req)
if err != nil { if err != nil {
return fmt.Errorf("failed to send webhook request: %v", err) return fmt.Errorf("failed to send webhook request: %v", err)