实现安全开关默认关闭与响应头透传逻辑 - URL 校验与响应头过滤支持开关并覆盖流式路径 - 非流式 Content-Type 透传/默认值按配置生效 - 接入 go test、golangci-lint 与前端 lint/typecheck - 补充相关测试与配置/文档说明
264 lines
7.6 KiB
Go
264 lines
7.6 KiB
Go
package service
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
func TestOpenAIStreamingTimeout(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
cfg := &config.Config{
|
|
Gateway: config.GatewayConfig{
|
|
StreamDataIntervalTimeout: 1,
|
|
StreamKeepaliveInterval: 0,
|
|
MaxLineSize: defaultMaxLineSize,
|
|
},
|
|
}
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
|
|
|
pr, pw := io.Pipe()
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: pr,
|
|
Header: http.Header{},
|
|
}
|
|
|
|
start := time.Now()
|
|
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, start, "model", "model")
|
|
_ = pw.Close()
|
|
_ = pr.Close()
|
|
|
|
if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") {
|
|
t.Fatalf("expected stream timeout error, got %v", err)
|
|
}
|
|
if !strings.Contains(rec.Body.String(), "stream_timeout") {
|
|
t.Fatalf("expected stream_timeout SSE error, got %q", rec.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestOpenAIStreamingTooLong(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
cfg := &config.Config{
|
|
Gateway: config.GatewayConfig{
|
|
StreamDataIntervalTimeout: 0,
|
|
StreamKeepaliveInterval: 0,
|
|
MaxLineSize: 64 * 1024,
|
|
},
|
|
}
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
|
|
|
pr, pw := io.Pipe()
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: pr,
|
|
Header: http.Header{},
|
|
}
|
|
|
|
go func() {
|
|
defer func() { _ = pw.Close() }()
|
|
// 写入超过 MaxLineSize 的单行数据,触发 ErrTooLong
|
|
payload := "data: " + strings.Repeat("a", 128*1024) + "\n"
|
|
_, _ = pw.Write([]byte(payload))
|
|
}()
|
|
|
|
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 2}, time.Now(), "model", "model")
|
|
_ = pr.Close()
|
|
|
|
if !errors.Is(err, bufio.ErrTooLong) {
|
|
t.Fatalf("expected ErrTooLong, got %v", err)
|
|
}
|
|
if !strings.Contains(rec.Body.String(), "response_too_large") {
|
|
t.Fatalf("expected response_too_large SSE error, got %q", rec.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestOpenAINonStreamingContentTypePassThrough(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
cfg := &config.Config{
|
|
Security: config.SecurityConfig{
|
|
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
|
|
},
|
|
}
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
|
|
|
body := []byte(`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`)
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: io.NopCloser(bytes.NewReader(body)),
|
|
Header: http.Header{"Content-Type": []string{"application/vnd.test+json"}},
|
|
}
|
|
|
|
_, err := svc.handleNonStreamingResponse(c.Request.Context(), resp, c, &Account{}, "model", "model")
|
|
if err != nil {
|
|
t.Fatalf("handleNonStreamingResponse error: %v", err)
|
|
}
|
|
|
|
if !strings.Contains(rec.Header().Get("Content-Type"), "application/vnd.test+json") {
|
|
t.Fatalf("expected Content-Type passthrough, got %q", rec.Header().Get("Content-Type"))
|
|
}
|
|
}
|
|
|
|
func TestOpenAINonStreamingContentTypeDefault(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
cfg := &config.Config{
|
|
Security: config.SecurityConfig{
|
|
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
|
|
},
|
|
}
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
|
|
|
body := []byte(`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`)
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: io.NopCloser(bytes.NewReader(body)),
|
|
Header: http.Header{},
|
|
}
|
|
|
|
_, err := svc.handleNonStreamingResponse(c.Request.Context(), resp, c, &Account{}, "model", "model")
|
|
if err != nil {
|
|
t.Fatalf("handleNonStreamingResponse error: %v", err)
|
|
}
|
|
|
|
if !strings.Contains(rec.Header().Get("Content-Type"), "application/json") {
|
|
t.Fatalf("expected default Content-Type, got %q", rec.Header().Get("Content-Type"))
|
|
}
|
|
}
|
|
|
|
func TestOpenAIStreamingHeadersOverride(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
cfg := &config.Config{
|
|
Security: config.SecurityConfig{
|
|
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
|
|
},
|
|
Gateway: config.GatewayConfig{
|
|
StreamDataIntervalTimeout: 0,
|
|
StreamKeepaliveInterval: 0,
|
|
MaxLineSize: defaultMaxLineSize,
|
|
},
|
|
}
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
|
|
|
pr, pw := io.Pipe()
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: pr,
|
|
Header: http.Header{
|
|
"Cache-Control": []string{"upstream"},
|
|
"X-Test": []string{"value"},
|
|
"Content-Type": []string{"application/custom"},
|
|
},
|
|
}
|
|
|
|
go func() {
|
|
defer func() { _ = pw.Close() }()
|
|
_, _ = pw.Write([]byte("data: {}\n\n"))
|
|
}()
|
|
|
|
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
|
_ = pr.Close()
|
|
if err != nil {
|
|
t.Fatalf("handleStreamingResponse error: %v", err)
|
|
}
|
|
|
|
if rec.Header().Get("Cache-Control") != "no-cache" {
|
|
t.Fatalf("expected Cache-Control override, got %q", rec.Header().Get("Cache-Control"))
|
|
}
|
|
if rec.Header().Get("Content-Type") != "text/event-stream" {
|
|
t.Fatalf("expected Content-Type override, got %q", rec.Header().Get("Content-Type"))
|
|
}
|
|
if rec.Header().Get("X-Test") != "value" {
|
|
t.Fatalf("expected X-Test passthrough, got %q", rec.Header().Get("X-Test"))
|
|
}
|
|
}
|
|
|
|
func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
cfg := &config.Config{
|
|
Security: config.SecurityConfig{
|
|
URLAllowlist: config.URLAllowlistConfig{Enabled: false},
|
|
},
|
|
}
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
|
|
|
account := &Account{
|
|
Platform: PlatformOpenAI,
|
|
Type: AccountTypeAPIKey,
|
|
Credentials: map[string]any{"base_url": "://invalid-url"},
|
|
}
|
|
|
|
_, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false)
|
|
if err == nil {
|
|
t.Fatalf("expected error for invalid base_url when allowlist disabled")
|
|
}
|
|
}
|
|
|
|
func TestOpenAIValidateUpstreamBaseURLDisabledSkipsValidation(t *testing.T) {
|
|
cfg := &config.Config{
|
|
Security: config.SecurityConfig{
|
|
URLAllowlist: config.URLAllowlistConfig{Enabled: false},
|
|
},
|
|
}
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
normalized, err := svc.validateUpstreamBaseURL("http://not-https.example.com")
|
|
if err != nil {
|
|
t.Fatalf("expected no error when allowlist disabled, got %v", err)
|
|
}
|
|
if normalized != "http://not-https.example.com" {
|
|
t.Fatalf("expected raw url passthrough, got %q", normalized)
|
|
}
|
|
}
|
|
|
|
func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) {
|
|
cfg := &config.Config{
|
|
Security: config.SecurityConfig{
|
|
URLAllowlist: config.URLAllowlistConfig{
|
|
Enabled: true,
|
|
UpstreamHosts: []string{"example.com"},
|
|
},
|
|
},
|
|
}
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
if _, err := svc.validateUpstreamBaseURL("https://example.com"); err != nil {
|
|
t.Fatalf("expected allowlisted host to pass, got %v", err)
|
|
}
|
|
if _, err := svc.validateUpstreamBaseURL("https://evil.com"); err == nil {
|
|
t.Fatalf("expected non-allowlisted host to fail")
|
|
}
|
|
}
|