feat: 添加流模式下的SSE保活机制 #945

This commit is contained in:
CaIon
2025-04-14 19:40:23 +08:00
parent dcf7878772
commit 2f3acd9d22
8 changed files with 136 additions and 31 deletions

View File

@@ -141,7 +141,6 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
if err != nil {
common.SysError("error handling stream format: " + err.Error())
}
info.SetFirstResponseTime()
}
lastStreamData = data
streamItems = append(streamItems, data)

View File

@@ -6,6 +6,7 @@ import (
"one-api/dto"
relayconstant "one-api/relay/constant"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
@@ -54,6 +55,7 @@ type RelayInfo struct {
StartTime time.Time
FirstResponseTime time.Time
isFirstResponse bool
responseMutex sync.Mutex // Add mutex for protecting concurrent access
//SendLastReasoningResponse bool
ApiType int
IsStream bool
@@ -212,12 +214,19 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
}
func (info *RelayInfo) SetFirstResponseTime() {
info.responseMutex.Lock()
defer info.responseMutex.Unlock()
if info.isFirstResponse {
info.FirstResponseTime = time.Now()
info.isFirstResponse = false
}
}
func (info *RelayInfo) HasSendResponse() bool {
return info.FirstResponseTime.After(info.StartTime)
}
type TaskRelayInfo struct {
*RelayInfo
Action string

View File

@@ -55,6 +55,16 @@ func StringData(c *gin.Context, str string) error {
return nil
}
func PingData(c *gin.Context) error {
c.Writer.Write([]byte(": PING\n\n"))
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
} else {
return errors.New("streaming error: flusher not found")
}
return nil
}
func ObjectData(c *gin.Context, object interface{}) error {
if object == nil {
return errors.New("object is nil")

View File

@@ -3,12 +3,15 @@ package helper
import (
"bufio"
"context"
"github.com/bytedance/gopkg/util/gopool"
"io"
"net/http"
"one-api/common"
"one-api/constant"
relaycommon "one-api/relay/common"
"one-api/setting/operation_setting"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
@@ -17,11 +20,12 @@ import (
const (
InitialScannerBufferSize = 1 << 20 // 1MB (1*1024*1024)
MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024)
DefaultPingInterval = 10 * time.Second
)
func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) {
if resp == nil {
if resp == nil || dataHandler == nil {
return
}
@@ -34,13 +38,29 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
}
var (
stopChan = make(chan bool, 2)
scanner = bufio.NewScanner(resp.Body)
ticker = time.NewTicker(streamingTimeout)
stopChan = make(chan bool, 2)
scanner = bufio.NewScanner(resp.Body)
ticker = time.NewTicker(streamingTimeout)
pingTicker *time.Ticker
writeMutex sync.Mutex // Mutex to protect concurrent writes
)
generalSettings := operation_setting.GetGeneralSetting()
pingEnabled := generalSettings.PingIntervalEnabled
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
if pingInterval <= 0 {
pingInterval = DefaultPingInterval
}
if pingEnabled {
pingTicker = time.NewTicker(pingInterval)
}
defer func() {
ticker.Stop()
if pingTicker != nil {
pingTicker.Stop()
}
close(stopChan)
}()
scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize)
@@ -51,6 +71,34 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
defer cancel()
ctx = context.WithValue(ctx, "stop_chan", stopChan)
// Handle ping data sending
if pingEnabled && pingTicker != nil {
gopool.Go(func() {
for {
select {
case <-pingTicker.C:
writeMutex.Lock() // Lock before writing
err := PingData(c)
writeMutex.Unlock() // Unlock after writing
if err != nil {
common.LogError(c, "ping data error: "+err.Error())
common.SafeSendBool(stopChan, true)
return
}
if common.DebugEnabled {
println("ping data sent")
}
case <-ctx.Done():
if common.DebugEnabled {
println("ping data goroutine stopped")
}
return
}
}
})
}
common.RelayCtxGo(ctx, func() {
for scanner.Scan() {
ticker.Reset(streamingTimeout)
@@ -70,7 +118,9 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
data = strings.TrimSuffix(data, "\"")
if !strings.HasPrefix(data, "[DONE]") {
info.SetFirstResponseTime()
writeMutex.Lock() // Lock before writing
success := dataHandler(data)
writeMutex.Unlock() // Unlock after writing
if !success {
break
}
@@ -90,7 +140,9 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
case <-ticker.C:
// 超时处理逻辑
common.LogError(c, "streaming timeout")
common.SafeSendBool(stopChan, true)
case <-stopChan:
// 正常结束
common.LogInfo(c, "streaming finished")
}
}

View File

@@ -3,12 +3,16 @@ package operation_setting
import "one-api/setting/config"
type GeneralSetting struct {
DocsLink string `json:"docs_link"`
DocsLink string `json:"docs_link"`
PingIntervalEnabled bool `json:"ping_interval_enabled"`
PingIntervalSeconds int `json:"ping_interval_seconds"`
}
// 默认配置
var generalSetting = GeneralSetting{
DocsLink: "https://docs.newapi.pro",
DocsLink: "https://docs.newapi.pro",
PingIntervalEnabled: false,
PingIntervalSeconds: 60,
}
func init() {

View File

@@ -18,6 +18,8 @@ const ModelSetting = () => {
'claude.default_max_tokens': '',
'claude.thinking_adapter_budget_tokens_percentage': 0.8,
'global.pass_through_request_enabled': false,
'general_setting.ping_interval_enabled': false,
'general_setting.ping_interval_seconds': 60,
});
let [loading, setLoading] = useState(false);

View File

@@ -793,23 +793,7 @@ const PersonalSetting = () => {
</div>
</Card>
<Card style={{ marginTop: 10 }}>
<Tabs type="line" defaultActiveKey="price">
<TabPane tab={t('价格设置')} itemKey="price">
<div style={{ marginTop: 20 }}>
<Typography.Text strong>{t('接受未设置价格模型')}</Typography.Text>
<div style={{ marginTop: 10 }}>
<Checkbox
checked={notificationSettings.acceptUnsetModelRatioModel}
onChange={e => handleNotificationSettingChange('acceptUnsetModelRatioModel', e.target.checked)}
>
{t('接受未设置价格模型')}
</Checkbox>
<Typography.Text type="secondary" style={{ marginTop: 8, display: 'block' }}>
{t('当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用')}
</Typography.Text>
</div>
</div>
</TabPane>
<Tabs type="line" defaultActiveKey="notification">
<TabPane tab={t('通知设置')} itemKey="notification">
<div style={{ marginTop: 20 }}>
<Typography.Text strong>{t('通知方式')}</Typography.Text>
@@ -923,6 +907,23 @@ const PersonalSetting = () => {
</Typography.Text>
</div>
</TabPane>
<TabPane tab={t('价格设置')} itemKey="price">
<div style={{ marginTop: 20 }}>
<Typography.Text strong>{t('接受未设置价格模型')}</Typography.Text>
<div style={{ marginTop: 10 }}>
<Checkbox
checked={notificationSettings.acceptUnsetModelRatioModel}
onChange={e => handleNotificationSettingChange('acceptUnsetModelRatioModel', e.target.checked)}
>
{t('接受未设置价格模型')}
</Checkbox>
<Typography.Text type="secondary" style={{ marginTop: 8, display: 'block' }}>
{t('当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用')}
</Typography.Text>
</div>
</div>
</TabPane>
</Tabs>
<div style={{ marginTop: 20 }}>
<Button type="primary" onClick={saveNotificationSettings}>

View File

@@ -1,5 +1,5 @@
import React, { useEffect, useState, useRef } from 'react';
import { Button, Col, Form, Row, Spin } from '@douyinfe/semi-ui';
import { Button, Col, Form, Row, Spin, Banner } from '@douyinfe/semi-ui';
import {
compareObjects,
API,
@@ -15,6 +15,8 @@ export default function SettingGlobalModel(props) {
const [loading, setLoading] = useState(false);
const [inputs, setInputs] = useState({
'global.pass_through_request_enabled': false,
'general_setting.ping_interval_enabled': false,
'general_setting.ping_interval_seconds': 60,
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
@@ -23,12 +25,8 @@ export default function SettingGlobalModel(props) {
const updateArray = compareObjects(inputs, inputsRow);
if (!updateArray.length) return showWarning(t('你似乎并没有修改什么'));
const requestQueue = updateArray.map((item) => {
let value = '';
if (typeof inputs[item.key] === 'boolean') {
value = String(inputs[item.key]);
} else {
value = inputs[item.key];
}
let value = String(inputs[item.key]);
return API.put('/api/option/', {
key: item.key,
value,
@@ -84,6 +82,36 @@ export default function SettingGlobalModel(props) {
/>
</Col>
</Row>
<Form.Section text={t('连接保活设置')}>
<Row style={{ marginTop: 10 }}>
<Col span={24}>
<Banner
type="warning"
description="警告启用保活后如果已经写入保活数据后渠道出错系统无法重试如果必须开启推荐设置尽可能大的Ping间隔"
/>
</Col>
</Row>
<Row>
<Col xs={24} sm={12} md={8} lg={8} xl={8}>
<Form.Switch
label={t('启用Ping间隔')}
field={'general_setting.ping_interval_enabled'}
onChange={(value) => setInputs({ ...inputs, 'general_setting.ping_interval_enabled': value })}
extraText={'开启后将定期发送ping数据保持连接活跃'}
/>
</Col>
<Col xs={24} sm={12} md={8} lg={8} xl={8}>
<Form.InputNumber
label={t('Ping间隔')}
field={'general_setting.ping_interval_seconds'}
onChange={(value) => setInputs({ ...inputs, 'general_setting.ping_interval_seconds': value })}
min={1}
disabled={!inputs['general_setting.ping_interval_enabled']}
/>
</Col>
</Row>
</Form.Section>
<Row>
<Button size='default' onClick={onSubmit}>