✨ feat: enhance environment configuration and resource initialization
This commit is contained in:
12
.env.example
12
.env.example
@@ -7,6 +7,8 @@
|
|||||||
# 调试相关配置
|
# 调试相关配置
|
||||||
# 启用pprof
|
# 启用pprof
|
||||||
# ENABLE_PPROF=true
|
# ENABLE_PPROF=true
|
||||||
|
# 启用调试模式
|
||||||
|
# DEBUG=true
|
||||||
|
|
||||||
# 数据库相关配置
|
# 数据库相关配置
|
||||||
# 数据库连接字符串
|
# 数据库连接字符串
|
||||||
@@ -41,6 +43,14 @@
|
|||||||
# 更新任务启用
|
# 更新任务启用
|
||||||
# UPDATE_TASK=true
|
# UPDATE_TASK=true
|
||||||
|
|
||||||
|
# 对话超时设置
|
||||||
|
# 所有请求超时时间,单位秒,默认为0,表示不限制
|
||||||
|
# RELAY_TIMEOUT=0
|
||||||
|
# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值
|
||||||
|
# STREAMING_TIMEOUT=120
|
||||||
|
|
||||||
|
# Gemini 识别图片 最大图片数量
|
||||||
|
# GEMINI_VISION_MAX_IMAGE_NUM=16
|
||||||
|
|
||||||
# 会话密钥
|
# 会话密钥
|
||||||
# SESSION_SECRET=random_string
|
# SESSION_SECRET=random_string
|
||||||
@@ -58,8 +68,6 @@
|
|||||||
# GET_MEDIA_TOKEN_NOT_STREAM=true
|
# GET_MEDIA_TOKEN_NOT_STREAM=true
|
||||||
# 设置 Dify 渠道是否输出工作流和节点信息到客户端
|
# 设置 Dify 渠道是否输出工作流和节点信息到客户端
|
||||||
# DIFY_DEBUG=true
|
# DIFY_DEBUG=true
|
||||||
# 设置流式一次回复的超时时间
|
|
||||||
# STREAMING_TIMEOUT=120
|
|
||||||
|
|
||||||
|
|
||||||
# 节点类型
|
# 节点类型
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ func printHelp() {
|
|||||||
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
|
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadEnv() {
|
func InitCommonEnv() {
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
if *PrintVersion {
|
if *PrintVersion {
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ services:
|
|||||||
- REDIS_CONN_STRING=redis://redis
|
- REDIS_CONN_STRING=redis://redis
|
||||||
- TZ=Asia/Shanghai
|
- TZ=Asia/Shanghai
|
||||||
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录
|
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录
|
||||||
# - TIKTOKEN_CACHE_DIR=./tiktoken_cache # 如果需要使用tiktoken_cache,请取消注释
|
# - STREAMING_TIMEOUT=120 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值
|
||||||
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!!
|
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!!
|
||||||
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
|
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
|
||||||
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
|
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
|
||||||
|
|||||||
84
main.go
84
main.go
@@ -32,12 +32,12 @@ var buildFS embed.FS
|
|||||||
var indexPage []byte
|
var indexPage []byte
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
err := godotenv.Load(".env")
|
|
||||||
if err != nil {
|
|
||||||
common.SysLog("Support for .env file is disabled: " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
common.LoadEnv()
|
err := InitResources()
|
||||||
|
if err != nil {
|
||||||
|
common.FatalLog("failed to initialize resources: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
common.SetupLogger()
|
common.SetupLogger()
|
||||||
common.SysLog("New API " + common.Version + " started")
|
common.SysLog("New API " + common.Version + " started")
|
||||||
@@ -47,19 +47,7 @@ func main() {
|
|||||||
if common.DebugEnabled {
|
if common.DebugEnabled {
|
||||||
common.SysLog("running in debug mode")
|
common.SysLog("running in debug mode")
|
||||||
}
|
}
|
||||||
// Initialize SQL Database
|
|
||||||
err = model.InitDB()
|
|
||||||
if err != nil {
|
|
||||||
common.FatalLog("failed to initialize database: " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
model.CheckSetup()
|
|
||||||
|
|
||||||
// Initialize SQL Database
|
|
||||||
err = model.InitLogDB()
|
|
||||||
if err != nil {
|
|
||||||
common.FatalLog("failed to initialize database: " + err.Error())
|
|
||||||
}
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := model.CloseDB()
|
err := model.CloseDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -67,21 +55,6 @@ func main() {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Initialize Redis
|
|
||||||
err = common.InitRedisClient()
|
|
||||||
if err != nil {
|
|
||||||
common.FatalLog("failed to initialize Redis: " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize model settings
|
|
||||||
ratio_setting.InitRatioSettings()
|
|
||||||
// Initialize constants
|
|
||||||
constant.InitEnv()
|
|
||||||
// Initialize options
|
|
||||||
model.InitOptionMap()
|
|
||||||
|
|
||||||
service.InitTokenEncoders()
|
|
||||||
|
|
||||||
if common.RedisEnabled {
|
if common.RedisEnabled {
|
||||||
// for compatibility with old versions
|
// for compatibility with old versions
|
||||||
common.MemoryCacheEnabled = true
|
common.MemoryCacheEnabled = true
|
||||||
@@ -186,3 +159,50 @@ func main() {
|
|||||||
common.FatalLog("failed to start HTTP server: " + err.Error())
|
common.FatalLog("failed to start HTTP server: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func InitResources() error {
|
||||||
|
// Initialize resources here if needed
|
||||||
|
// This is a placeholder function for future resource initialization
|
||||||
|
err := godotenv.Load(".env")
|
||||||
|
if err != nil {
|
||||||
|
common.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量")
|
||||||
|
common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 加载旧的(common)环境变量
|
||||||
|
common.InitCommonEnv()
|
||||||
|
// 加载constants的环境变量
|
||||||
|
constant.InitEnv()
|
||||||
|
|
||||||
|
// Initialize model settings
|
||||||
|
ratio_setting.InitRatioSettings()
|
||||||
|
|
||||||
|
service.InitHttpClient()
|
||||||
|
|
||||||
|
service.InitTokenEncoders()
|
||||||
|
|
||||||
|
// Initialize SQL Database
|
||||||
|
err = model.InitDB()
|
||||||
|
if err != nil {
|
||||||
|
common.FatalLog("failed to initialize database: " + err.Error())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
model.CheckSetup()
|
||||||
|
|
||||||
|
// Initialize options, should after model.InitDB()
|
||||||
|
model.InitOptionMap()
|
||||||
|
|
||||||
|
// Initialize SQL Database
|
||||||
|
err = model.InitLogDB()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize Redis
|
||||||
|
err = common.InitRedisClient()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -271,7 +271,7 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
|
|||||||
}
|
}
|
||||||
req.Header.Add("Content-Type", "application/json")
|
req.Header.Add("Content-Type", "application/json")
|
||||||
req.Header.Add("Accept", "application/json")
|
req.Header.Add("Accept", "application/json")
|
||||||
res, err := service.GetImpatientHttpClient().Do(req)
|
res, err := service.GetHttpClient().Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
|
|||||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||||
|
|
||||||
// Send request
|
// Send request
|
||||||
client := service.GetImpatientHttpClient()
|
client := service.GetHttpClient()
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to send request: " + err.Error())
|
common.SysError("failed to send request: " + err.Error())
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
InitialScannerBufferSize = 64 << 10 // 64KB (64*1024)
|
InitialScannerBufferSize = 64 << 10 // 64KB (64*1024)
|
||||||
MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024)
|
MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024)
|
||||||
DefaultPingInterval = 10 * time.Second
|
DefaultPingInterval = 10 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -49,7 +49,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
|||||||
scanner = bufio.NewScanner(resp.Body)
|
scanner = bufio.NewScanner(resp.Body)
|
||||||
ticker = time.NewTicker(streamingTimeout)
|
ticker = time.NewTicker(streamingTimeout)
|
||||||
pingTicker *time.Ticker
|
pingTicker *time.Ticker
|
||||||
writeMutex sync.Mutex // Mutex to protect concurrent writes
|
writeMutex sync.Mutex // Mutex to protect concurrent writes
|
||||||
wg sync.WaitGroup // 用于等待所有 goroutine 退出
|
wg sync.WaitGroup // 用于等待所有 goroutine 退出
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -64,32 +64,39 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
|||||||
pingTicker = time.NewTicker(pingInterval)
|
pingTicker = time.NewTicker(pingInterval)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if common.DebugEnabled {
|
||||||
|
// print timeout and ping interval for debugging
|
||||||
|
println("relay timeout seconds:", common.RelayTimeout)
|
||||||
|
println("streaming timeout seconds:", int64(streamingTimeout.Seconds()))
|
||||||
|
println("ping interval seconds:", int64(pingInterval.Seconds()))
|
||||||
|
}
|
||||||
|
|
||||||
// 改进资源清理,确保所有 goroutine 正确退出
|
// 改进资源清理,确保所有 goroutine 正确退出
|
||||||
defer func() {
|
defer func() {
|
||||||
// 通知所有 goroutine 停止
|
// 通知所有 goroutine 停止
|
||||||
common.SafeSendBool(stopChan, true)
|
common.SafeSendBool(stopChan, true)
|
||||||
|
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
if pingTicker != nil {
|
if pingTicker != nil {
|
||||||
pingTicker.Stop()
|
pingTicker.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 等待所有 goroutine 退出,最多等待5秒
|
// 等待所有 goroutine 退出,最多等待5秒
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(5 * time.Second):
|
case <-time.After(5 * time.Second):
|
||||||
common.LogError(c, "timeout waiting for goroutines to exit")
|
common.LogError(c, "timeout waiting for goroutines to exit")
|
||||||
}
|
}
|
||||||
|
|
||||||
close(stopChan)
|
close(stopChan)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize)
|
scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize)
|
||||||
scanner.Split(bufio.ScanLines)
|
scanner.Split(bufio.ScanLines)
|
||||||
SetEventStreamHeaders(c)
|
SetEventStreamHeaders(c)
|
||||||
@@ -113,12 +120,12 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
|||||||
println("ping goroutine exited")
|
println("ping goroutine exited")
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// 添加超时保护,防止 goroutine 无限运行
|
// 添加超时保护,防止 goroutine 无限运行
|
||||||
maxPingDuration := 30 * time.Minute // 最大 ping 持续时间
|
maxPingDuration := 30 * time.Minute // 最大 ping 持续时间
|
||||||
pingTimeout := time.NewTimer(maxPingDuration)
|
pingTimeout := time.NewTimer(maxPingDuration)
|
||||||
defer pingTimeout.Stop()
|
defer pingTimeout.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-pingTicker.C:
|
case <-pingTicker.C:
|
||||||
@@ -129,7 +136,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
|||||||
defer writeMutex.Unlock()
|
defer writeMutex.Unlock()
|
||||||
done <- PingData(c)
|
done <- PingData(c)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case err := <-done:
|
case err := <-done:
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -175,7 +182,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
|||||||
println("scanner goroutine exited")
|
println("scanner goroutine exited")
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
// 检查是否需要停止
|
// 检查是否需要停止
|
||||||
select {
|
select {
|
||||||
@@ -187,7 +194,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
|||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
ticker.Reset(streamingTimeout)
|
ticker.Reset(streamingTimeout)
|
||||||
data := scanner.Text()
|
data := scanner.Text()
|
||||||
if common.DebugEnabled {
|
if common.DebugEnabled {
|
||||||
@@ -205,7 +212,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
|||||||
data = strings.TrimSuffix(data, "\r")
|
data = strings.TrimSuffix(data, "\r")
|
||||||
if !strings.HasPrefix(data, "[DONE]") {
|
if !strings.HasPrefix(data, "[DONE]") {
|
||||||
info.SetFirstResponseTime()
|
info.SetFirstResponseTime()
|
||||||
|
|
||||||
// 使用超时机制防止写操作阻塞
|
// 使用超时机制防止写操作阻塞
|
||||||
done := make(chan bool, 1)
|
done := make(chan bool, 1)
|
||||||
go func() {
|
go func() {
|
||||||
@@ -213,7 +220,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
|||||||
defer writeMutex.Unlock()
|
defer writeMutex.Unlock()
|
||||||
done <- dataHandler(data)
|
done <- dataHandler(data)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case success := <-done:
|
case success := <-done:
|
||||||
if !success {
|
if !success {
|
||||||
|
|||||||
@@ -13,9 +13,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var httpClient *http.Client
|
var httpClient *http.Client
|
||||||
var impatientHTTPClient *http.Client
|
|
||||||
|
|
||||||
func init() {
|
func InitHttpClient() {
|
||||||
if common.RelayTimeout == 0 {
|
if common.RelayTimeout == 0 {
|
||||||
httpClient = &http.Client{}
|
httpClient = &http.Client{}
|
||||||
} else {
|
} else {
|
||||||
@@ -23,20 +22,12 @@ func init() {
|
|||||||
Timeout: time.Duration(common.RelayTimeout) * time.Second,
|
Timeout: time.Duration(common.RelayTimeout) * time.Second,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impatientHTTPClient = &http.Client{
|
|
||||||
Timeout: 5 * time.Second,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetHttpClient() *http.Client {
|
func GetHttpClient() *http.Client {
|
||||||
return httpClient
|
return httpClient
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetImpatientHttpClient() *http.Client {
|
|
||||||
return impatientHTTPClient
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewProxyHttpClient 创建支持代理的 HTTP 客户端
|
// NewProxyHttpClient 创建支持代理的 HTTP 客户端
|
||||||
func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
|
func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
|
||||||
if proxyURL == "" {
|
if proxyURL == "" {
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 发送请求
|
// 发送请求
|
||||||
client := GetImpatientHttpClient()
|
client := GetHttpClient()
|
||||||
resp, err = client.Do(req)
|
resp, err = client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to send webhook request: %v", err)
|
return fmt.Errorf("failed to send webhook request: %v", err)
|
||||||
|
|||||||
Reference in New Issue
Block a user