diff --git a/.env.example b/.env.example index d317e1f3..ea246427 100644 --- a/.env.example +++ b/.env.example @@ -7,6 +7,8 @@ # 调试相关配置 # 启用pprof # ENABLE_PPROF=true +# 启用调试模式 +# DEBUG=true # 数据库相关配置 # 数据库连接字符串 @@ -41,6 +43,14 @@ # 更新任务启用 # UPDATE_TASK=true +# 对话超时设置 +# 所有请求超时时间,单位秒,默认为0,表示不限制 +# RELAY_TIMEOUT=0 +# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值 +# STREAMING_TIMEOUT=120 + +# Gemini 识别图片 最大图片数量 +# GEMINI_VISION_MAX_IMAGE_NUM=16 # 会话密钥 # SESSION_SECRET=random_string @@ -58,8 +68,6 @@ # GET_MEDIA_TOKEN_NOT_STREAM=true # 设置 Dify 渠道是否输出工作流和节点信息到客户端 # DIFY_DEBUG=true -# 设置流式一次回复的超时时间 -# STREAMING_TIMEOUT=120 # 节点类型 diff --git a/common/init.go b/common/init.go index c0caf0a1..dd680db2 100644 --- a/common/init.go +++ b/common/init.go @@ -24,7 +24,7 @@ func printHelp() { fmt.Println("Usage: one-api [--port ] [--log-dir ] [--version] [--help]") } -func LoadEnv() { +func InitCommonEnv() { flag.Parse() if *PrintVersion { diff --git a/docker-compose.yml b/docker-compose.yml index 3d707ed0..57ad0b30 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -16,7 +16,7 @@ services: - REDIS_CONN_STRING=redis://redis - TZ=Asia/Shanghai - ERROR_LOG_ENABLED=true # 是否启用错误日志记录 - # - TIKTOKEN_CACHE_DIR=./tiktoken_cache # 如果需要使用tiktoken_cache,请取消注释 + # - STREAMING_TIMEOUT=120 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值 # - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!! # - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment # - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed diff --git a/main.go b/main.go index cf593b57..5e7656e9 100644 --- a/main.go +++ b/main.go @@ -32,12 +32,12 @@ var buildFS embed.FS var indexPage []byte func main() { - err := godotenv.Load(".env") - if err != nil { - common.SysLog("Support for .env file is disabled: " + err.Error()) - } - common.LoadEnv() + err := InitResources() + if err != nil { + common.FatalLog("failed to initialize resources: " + err.Error()) + return + } common.SetupLogger() common.SysLog("New API " + common.Version + " started") @@ -47,19 +47,7 @@ func main() { if common.DebugEnabled { common.SysLog("running in debug mode") } - // Initialize SQL Database - err = model.InitDB() - if err != nil { - common.FatalLog("failed to initialize database: " + err.Error()) - } - model.CheckSetup() - - // Initialize SQL Database - err = model.InitLogDB() - if err != nil { - common.FatalLog("failed to initialize database: " + err.Error()) - } defer func() { err := model.CloseDB() if err != nil { @@ -67,21 +55,6 @@ func main() { } }() - // Initialize Redis - err = common.InitRedisClient() - if err != nil { - common.FatalLog("failed to initialize Redis: " + err.Error()) - } - - // Initialize model settings - ratio_setting.InitRatioSettings() - // Initialize constants - constant.InitEnv() - // Initialize options - model.InitOptionMap() - - service.InitTokenEncoders() - if common.RedisEnabled { // for compatibility with old versions common.MemoryCacheEnabled = true @@ -186,3 +159,50 @@ func main() { common.FatalLog("failed to start HTTP server: " + err.Error()) } } + +func InitResources() error { + // Initialize resources here if needed + // This is a placeholder function for future resource initialization + err := godotenv.Load(".env") + if err != nil { + common.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量") + common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.") + } + + // 加载旧的(common)环境变量 + common.InitCommonEnv() + // 加载constants的环境变量 + constant.InitEnv() + + // Initialize model settings + ratio_setting.InitRatioSettings() + + service.InitHttpClient() + + service.InitTokenEncoders() + + // Initialize SQL Database + err = model.InitDB() + if err != nil { + common.FatalLog("failed to initialize database: " + err.Error()) + return err + } + + model.CheckSetup() + + // Initialize options, should after model.InitDB() + model.InitOptionMap() + + // Initialize SQL Database + err = model.InitLogDB() + if err != nil { + return err + } + + // Initialize Redis + err = common.InitRedisClient() + if err != nil { + return err + } + return nil +} diff --git a/relay/channel/baidu/relay-baidu.go b/relay/channel/baidu/relay-baidu.go index 011af262..11492fe3 100644 --- a/relay/channel/baidu/relay-baidu.go +++ b/relay/channel/baidu/relay-baidu.go @@ -271,7 +271,7 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) { } req.Header.Add("Content-Type", "application/json") req.Header.Add("Accept", "application/json") - res, err := service.GetImpatientHttpClient().Do(req) + res, err := service.GetHttpClient().Do(req) if err != nil { return nil, err } diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go index c330c791..3a2845b3 100644 --- a/relay/channel/dify/relay-dify.go +++ b/relay/channel/dify/relay-dify.go @@ -95,7 +95,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) // Send request - client := service.GetImpatientHttpClient() + client := service.GetHttpClient() resp, err := client.Do(req) if err != nil { common.SysError("failed to send request: " + err.Error()) diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go index a69877e2..b526b1c0 100644 --- a/relay/helper/stream_scanner.go +++ b/relay/helper/stream_scanner.go @@ -20,8 +20,8 @@ import ( ) const ( - InitialScannerBufferSize = 64 << 10 // 64KB (64*1024) - MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024) + InitialScannerBufferSize = 64 << 10 // 64KB (64*1024) + MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024) DefaultPingInterval = 10 * time.Second ) @@ -49,7 +49,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon scanner = bufio.NewScanner(resp.Body) ticker = time.NewTicker(streamingTimeout) pingTicker *time.Ticker - writeMutex sync.Mutex // Mutex to protect concurrent writes + writeMutex sync.Mutex // Mutex to protect concurrent writes wg sync.WaitGroup // 用于等待所有 goroutine 退出 ) @@ -64,32 +64,39 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon pingTicker = time.NewTicker(pingInterval) } + if common.DebugEnabled { + // print timeout and ping interval for debugging + println("relay timeout seconds:", common.RelayTimeout) + println("streaming timeout seconds:", int64(streamingTimeout.Seconds())) + println("ping interval seconds:", int64(pingInterval.Seconds())) + } + // 改进资源清理,确保所有 goroutine 正确退出 defer func() { // 通知所有 goroutine 停止 common.SafeSendBool(stopChan, true) - + ticker.Stop() if pingTicker != nil { pingTicker.Stop() } - + // 等待所有 goroutine 退出,最多等待5秒 done := make(chan struct{}) go func() { wg.Wait() close(done) }() - + select { case <-done: case <-time.After(5 * time.Second): common.LogError(c, "timeout waiting for goroutines to exit") } - + close(stopChan) }() - + scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize) scanner.Split(bufio.ScanLines) SetEventStreamHeaders(c) @@ -113,12 +120,12 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon println("ping goroutine exited") } }() - + // 添加超时保护,防止 goroutine 无限运行 maxPingDuration := 30 * time.Minute // 最大 ping 持续时间 pingTimeout := time.NewTimer(maxPingDuration) defer pingTimeout.Stop() - + for { select { case <-pingTicker.C: @@ -129,7 +136,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon defer writeMutex.Unlock() done <- PingData(c) }() - + select { case err := <-done: if err != nil { @@ -175,7 +182,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon println("scanner goroutine exited") } }() - + for scanner.Scan() { // 检查是否需要停止 select { @@ -187,7 +194,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon return default: } - + ticker.Reset(streamingTimeout) data := scanner.Text() if common.DebugEnabled { @@ -205,7 +212,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon data = strings.TrimSuffix(data, "\r") if !strings.HasPrefix(data, "[DONE]") { info.SetFirstResponseTime() - + // 使用超时机制防止写操作阻塞 done := make(chan bool, 1) go func() { @@ -213,7 +220,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon defer writeMutex.Unlock() done <- dataHandler(data) }() - + select { case success := <-done: if !success { diff --git a/service/http_client.go b/service/http_client.go index 64a361cf..b191ddd7 100644 --- a/service/http_client.go +++ b/service/http_client.go @@ -13,9 +13,8 @@ import ( ) var httpClient *http.Client -var impatientHTTPClient *http.Client -func init() { +func InitHttpClient() { if common.RelayTimeout == 0 { httpClient = &http.Client{} } else { @@ -23,20 +22,12 @@ func init() { Timeout: time.Duration(common.RelayTimeout) * time.Second, } } - - impatientHTTPClient = &http.Client{ - Timeout: 5 * time.Second, - } } func GetHttpClient() *http.Client { return httpClient } -func GetImpatientHttpClient() *http.Client { - return impatientHTTPClient -} - // NewProxyHttpClient 创建支持代理的 HTTP 客户端 func NewProxyHttpClient(proxyURL string) (*http.Client, error) { if proxyURL == "" { diff --git a/service/webhook.go b/service/webhook.go index ad2967eb..8faccda3 100644 --- a/service/webhook.go +++ b/service/webhook.go @@ -101,7 +101,7 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error } // 发送请求 - client := GetImpatientHttpClient() + client := GetHttpClient() resp, err = client.Do(req) if err != nil { return fmt.Errorf("failed to send webhook request: %v", err)