feat: 添加流模式下的SSE保活机制 #945

This commit is contained in:
CaIon
2025-04-14 19:40:23 +08:00
parent dcf7878772
commit 2f3acd9d22
8 changed files with 136 additions and 31 deletions

View File

@@ -141,7 +141,6 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
if err != nil { if err != nil {
common.SysError("error handling stream format: " + err.Error()) common.SysError("error handling stream format: " + err.Error())
} }
info.SetFirstResponseTime()
} }
lastStreamData = data lastStreamData = data
streamItems = append(streamItems, data) streamItems = append(streamItems, data)

View File

@@ -6,6 +6,7 @@ import (
"one-api/dto" "one-api/dto"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"strings" "strings"
"sync"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -54,6 +55,7 @@ type RelayInfo struct {
StartTime time.Time StartTime time.Time
FirstResponseTime time.Time FirstResponseTime time.Time
isFirstResponse bool isFirstResponse bool
responseMutex sync.Mutex // Add mutex for protecting concurrent access
//SendLastReasoningResponse bool //SendLastReasoningResponse bool
ApiType int ApiType int
IsStream bool IsStream bool
@@ -212,12 +214,19 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
} }
func (info *RelayInfo) SetFirstResponseTime() { func (info *RelayInfo) SetFirstResponseTime() {
info.responseMutex.Lock()
defer info.responseMutex.Unlock()
if info.isFirstResponse { if info.isFirstResponse {
info.FirstResponseTime = time.Now() info.FirstResponseTime = time.Now()
info.isFirstResponse = false info.isFirstResponse = false
} }
} }
func (info *RelayInfo) HasSendResponse() bool {
return info.FirstResponseTime.After(info.StartTime)
}
type TaskRelayInfo struct { type TaskRelayInfo struct {
*RelayInfo *RelayInfo
Action string Action string

View File

@@ -55,6 +55,16 @@ func StringData(c *gin.Context, str string) error {
return nil return nil
} }
func PingData(c *gin.Context) error {
c.Writer.Write([]byte(": PING\n\n"))
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
} else {
return errors.New("streaming error: flusher not found")
}
return nil
}
func ObjectData(c *gin.Context, object interface{}) error { func ObjectData(c *gin.Context, object interface{}) error {
if object == nil { if object == nil {
return errors.New("object is nil") return errors.New("object is nil")

View File

@@ -3,12 +3,15 @@ package helper
import ( import (
"bufio" "bufio"
"context" "context"
"github.com/bytedance/gopkg/util/gopool"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/setting/operation_setting"
"strings" "strings"
"sync"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -17,11 +20,12 @@ import (
const ( const (
InitialScannerBufferSize = 1 << 20 // 1MB (1*1024*1024) InitialScannerBufferSize = 1 << 20 // 1MB (1*1024*1024)
MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024) MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024)
DefaultPingInterval = 10 * time.Second
) )
func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) { func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) {
if resp == nil { if resp == nil || dataHandler == nil {
return return
} }
@@ -34,13 +38,29 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
} }
var ( var (
stopChan = make(chan bool, 2) stopChan = make(chan bool, 2)
scanner = bufio.NewScanner(resp.Body) scanner = bufio.NewScanner(resp.Body)
ticker = time.NewTicker(streamingTimeout) ticker = time.NewTicker(streamingTimeout)
pingTicker *time.Ticker
writeMutex sync.Mutex // Mutex to protect concurrent writes
) )
generalSettings := operation_setting.GetGeneralSetting()
pingEnabled := generalSettings.PingIntervalEnabled
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
if pingInterval <= 0 {
pingInterval = DefaultPingInterval
}
if pingEnabled {
pingTicker = time.NewTicker(pingInterval)
}
defer func() { defer func() {
ticker.Stop() ticker.Stop()
if pingTicker != nil {
pingTicker.Stop()
}
close(stopChan) close(stopChan)
}() }()
scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize) scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize)
@@ -51,6 +71,34 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
defer cancel() defer cancel()
ctx = context.WithValue(ctx, "stop_chan", stopChan) ctx = context.WithValue(ctx, "stop_chan", stopChan)
// Handle ping data sending
if pingEnabled && pingTicker != nil {
gopool.Go(func() {
for {
select {
case <-pingTicker.C:
writeMutex.Lock() // Lock before writing
err := PingData(c)
writeMutex.Unlock() // Unlock after writing
if err != nil {
common.LogError(c, "ping data error: "+err.Error())
common.SafeSendBool(stopChan, true)
return
}
if common.DebugEnabled {
println("ping data sent")
}
case <-ctx.Done():
if common.DebugEnabled {
println("ping data goroutine stopped")
}
return
}
}
})
}
common.RelayCtxGo(ctx, func() { common.RelayCtxGo(ctx, func() {
for scanner.Scan() { for scanner.Scan() {
ticker.Reset(streamingTimeout) ticker.Reset(streamingTimeout)
@@ -70,7 +118,9 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
data = strings.TrimSuffix(data, "\"") data = strings.TrimSuffix(data, "\"")
if !strings.HasPrefix(data, "[DONE]") { if !strings.HasPrefix(data, "[DONE]") {
info.SetFirstResponseTime() info.SetFirstResponseTime()
writeMutex.Lock() // Lock before writing
success := dataHandler(data) success := dataHandler(data)
writeMutex.Unlock() // Unlock after writing
if !success { if !success {
break break
} }
@@ -90,7 +140,9 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
case <-ticker.C: case <-ticker.C:
// 超时处理逻辑 // 超时处理逻辑
common.LogError(c, "streaming timeout") common.LogError(c, "streaming timeout")
common.SafeSendBool(stopChan, true)
case <-stopChan: case <-stopChan:
// 正常结束 // 正常结束
common.LogInfo(c, "streaming finished")
} }
} }

View File

@@ -3,12 +3,16 @@ package operation_setting
import "one-api/setting/config" import "one-api/setting/config"
type GeneralSetting struct { type GeneralSetting struct {
DocsLink string `json:"docs_link"` DocsLink string `json:"docs_link"`
PingIntervalEnabled bool `json:"ping_interval_enabled"`
PingIntervalSeconds int `json:"ping_interval_seconds"`
} }
// 默认配置 // 默认配置
var generalSetting = GeneralSetting{ var generalSetting = GeneralSetting{
DocsLink: "https://docs.newapi.pro", DocsLink: "https://docs.newapi.pro",
PingIntervalEnabled: false,
PingIntervalSeconds: 60,
} }
func init() { func init() {

View File

@@ -18,6 +18,8 @@ const ModelSetting = () => {
'claude.default_max_tokens': '', 'claude.default_max_tokens': '',
'claude.thinking_adapter_budget_tokens_percentage': 0.8, 'claude.thinking_adapter_budget_tokens_percentage': 0.8,
'global.pass_through_request_enabled': false, 'global.pass_through_request_enabled': false,
'general_setting.ping_interval_enabled': false,
'general_setting.ping_interval_seconds': 60,
}); });
let [loading, setLoading] = useState(false); let [loading, setLoading] = useState(false);

View File

@@ -793,23 +793,7 @@ const PersonalSetting = () => {
</div> </div>
</Card> </Card>
<Card style={{ marginTop: 10 }}> <Card style={{ marginTop: 10 }}>
<Tabs type="line" defaultActiveKey="price"> <Tabs type="line" defaultActiveKey="notification">
<TabPane tab={t('价格设置')} itemKey="price">
<div style={{ marginTop: 20 }}>
<Typography.Text strong>{t('接受未设置价格模型')}</Typography.Text>
<div style={{ marginTop: 10 }}>
<Checkbox
checked={notificationSettings.acceptUnsetModelRatioModel}
onChange={e => handleNotificationSettingChange('acceptUnsetModelRatioModel', e.target.checked)}
>
{t('接受未设置价格模型')}
</Checkbox>
<Typography.Text type="secondary" style={{ marginTop: 8, display: 'block' }}>
{t('当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用')}
</Typography.Text>
</div>
</div>
</TabPane>
<TabPane tab={t('通知设置')} itemKey="notification"> <TabPane tab={t('通知设置')} itemKey="notification">
<div style={{ marginTop: 20 }}> <div style={{ marginTop: 20 }}>
<Typography.Text strong>{t('通知方式')}</Typography.Text> <Typography.Text strong>{t('通知方式')}</Typography.Text>
@@ -923,6 +907,23 @@ const PersonalSetting = () => {
</Typography.Text> </Typography.Text>
</div> </div>
</TabPane> </TabPane>
<TabPane tab={t('价格设置')} itemKey="price">
<div style={{ marginTop: 20 }}>
<Typography.Text strong>{t('接受未设置价格模型')}</Typography.Text>
<div style={{ marginTop: 10 }}>
<Checkbox
checked={notificationSettings.acceptUnsetModelRatioModel}
onChange={e => handleNotificationSettingChange('acceptUnsetModelRatioModel', e.target.checked)}
>
{t('接受未设置价格模型')}
</Checkbox>
<Typography.Text type="secondary" style={{ marginTop: 8, display: 'block' }}>
{t('当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用')}
</Typography.Text>
</div>
</div>
</TabPane>
</Tabs> </Tabs>
<div style={{ marginTop: 20 }}> <div style={{ marginTop: 20 }}>
<Button type="primary" onClick={saveNotificationSettings}> <Button type="primary" onClick={saveNotificationSettings}>

View File

@@ -1,5 +1,5 @@
import React, { useEffect, useState, useRef } from 'react'; import React, { useEffect, useState, useRef } from 'react';
import { Button, Col, Form, Row, Spin } from '@douyinfe/semi-ui'; import { Button, Col, Form, Row, Spin, Banner } from '@douyinfe/semi-ui';
import { import {
compareObjects, compareObjects,
API, API,
@@ -15,6 +15,8 @@ export default function SettingGlobalModel(props) {
const [loading, setLoading] = useState(false); const [loading, setLoading] = useState(false);
const [inputs, setInputs] = useState({ const [inputs, setInputs] = useState({
'global.pass_through_request_enabled': false, 'global.pass_through_request_enabled': false,
'general_setting.ping_interval_enabled': false,
'general_setting.ping_interval_seconds': 60,
}); });
const refForm = useRef(); const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs); const [inputsRow, setInputsRow] = useState(inputs);
@@ -23,12 +25,8 @@ export default function SettingGlobalModel(props) {
const updateArray = compareObjects(inputs, inputsRow); const updateArray = compareObjects(inputs, inputsRow);
if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); if (!updateArray.length) return showWarning(t('你似乎并没有修改什么'));
const requestQueue = updateArray.map((item) => { const requestQueue = updateArray.map((item) => {
let value = ''; let value = String(inputs[item.key]);
if (typeof inputs[item.key] === 'boolean') {
value = String(inputs[item.key]);
} else {
value = inputs[item.key];
}
return API.put('/api/option/', { return API.put('/api/option/', {
key: item.key, key: item.key,
value, value,
@@ -84,6 +82,36 @@ export default function SettingGlobalModel(props) {
/> />
</Col> </Col>
</Row> </Row>
<Form.Section text={t('连接保活设置')}>
<Row style={{ marginTop: 10 }}>
<Col span={24}>
<Banner
type="warning"
description="警告启用保活后如果已经写入保活数据后渠道出错系统无法重试如果必须开启推荐设置尽可能大的Ping间隔"
/>
</Col>
</Row>
<Row>
<Col xs={24} sm={12} md={8} lg={8} xl={8}>
<Form.Switch
label={t('启用Ping间隔')}
field={'general_setting.ping_interval_enabled'}
onChange={(value) => setInputs({ ...inputs, 'general_setting.ping_interval_enabled': value })}
extraText={'开启后将定期发送ping数据保持连接活跃'}
/>
</Col>
<Col xs={24} sm={12} md={8} lg={8} xl={8}>
<Form.InputNumber
label={t('Ping间隔')}
field={'general_setting.ping_interval_seconds'}
onChange={(value) => setInputs({ ...inputs, 'general_setting.ping_interval_seconds': value })}
min={1}
disabled={!inputs['general_setting.ping_interval_enabled']}
/>
</Col>
</Row>
</Form.Section>
<Row> <Row>
<Button size='default' onClick={onSubmit}> <Button size='default' onClick={onSubmit}>