feat(安全): 添加安全开关并完善测试流程
实现安全开关默认关闭与响应头透传逻辑 - URL 校验与响应头过滤支持开关并覆盖流式路径 - 非流式 Content-Type 透传/默认值按配置生效 - 接入 go test、golangci-lint 与前端 lint/typecheck - 补充相关测试与配置/文档说明
This commit is contained in:
@@ -72,6 +72,9 @@ func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error)
|
||||
if s.cfg == nil {
|
||||
return "", errors.New("config is not available")
|
||||
}
|
||||
if !s.cfg.Security.URLAllowlist.Enabled {
|
||||
return strings.TrimSpace(raw), nil
|
||||
}
|
||||
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
|
||||
RequireAllowlist: true,
|
||||
|
||||
@@ -194,9 +194,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
if s.cfg == nil {
|
||||
return nil, errors.New("config is not available")
|
||||
}
|
||||
baseURL, err := normalizeBaseURL(input.BaseURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
baseURL := strings.TrimSpace(input.BaseURL)
|
||||
if s.cfg.Security.URLAllowlist.Enabled {
|
||||
normalized, err := normalizeBaseURL(baseURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
baseURL = normalized
|
||||
}
|
||||
if strings.TrimSpace(input.Username) == "" || strings.TrimSpace(input.Password) == "" {
|
||||
return nil, errors.New("username and password are required")
|
||||
@@ -204,7 +208,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 20 * time.Second,
|
||||
ValidateResolvedIP: true,
|
||||
ValidateResolvedIP: s.cfg.Security.URLAllowlist.Enabled,
|
||||
AllowPrivateHosts: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -1583,6 +1583,10 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
// 更新5h窗口状态
|
||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||
|
||||
if s.cfg != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
}
|
||||
|
||||
// 设置SSE响应头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
@@ -1837,8 +1841,15 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
|
||||
contentType := "application/json"
|
||||
if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
|
||||
if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" {
|
||||
contentType = upstreamType
|
||||
}
|
||||
}
|
||||
|
||||
// 写入响应
|
||||
c.Data(resp.StatusCode, "application/json", body)
|
||||
c.Data(resp.StatusCode, contentType, body)
|
||||
|
||||
return &response.Usage, nil
|
||||
}
|
||||
@@ -2194,6 +2205,9 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
|
||||
}
|
||||
|
||||
func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
|
||||
return strings.TrimSpace(raw), nil
|
||||
}
|
||||
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
|
||||
RequireAllowlist: true,
|
||||
|
||||
@@ -237,6 +237,9 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
|
||||
return strings.TrimSpace(raw), nil
|
||||
}
|
||||
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
|
||||
RequireAllowlist: true,
|
||||
@@ -1720,6 +1723,10 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte
|
||||
}
|
||||
log.Printf("[GeminiAPI] ====================================================")
|
||||
|
||||
if s.cfg != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
}
|
||||
|
||||
c.Status(resp.StatusCode)
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
@@ -762,6 +762,10 @@ type openaiStreamingResult struct {
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
|
||||
if s.cfg != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
}
|
||||
|
||||
// Set SSE response headers
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
@@ -1030,12 +1034,22 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
|
||||
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
|
||||
c.Data(resp.StatusCode, "application/json", body)
|
||||
contentType := "application/json"
|
||||
if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
|
||||
if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" {
|
||||
contentType = upstreamType
|
||||
}
|
||||
}
|
||||
|
||||
c.Data(resp.StatusCode, contentType, body)
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
|
||||
return strings.TrimSpace(raw), nil
|
||||
}
|
||||
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
|
||||
RequireAllowlist: true,
|
||||
|
||||
@@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -88,3 +89,175 @@ func TestOpenAIStreamingTooLong(t *testing.T) {
|
||||
t.Fatalf("expected response_too_large SSE error, got %q", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAINonStreamingContentTypePassThrough(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
body := []byte(`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
Header: http.Header{"Content-Type": []string{"application/vnd.test+json"}},
|
||||
}
|
||||
|
||||
_, err := svc.handleNonStreamingResponse(c.Request.Context(), resp, c, &Account{}, "model", "model")
|
||||
if err != nil {
|
||||
t.Fatalf("handleNonStreamingResponse error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(rec.Header().Get("Content-Type"), "application/vnd.test+json") {
|
||||
t.Fatalf("expected Content-Type passthrough, got %q", rec.Header().Get("Content-Type"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAINonStreamingContentTypeDefault(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
body := []byte(`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
_, err := svc.handleNonStreamingResponse(c.Request.Context(), resp, c, &Account{}, "model", "model")
|
||||
if err != nil {
|
||||
t.Fatalf("handleNonStreamingResponse error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(rec.Header().Get("Content-Type"), "application/json") {
|
||||
t.Fatalf("expected default Content-Type, got %q", rec.Header().Get("Content-Type"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingHeadersOverride(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
|
||||
},
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
StreamKeepaliveInterval: 0,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{
|
||||
"Cache-Control": []string{"upstream"},
|
||||
"X-Test": []string{"value"},
|
||||
"Content-Type": []string{"application/custom"},
|
||||
},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {}\n\n"))
|
||||
}()
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||
_ = pr.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("handleStreamingResponse error: %v", err)
|
||||
}
|
||||
|
||||
if rec.Header().Get("Cache-Control") != "no-cache" {
|
||||
t.Fatalf("expected Cache-Control override, got %q", rec.Header().Get("Cache-Control"))
|
||||
}
|
||||
if rec.Header().Get("Content-Type") != "text/event-stream" {
|
||||
t.Fatalf("expected Content-Type override, got %q", rec.Header().Get("Content-Type"))
|
||||
}
|
||||
if rec.Header().Get("X-Test") != "value" {
|
||||
t.Fatalf("expected X-Test passthrough, got %q", rec.Header().Get("X-Test"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
URLAllowlist: config.URLAllowlistConfig{Enabled: false},
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{"base_url": "://invalid-url"},
|
||||
}
|
||||
|
||||
_, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for invalid base_url when allowlist disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIValidateUpstreamBaseURLDisabledSkipsValidation(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
URLAllowlist: config.URLAllowlistConfig{Enabled: false},
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
normalized, err := svc.validateUpstreamBaseURL("http://not-https.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error when allowlist disabled, got %v", err)
|
||||
}
|
||||
if normalized != "http://not-https.example.com" {
|
||||
t.Fatalf("expected raw url passthrough, got %q", normalized)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
URLAllowlist: config.URLAllowlistConfig{
|
||||
Enabled: true,
|
||||
UpstreamHosts: []string{"example.com"},
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
if _, err := svc.validateUpstreamBaseURL("https://example.com"); err != nil {
|
||||
t.Fatalf("expected allowlisted host to pass, got %v", err)
|
||||
}
|
||||
if _, err := svc.validateUpstreamBaseURL("https://evil.com"); err == nil {
|
||||
t.Fatalf("expected non-allowlisted host to fail")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -409,6 +409,9 @@ func (s *PricingService) fetchRemoteHash() (string, error) {
|
||||
}
|
||||
|
||||
func (s *PricingService) validatePricingURL(raw string) (string, error) {
|
||||
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
|
||||
return strings.TrimSpace(raw), nil
|
||||
}
|
||||
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: s.cfg.Security.URLAllowlist.PricingHosts,
|
||||
RequireAllowlist: true,
|
||||
|
||||
Reference in New Issue
Block a user