fix(流式): 提升SSE稳定性并统一超时配置
- 扩展SSE行长与间隔超时处理,补充keepalive - 写入失败与超长行时发送错误事件,修复并发释放 - 同步默认配置与示例配置,更新Caddy超时/压缩规则 - 新增OpenAI流式超时与超长行测试 测试: go test ./...
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -64,6 +65,7 @@ type AntigravityGatewayService struct {
|
||||
tokenProvider *AntigravityTokenProvider
|
||||
rateLimitService *RateLimitService
|
||||
httpUpstream HTTPUpstream
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewAntigravityGatewayService(
|
||||
@@ -72,12 +74,14 @@ func NewAntigravityGatewayService(
|
||||
tokenProvider *AntigravityTokenProvider,
|
||||
rateLimitService *RateLimitService,
|
||||
httpUpstream HTTPUpstream,
|
||||
cfg *config.Config,
|
||||
) *AntigravityGatewayService {
|
||||
return &AntigravityGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
tokenProvider: tokenProvider,
|
||||
rateLimitService: rateLimitService,
|
||||
httpUpstream: httpUpstream,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -674,57 +678,147 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
return nil, errors.New("streaming not supported")
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 64*1024), defaultMaxLineSize)
|
||||
usage := &ClaudeUsage{}
|
||||
var firstTokenMs *int
|
||||
|
||||
type scanEvent struct {
|
||||
line string
|
||||
err error
|
||||
}
|
||||
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
|
||||
events := make(chan scanEvent, 1)
|
||||
done := make(chan struct{})
|
||||
sendEvent := func(ev scanEvent) bool {
|
||||
select {
|
||||
case events <- ev:
|
||||
return true
|
||||
case <-done:
|
||||
return false
|
||||
}
|
||||
}
|
||||
go func() {
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = sendEvent(scanEvent{err: err})
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
|
||||
streamInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||
}
|
||||
var intervalTimer *time.Timer
|
||||
if streamInterval > 0 {
|
||||
intervalTimer = time.NewTimer(streamInterval)
|
||||
defer intervalTimer.Stop()
|
||||
}
|
||||
var intervalCh <-chan time.Time
|
||||
if intervalTimer != nil {
|
||||
intervalCh = intervalTimer.C
|
||||
}
|
||||
resetInterval := func() {
|
||||
if intervalTimer == nil {
|
||||
return
|
||||
}
|
||||
if !intervalTimer.Stop() {
|
||||
select {
|
||||
case <-intervalTimer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
intervalTimer.Reset(streamInterval)
|
||||
}
|
||||
|
||||
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
||||
errorEventSent := false
|
||||
sendErrorEvent := func(reason string) {
|
||||
if errorEventSent {
|
||||
return
|
||||
}
|
||||
errorEventSent = true
|
||||
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if len(line) > 0 {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
if ev.err != nil {
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", defaultMaxLineSize, ev.err)
|
||||
sendErrorEvent("response_too_large")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
|
||||
}
|
||||
sendErrorEvent("stream_read_error")
|
||||
return nil, ev.err
|
||||
}
|
||||
|
||||
resetInterval()
|
||||
line := ev.line
|
||||
trimmed := strings.TrimRight(line, "\r\n")
|
||||
if strings.HasPrefix(trimmed, "data:") {
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
|
||||
if payload == "" || payload == "[DONE]" {
|
||||
_, _ = io.WriteString(c.Writer, line)
|
||||
flusher.Flush()
|
||||
} else {
|
||||
// 解包 v1internal 响应
|
||||
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
|
||||
if parseErr == nil && inner != nil {
|
||||
payload = string(inner)
|
||||
}
|
||||
|
||||
// 解析 usage
|
||||
var parsed map[string]any
|
||||
if json.Unmarshal(inner, &parsed) == nil {
|
||||
if u := extractGeminiUsage(parsed); u != nil {
|
||||
usage = u
|
||||
}
|
||||
}
|
||||
|
||||
if firstTokenMs == nil {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", payload)
|
||||
if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
_, _ = io.WriteString(c.Writer, line)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// 解包 v1internal 响应
|
||||
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
|
||||
if parseErr == nil && inner != nil {
|
||||
payload = string(inner)
|
||||
}
|
||||
|
||||
// 解析 usage
|
||||
var parsed map[string]any
|
||||
if json.Unmarshal(inner, &parsed) == nil {
|
||||
if u := extractGeminiUsage(parsed); u != nil {
|
||||
usage = u
|
||||
}
|
||||
}
|
||||
|
||||
if firstTokenMs == nil {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", payload); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
case <-intervalCh:
|
||||
log.Printf("Stream data interval timeout (antigravity)")
|
||||
sendErrorEvent("stream_timeout")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
}
|
||||
}
|
||||
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) handleGeminiNonStreamingResponse(c *gin.Context, resp *http.Response) (*ClaudeUsage, error) {
|
||||
@@ -863,7 +957,9 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
|
||||
processor := antigravity.NewStreamingProcessor(originalModel)
|
||||
var firstTokenMs *int
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 64*1024), defaultMaxLineSize)
|
||||
|
||||
// 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage
|
||||
convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage {
|
||||
@@ -878,13 +974,95 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return nil, fmt.Errorf("stream read error: %w", err)
|
||||
type scanEvent struct {
|
||||
line string
|
||||
err error
|
||||
}
|
||||
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
|
||||
events := make(chan scanEvent, 1)
|
||||
done := make(chan struct{})
|
||||
sendEvent := func(ev scanEvent) bool {
|
||||
select {
|
||||
case events <- ev:
|
||||
return true
|
||||
case <-done:
|
||||
return false
|
||||
}
|
||||
}
|
||||
go func() {
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = sendEvent(scanEvent{err: err})
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
if len(line) > 0 {
|
||||
streamInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||
}
|
||||
var intervalTimer *time.Timer
|
||||
if streamInterval > 0 {
|
||||
intervalTimer = time.NewTimer(streamInterval)
|
||||
defer intervalTimer.Stop()
|
||||
}
|
||||
var intervalCh <-chan time.Time
|
||||
if intervalTimer != nil {
|
||||
intervalCh = intervalTimer.C
|
||||
}
|
||||
resetInterval := func() {
|
||||
if intervalTimer == nil {
|
||||
return
|
||||
}
|
||||
if !intervalTimer.Stop() {
|
||||
select {
|
||||
case <-intervalTimer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
intervalTimer.Reset(streamInterval)
|
||||
}
|
||||
|
||||
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
||||
errorEventSent := false
|
||||
sendErrorEvent := func(reason string) {
|
||||
if errorEventSent {
|
||||
return
|
||||
}
|
||||
errorEventSent = true
|
||||
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
// 发送结束事件
|
||||
finalEvents, agUsage := processor.Finish()
|
||||
if len(finalEvents) > 0 {
|
||||
_, _ = c.Writer.Write(finalEvents)
|
||||
flusher.Flush()
|
||||
}
|
||||
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
if ev.err != nil {
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", defaultMaxLineSize, ev.err)
|
||||
sendErrorEvent("response_too_large")
|
||||
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, ev.err
|
||||
}
|
||||
sendErrorEvent("stream_read_error")
|
||||
return nil, fmt.Errorf("stream read error: %w", ev.err)
|
||||
}
|
||||
|
||||
resetInterval()
|
||||
line := ev.line
|
||||
// 处理 SSE 行,转换为 Claude 格式
|
||||
claudeEvents := processor.ProcessLine(strings.TrimRight(line, "\r\n"))
|
||||
|
||||
@@ -899,23 +1077,17 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
if len(finalEvents) > 0 {
|
||||
_, _ = c.Writer.Write(finalEvents)
|
||||
}
|
||||
sendErrorEvent("write_failed")
|
||||
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, writeErr
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
case <-intervalCh:
|
||||
log.Printf("Stream data interval timeout (antigravity)")
|
||||
sendErrorEvent("stream_timeout")
|
||||
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
}
|
||||
}
|
||||
|
||||
// 发送结束事件
|
||||
finalEvents, agUsage := processor.Finish()
|
||||
if len(finalEvents) > 0 {
|
||||
_, _ = c.Writer.Write(finalEvents)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
|
||||
@@ -32,6 +32,7 @@ const (
|
||||
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
|
||||
stickySessionTTL = time.Hour // 粘性会话TTL
|
||||
defaultMaxLineSize = 10 * 1024 * 1024
|
||||
)
|
||||
|
||||
// sseDataRe matches SSE data lines with optional whitespace after colon.
|
||||
@@ -1448,53 +1449,143 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
var firstTokenMs *int
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
// 设置更大的buffer以处理长行
|
||||
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||||
|
||||
type scanEvent struct {
|
||||
line string
|
||||
err error
|
||||
}
|
||||
// 独立 goroutine 读取上游,避免读取阻塞导致超时/keepalive无法处理
|
||||
events := make(chan scanEvent, 1)
|
||||
done := make(chan struct{})
|
||||
sendEvent := func(ev scanEvent) bool {
|
||||
select {
|
||||
case events <- ev:
|
||||
return true
|
||||
case <-done:
|
||||
return false
|
||||
}
|
||||
}
|
||||
go func() {
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = sendEvent(scanEvent{err: err})
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
streamInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||
}
|
||||
// 仅监控上游数据间隔超时,避免上游挂起占用资源
|
||||
var intervalTimer *time.Timer
|
||||
if streamInterval > 0 {
|
||||
intervalTimer = time.NewTimer(streamInterval)
|
||||
defer intervalTimer.Stop()
|
||||
}
|
||||
|
||||
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
|
||||
errorEventSent := false
|
||||
sendErrorEvent := func(reason string) {
|
||||
if errorEventSent {
|
||||
return
|
||||
}
|
||||
errorEventSent = true
|
||||
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
needModelReplace := originalModel != mappedModel
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "event: error" {
|
||||
return nil, errors.New("have error in stream")
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
if ev.err != nil {
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||
sendErrorEvent("response_too_large")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
|
||||
}
|
||||
sendErrorEvent("stream_read_error")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
|
||||
}
|
||||
line := ev.line
|
||||
if intervalTimer != nil {
|
||||
resetTimer(intervalTimer, streamInterval)
|
||||
}
|
||||
if line == "event: error" {
|
||||
return nil, errors.New("have error in stream")
|
||||
}
|
||||
|
||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
||||
if sseDataRe.MatchString(line) {
|
||||
data := sseDataRe.ReplaceAllString(line, "")
|
||||
|
||||
// 如果有模型映射,替换响应中的model字段
|
||||
if needModelReplace {
|
||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||
}
|
||||
|
||||
// 转发行
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
// 记录首字时间:第一个有效的 content_block_delta 或 message_start
|
||||
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
s.parseSSEUsage(data, usage)
|
||||
} else {
|
||||
// 非 data 行直接转发
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
case <-func() <-chan time.Time {
|
||||
if intervalTimer != nil {
|
||||
return intervalTimer.C
|
||||
}
|
||||
return nil
|
||||
}():
|
||||
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||
sendErrorEvent("stream_timeout")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
}
|
||||
|
||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
||||
if sseDataRe.MatchString(line) {
|
||||
data := sseDataRe.ReplaceAllString(line, "")
|
||||
|
||||
// 如果有模型映射,替换响应中的model字段
|
||||
if needModelReplace {
|
||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||
}
|
||||
|
||||
// 转发行
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
// 记录首字时间:第一个有效的 content_block_delta 或 message_start
|
||||
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
s.parseSSEUsage(data, usage)
|
||||
} else {
|
||||
// 非 data 行直接转发
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
|
||||
}
|
||||
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
|
||||
func resetTimer(timer *time.Timer, interval time.Duration) {
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
timer.Reset(interval)
|
||||
}
|
||||
|
||||
// replaceModelInSSELine 替换SSE数据行中的model字段
|
||||
func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
|
||||
if !sseDataRe.MatchString(line) {
|
||||
|
||||
@@ -775,47 +775,154 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
usage := &OpenAIUsage{}
|
||||
var firstTokenMs *int
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||||
|
||||
type scanEvent struct {
|
||||
line string
|
||||
err error
|
||||
}
|
||||
// 独立 goroutine 读取上游,避免读取阻塞影响 keepalive/超时处理
|
||||
events := make(chan scanEvent, 1)
|
||||
done := make(chan struct{})
|
||||
sendEvent := func(ev scanEvent) bool {
|
||||
select {
|
||||
case events <- ev:
|
||||
return true
|
||||
case <-done:
|
||||
return false
|
||||
}
|
||||
}
|
||||
go func() {
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = sendEvent(scanEvent{err: err})
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
streamInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||
}
|
||||
// 仅监控上游数据间隔超时,不被下游 keepalive 影响
|
||||
var intervalTimer *time.Timer
|
||||
if streamInterval > 0 {
|
||||
intervalTimer = time.NewTimer(streamInterval)
|
||||
defer intervalTimer.Stop()
|
||||
}
|
||||
var intervalCh <-chan time.Time
|
||||
if intervalTimer != nil {
|
||||
intervalCh = intervalTimer.C
|
||||
}
|
||||
|
||||
keepaliveInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||||
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||||
}
|
||||
// 下游 keepalive 仅用于防止代理空闲断开
|
||||
var keepaliveTicker *time.Ticker
|
||||
if keepaliveInterval > 0 {
|
||||
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
}
|
||||
var keepaliveCh <-chan time.Time
|
||||
if keepaliveTicker != nil {
|
||||
keepaliveCh = keepaliveTicker.C
|
||||
}
|
||||
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
|
||||
lastDataAt := time.Now()
|
||||
|
||||
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
|
||||
errorEventSent := false
|
||||
sendErrorEvent := func(reason string) {
|
||||
if errorEventSent {
|
||||
return
|
||||
}
|
||||
errorEventSent = true
|
||||
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
needModelReplace := originalModel != mappedModel
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
||||
if openaiSSEDataRe.MatchString(line) {
|
||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||
|
||||
// Replace model in response if needed
|
||||
if needModelReplace {
|
||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
if ev.err != nil {
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||
sendErrorEvent("response_too_large")
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
|
||||
}
|
||||
sendErrorEvent("stream_read_error")
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
|
||||
}
|
||||
|
||||
// Forward line
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
line := ev.line
|
||||
lastDataAt = time.Now()
|
||||
if intervalTimer != nil {
|
||||
resetTimer(intervalTimer, streamInterval)
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
// Record first token time
|
||||
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
||||
if openaiSSEDataRe.MatchString(line) {
|
||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||
|
||||
// Replace model in response if needed
|
||||
if needModelReplace {
|
||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||
}
|
||||
|
||||
// Forward line
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
// Record first token time
|
||||
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
s.parseSSEUsage(data, usage)
|
||||
} else {
|
||||
// Forward non-data lines as-is
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
s.parseSSEUsage(data, usage)
|
||||
} else {
|
||||
// Forward non-data lines as-is
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
|
||||
case <-intervalCh:
|
||||
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||
sendErrorEvent("stream_timeout")
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
|
||||
case <-keepaliveCh:
|
||||
if time.Since(lastDataAt) < keepaliveInterval {
|
||||
continue
|
||||
}
|
||||
if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
|
||||
}
|
||||
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
|
||||
|
||||
90
backend/internal/service/openai_gateway_service_test.go
Normal file
90
backend/internal/service/openai_gateway_service_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestOpenAIStreamingTimeout(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 1,
|
||||
StreamKeepaliveInterval: 0,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, start, "model", "model")
|
||||
_ = pw.Close()
|
||||
_ = pr.Close()
|
||||
|
||||
if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") {
|
||||
t.Fatalf("expected stream timeout error, got %v", err)
|
||||
}
|
||||
if !strings.Contains(rec.Body.String(), "stream_timeout") {
|
||||
t.Fatalf("expected stream_timeout SSE error, got %q", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingTooLong(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
StreamKeepaliveInterval: 0,
|
||||
MaxLineSize: 64 * 1024,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer pw.Close()
|
||||
// 写入超过 MaxLineSize 的单行数据,触发 ErrTooLong
|
||||
payload := "data: " + strings.Repeat("a", 128*1024) + "\n"
|
||||
_, _ = pw.Write([]byte(payload))
|
||||
}()
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 2}, time.Now(), "model", "model")
|
||||
_ = pr.Close()
|
||||
|
||||
if !errors.Is(err, bufio.ErrTooLong) {
|
||||
t.Fatalf("expected ErrTooLong, got %v", err)
|
||||
}
|
||||
if !strings.Contains(rec.Body.String(), "response_too_large") {
|
||||
t.Fatalf("expected response_too_large SSE error, got %q", rec.Body.String())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user