package service import ( "bufio" "errors" "io" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/gin-gonic/gin" ) func TestOpenAIStreamingTimeout(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ Gateway: config.GatewayConfig{ StreamDataIntervalTimeout: 1, StreamKeepaliveInterval: 0, MaxLineSize: defaultMaxLineSize, }, } svc := &OpenAIGatewayService{cfg: cfg} rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/", nil) pr, pw := io.Pipe() resp := &http.Response{ StatusCode: http.StatusOK, Body: pr, Header: http.Header{}, } start := time.Now() _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, start, "model", "model") _ = pw.Close() _ = pr.Close() if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") { t.Fatalf("expected stream timeout error, got %v", err) } if !strings.Contains(rec.Body.String(), "stream_timeout") { t.Fatalf("expected stream_timeout SSE error, got %q", rec.Body.String()) } } func TestOpenAIStreamingTooLong(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ Gateway: config.GatewayConfig{ StreamDataIntervalTimeout: 0, StreamKeepaliveInterval: 0, MaxLineSize: 64 * 1024, }, } svc := &OpenAIGatewayService{cfg: cfg} rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/", nil) pr, pw := io.Pipe() resp := &http.Response{ StatusCode: http.StatusOK, Body: pr, Header: http.Header{}, } go func() { defer func() { _ = pw.Close() }() // 写入超过 MaxLineSize 的单行数据,触发 ErrTooLong payload := "data: " + strings.Repeat("a", 128*1024) + "\n" _, _ = pw.Write([]byte(payload)) }() _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 2}, time.Now(), "model", "model") _ = pr.Close() if !errors.Is(err, bufio.ErrTooLong) { t.Fatalf("expected ErrTooLong, got %v", err) } if !strings.Contains(rec.Body.String(), "response_too_large") { t.Fatalf("expected response_too_large SSE error, got %q", rec.Body.String()) } }