实现 allow_insecure_http 并在关闭校验时执行最小格式验证 - 关闭 allowlist 时要求 URL 可解析且 scheme 合规 - 响应头过滤关闭时使用默认白名单策略 - 更新相关文档、示例与测试覆盖
287 lines
8.3 KiB
Go
287 lines
8.3 KiB
Go
package service
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
func TestOpenAIStreamingTimeout(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
cfg := &config.Config{
|
|
Gateway: config.GatewayConfig{
|
|
StreamDataIntervalTimeout: 1,
|
|
StreamKeepaliveInterval: 0,
|
|
MaxLineSize: defaultMaxLineSize,
|
|
},
|
|
}
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
|
|
|
pr, pw := io.Pipe()
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: pr,
|
|
Header: http.Header{},
|
|
}
|
|
|
|
start := time.Now()
|
|
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, start, "model", "model")
|
|
_ = pw.Close()
|
|
_ = pr.Close()
|
|
|
|
if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") {
|
|
t.Fatalf("expected stream timeout error, got %v", err)
|
|
}
|
|
if !strings.Contains(rec.Body.String(), "stream_timeout") {
|
|
t.Fatalf("expected stream_timeout SSE error, got %q", rec.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestOpenAIStreamingTooLong(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
cfg := &config.Config{
|
|
Gateway: config.GatewayConfig{
|
|
StreamDataIntervalTimeout: 0,
|
|
StreamKeepaliveInterval: 0,
|
|
MaxLineSize: 64 * 1024,
|
|
},
|
|
}
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
|
|
|
pr, pw := io.Pipe()
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: pr,
|
|
Header: http.Header{},
|
|
}
|
|
|
|
go func() {
|
|
defer func() { _ = pw.Close() }()
|
|
// 写入超过 MaxLineSize 的单行数据,触发 ErrTooLong
|
|
payload := "data: " + strings.Repeat("a", 128*1024) + "\n"
|
|
_, _ = pw.Write([]byte(payload))
|
|
}()
|
|
|
|
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 2}, time.Now(), "model", "model")
|
|
_ = pr.Close()
|
|
|
|
if !errors.Is(err, bufio.ErrTooLong) {
|
|
t.Fatalf("expected ErrTooLong, got %v", err)
|
|
}
|
|
if !strings.Contains(rec.Body.String(), "response_too_large") {
|
|
t.Fatalf("expected response_too_large SSE error, got %q", rec.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestOpenAINonStreamingContentTypePassThrough(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
cfg := &config.Config{
|
|
Security: config.SecurityConfig{
|
|
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
|
|
},
|
|
}
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
|
|
|
body := []byte(`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`)
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: io.NopCloser(bytes.NewReader(body)),
|
|
Header: http.Header{"Content-Type": []string{"application/vnd.test+json"}},
|
|
}
|
|
|
|
_, err := svc.handleNonStreamingResponse(c.Request.Context(), resp, c, &Account{}, "model", "model")
|
|
if err != nil {
|
|
t.Fatalf("handleNonStreamingResponse error: %v", err)
|
|
}
|
|
|
|
if !strings.Contains(rec.Header().Get("Content-Type"), "application/vnd.test+json") {
|
|
t.Fatalf("expected Content-Type passthrough, got %q", rec.Header().Get("Content-Type"))
|
|
}
|
|
}
|
|
|
|
func TestOpenAINonStreamingContentTypeDefault(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
cfg := &config.Config{
|
|
Security: config.SecurityConfig{
|
|
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
|
|
},
|
|
}
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
|
|
|
body := []byte(`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`)
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: io.NopCloser(bytes.NewReader(body)),
|
|
Header: http.Header{},
|
|
}
|
|
|
|
_, err := svc.handleNonStreamingResponse(c.Request.Context(), resp, c, &Account{}, "model", "model")
|
|
if err != nil {
|
|
t.Fatalf("handleNonStreamingResponse error: %v", err)
|
|
}
|
|
|
|
if !strings.Contains(rec.Header().Get("Content-Type"), "application/json") {
|
|
t.Fatalf("expected default Content-Type, got %q", rec.Header().Get("Content-Type"))
|
|
}
|
|
}
|
|
|
|
func TestOpenAIStreamingHeadersOverride(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
cfg := &config.Config{
|
|
Security: config.SecurityConfig{
|
|
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
|
|
},
|
|
Gateway: config.GatewayConfig{
|
|
StreamDataIntervalTimeout: 0,
|
|
StreamKeepaliveInterval: 0,
|
|
MaxLineSize: defaultMaxLineSize,
|
|
},
|
|
}
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
|
|
|
pr, pw := io.Pipe()
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: pr,
|
|
Header: http.Header{
|
|
"Cache-Control": []string{"upstream"},
|
|
"X-Request-Id": []string{"req-123"},
|
|
"Content-Type": []string{"application/custom"},
|
|
},
|
|
}
|
|
|
|
go func() {
|
|
defer func() { _ = pw.Close() }()
|
|
_, _ = pw.Write([]byte("data: {}\n\n"))
|
|
}()
|
|
|
|
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
|
_ = pr.Close()
|
|
if err != nil {
|
|
t.Fatalf("handleStreamingResponse error: %v", err)
|
|
}
|
|
|
|
if rec.Header().Get("Cache-Control") != "no-cache" {
|
|
t.Fatalf("expected Cache-Control override, got %q", rec.Header().Get("Cache-Control"))
|
|
}
|
|
if rec.Header().Get("Content-Type") != "text/event-stream" {
|
|
t.Fatalf("expected Content-Type override, got %q", rec.Header().Get("Content-Type"))
|
|
}
|
|
if rec.Header().Get("X-Request-Id") != "req-123" {
|
|
t.Fatalf("expected X-Request-Id passthrough, got %q", rec.Header().Get("X-Request-Id"))
|
|
}
|
|
}
|
|
|
|
func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
cfg := &config.Config{
|
|
Security: config.SecurityConfig{
|
|
URLAllowlist: config.URLAllowlistConfig{Enabled: false},
|
|
},
|
|
}
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
|
|
|
account := &Account{
|
|
Platform: PlatformOpenAI,
|
|
Type: AccountTypeAPIKey,
|
|
Credentials: map[string]any{"base_url": "://invalid-url"},
|
|
}
|
|
|
|
_, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false)
|
|
if err == nil {
|
|
t.Fatalf("expected error for invalid base_url when allowlist disabled")
|
|
}
|
|
}
|
|
|
|
func TestOpenAIValidateUpstreamBaseURLDisabledRequiresHTTPS(t *testing.T) {
|
|
cfg := &config.Config{
|
|
Security: config.SecurityConfig{
|
|
URLAllowlist: config.URLAllowlistConfig{Enabled: false},
|
|
},
|
|
}
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
if _, err := svc.validateUpstreamBaseURL("http://not-https.example.com"); err == nil {
|
|
t.Fatalf("expected http to be rejected when allow_insecure_http is false")
|
|
}
|
|
normalized, err := svc.validateUpstreamBaseURL("https://example.com")
|
|
if err != nil {
|
|
t.Fatalf("expected https to be allowed when allowlist disabled, got %v", err)
|
|
}
|
|
if normalized != "https://example.com" {
|
|
t.Fatalf("expected raw url passthrough, got %q", normalized)
|
|
}
|
|
}
|
|
|
|
func TestOpenAIValidateUpstreamBaseURLDisabledAllowsHTTP(t *testing.T) {
|
|
cfg := &config.Config{
|
|
Security: config.SecurityConfig{
|
|
URLAllowlist: config.URLAllowlistConfig{
|
|
Enabled: false,
|
|
AllowInsecureHTTP: true,
|
|
},
|
|
},
|
|
}
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
normalized, err := svc.validateUpstreamBaseURL("http://not-https.example.com")
|
|
if err != nil {
|
|
t.Fatalf("expected http allowed when allow_insecure_http is true, got %v", err)
|
|
}
|
|
if normalized != "http://not-https.example.com" {
|
|
t.Fatalf("expected raw url passthrough, got %q", normalized)
|
|
}
|
|
}
|
|
|
|
func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) {
|
|
cfg := &config.Config{
|
|
Security: config.SecurityConfig{
|
|
URLAllowlist: config.URLAllowlistConfig{
|
|
Enabled: true,
|
|
UpstreamHosts: []string{"example.com"},
|
|
},
|
|
},
|
|
}
|
|
svc := &OpenAIGatewayService{cfg: cfg}
|
|
|
|
if _, err := svc.validateUpstreamBaseURL("https://example.com"); err != nil {
|
|
t.Fatalf("expected allowlisted host to pass, got %v", err)
|
|
}
|
|
if _, err := svc.validateUpstreamBaseURL("https://evil.com"); err == nil {
|
|
t.Fatalf("expected non-allowlisted host to fail")
|
|
}
|
|
}
|