first commit
This commit is contained in:
100
backend/internal/util/logredact/redact.go
Normal file
100
backend/internal/util/logredact/redact.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package logredact
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// maxRedactDepth 限制递归深度以防止栈溢出
|
||||
const maxRedactDepth = 32
|
||||
|
||||
var defaultSensitiveKeys = map[string]struct{}{
|
||||
"authorization_code": {},
|
||||
"code": {},
|
||||
"code_verifier": {},
|
||||
"access_token": {},
|
||||
"refresh_token": {},
|
||||
"id_token": {},
|
||||
"client_secret": {},
|
||||
"password": {},
|
||||
}
|
||||
|
||||
func RedactMap(input map[string]any, extraKeys ...string) map[string]any {
|
||||
if input == nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
keys := buildKeySet(extraKeys)
|
||||
redacted, ok := redactValueWithDepth(input, keys, 0).(map[string]any)
|
||||
if !ok {
|
||||
return map[string]any{}
|
||||
}
|
||||
return redacted
|
||||
}
|
||||
|
||||
func RedactJSON(raw []byte, extraKeys ...string) string {
|
||||
if len(raw) == 0 {
|
||||
return ""
|
||||
}
|
||||
var value any
|
||||
if err := json.Unmarshal(raw, &value); err != nil {
|
||||
return "<non-json payload redacted>"
|
||||
}
|
||||
keys := buildKeySet(extraKeys)
|
||||
redacted := redactValueWithDepth(value, keys, 0)
|
||||
encoded, err := json.Marshal(redacted)
|
||||
if err != nil {
|
||||
return "<redacted>"
|
||||
}
|
||||
return string(encoded)
|
||||
}
|
||||
|
||||
func buildKeySet(extraKeys []string) map[string]struct{} {
|
||||
keys := make(map[string]struct{}, len(defaultSensitiveKeys)+len(extraKeys))
|
||||
for k := range defaultSensitiveKeys {
|
||||
keys[k] = struct{}{}
|
||||
}
|
||||
for _, key := range extraKeys {
|
||||
normalized := normalizeKey(key)
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
keys[normalized] = struct{}{}
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func redactValueWithDepth(value any, keys map[string]struct{}, depth int) any {
|
||||
if depth > maxRedactDepth {
|
||||
return "<depth limit exceeded>"
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case map[string]any:
|
||||
out := make(map[string]any, len(v))
|
||||
for k, val := range v {
|
||||
if isSensitiveKey(k, keys) {
|
||||
out[k] = "***"
|
||||
continue
|
||||
}
|
||||
out[k] = redactValueWithDepth(val, keys, depth+1)
|
||||
}
|
||||
return out
|
||||
case []any:
|
||||
out := make([]any, len(v))
|
||||
for i, item := range v {
|
||||
out[i] = redactValueWithDepth(item, keys, depth+1)
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
func isSensitiveKey(key string, keys map[string]struct{}) bool {
|
||||
_, ok := keys[normalizeKey(key)]
|
||||
return ok
|
||||
}
|
||||
|
||||
func normalizeKey(key string) string {
|
||||
return strings.ToLower(strings.TrimSpace(key))
|
||||
}
|
||||
99
backend/internal/util/responseheaders/responseheaders.go
Normal file
99
backend/internal/util/responseheaders/responseheaders.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package responseheaders
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
// defaultAllowed 定义允许透传的响应头白名单
|
||||
// 注意:以下头部由 Go HTTP 包自动处理,不应手动设置:
|
||||
// - content-length: 由 ResponseWriter 根据实际写入数据自动设置
|
||||
// - transfer-encoding: 由 HTTP 库根据需要自动添加/移除
|
||||
// - connection: 由 HTTP 库管理连接复用
|
||||
var defaultAllowed = map[string]struct{}{
|
||||
"content-type": {},
|
||||
"content-encoding": {},
|
||||
"content-language": {},
|
||||
"cache-control": {},
|
||||
"etag": {},
|
||||
"last-modified": {},
|
||||
"expires": {},
|
||||
"vary": {},
|
||||
"date": {},
|
||||
"x-request-id": {},
|
||||
"x-ratelimit-limit-requests": {},
|
||||
"x-ratelimit-limit-tokens": {},
|
||||
"x-ratelimit-remaining-requests": {},
|
||||
"x-ratelimit-remaining-tokens": {},
|
||||
"x-ratelimit-reset-requests": {},
|
||||
"x-ratelimit-reset-tokens": {},
|
||||
"retry-after": {},
|
||||
"location": {},
|
||||
"www-authenticate": {},
|
||||
}
|
||||
|
||||
// hopByHopHeaders 是跳过的 hop-by-hop 头部,这些头部由 HTTP 库自动处理
|
||||
var hopByHopHeaders = map[string]struct{}{
|
||||
"content-length": {},
|
||||
"transfer-encoding": {},
|
||||
"connection": {},
|
||||
}
|
||||
|
||||
func FilterHeaders(src http.Header, cfg config.ResponseHeaderConfig) http.Header {
|
||||
allowed := make(map[string]struct{}, len(defaultAllowed)+len(cfg.AdditionalAllowed))
|
||||
for key := range defaultAllowed {
|
||||
allowed[key] = struct{}{}
|
||||
}
|
||||
// 关闭时只使用默认白名单,additional/force_remove 不生效
|
||||
if cfg.Enabled {
|
||||
for _, key := range cfg.AdditionalAllowed {
|
||||
normalized := strings.ToLower(strings.TrimSpace(key))
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
allowed[normalized] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
forceRemove := map[string]struct{}{}
|
||||
if cfg.Enabled {
|
||||
forceRemove = make(map[string]struct{}, len(cfg.ForceRemove))
|
||||
for _, key := range cfg.ForceRemove {
|
||||
normalized := strings.ToLower(strings.TrimSpace(key))
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
forceRemove[normalized] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
filtered := make(http.Header, len(src))
|
||||
for key, values := range src {
|
||||
lower := strings.ToLower(key)
|
||||
if _, blocked := forceRemove[lower]; blocked {
|
||||
continue
|
||||
}
|
||||
if _, ok := allowed[lower]; !ok {
|
||||
continue
|
||||
}
|
||||
// 跳过 hop-by-hop 头部,这些由 HTTP 库自动处理
|
||||
if _, isHopByHop := hopByHopHeaders[lower]; isHopByHop {
|
||||
continue
|
||||
}
|
||||
for _, value := range values {
|
||||
filtered.Add(key, value)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func WriteFilteredHeaders(dst http.Header, src http.Header, cfg config.ResponseHeaderConfig) {
|
||||
filtered := FilterHeaders(src, cfg)
|
||||
for key, values := range filtered {
|
||||
for _, value := range values {
|
||||
dst.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package responseheaders
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
func TestFilterHeadersDisabledUsesDefaultAllowlist(t *testing.T) {
|
||||
src := http.Header{}
|
||||
src.Add("Content-Type", "application/json")
|
||||
src.Add("X-Request-Id", "req-123")
|
||||
src.Add("X-Test", "ok")
|
||||
src.Add("Connection", "keep-alive")
|
||||
src.Add("Content-Length", "123")
|
||||
|
||||
cfg := config.ResponseHeaderConfig{
|
||||
Enabled: false,
|
||||
ForceRemove: []string{"x-request-id"},
|
||||
}
|
||||
|
||||
filtered := FilterHeaders(src, cfg)
|
||||
if filtered.Get("Content-Type") != "application/json" {
|
||||
t.Fatalf("expected Content-Type passthrough, got %q", filtered.Get("Content-Type"))
|
||||
}
|
||||
if filtered.Get("X-Request-Id") != "req-123" {
|
||||
t.Fatalf("expected X-Request-Id allowed, got %q", filtered.Get("X-Request-Id"))
|
||||
}
|
||||
if filtered.Get("X-Test") != "" {
|
||||
t.Fatalf("expected X-Test removed, got %q", filtered.Get("X-Test"))
|
||||
}
|
||||
if filtered.Get("Connection") != "" {
|
||||
t.Fatalf("expected Connection to be removed, got %q", filtered.Get("Connection"))
|
||||
}
|
||||
if filtered.Get("Content-Length") != "" {
|
||||
t.Fatalf("expected Content-Length to be removed, got %q", filtered.Get("Content-Length"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterHeadersEnabledUsesAllowlist(t *testing.T) {
|
||||
src := http.Header{}
|
||||
src.Add("Content-Type", "application/json")
|
||||
src.Add("X-Extra", "ok")
|
||||
src.Add("X-Remove", "nope")
|
||||
src.Add("X-Blocked", "nope")
|
||||
|
||||
cfg := config.ResponseHeaderConfig{
|
||||
Enabled: true,
|
||||
AdditionalAllowed: []string{"x-extra"},
|
||||
ForceRemove: []string{"x-remove"},
|
||||
}
|
||||
|
||||
filtered := FilterHeaders(src, cfg)
|
||||
if filtered.Get("Content-Type") != "application/json" {
|
||||
t.Fatalf("expected Content-Type allowed, got %q", filtered.Get("Content-Type"))
|
||||
}
|
||||
if filtered.Get("X-Extra") != "ok" {
|
||||
t.Fatalf("expected X-Extra allowed, got %q", filtered.Get("X-Extra"))
|
||||
}
|
||||
if filtered.Get("X-Remove") != "" {
|
||||
t.Fatalf("expected X-Remove removed, got %q", filtered.Get("X-Remove"))
|
||||
}
|
||||
if filtered.Get("X-Blocked") != "" {
|
||||
t.Fatalf("expected X-Blocked removed, got %q", filtered.Get("X-Blocked"))
|
||||
}
|
||||
}
|
||||
154
backend/internal/util/urlvalidator/validator.go
Normal file
154
backend/internal/util/urlvalidator/validator.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package urlvalidator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ValidationOptions struct {
|
||||
AllowedHosts []string
|
||||
RequireAllowlist bool
|
||||
AllowPrivate bool
|
||||
}
|
||||
|
||||
func ValidateURLFormat(raw string, allowInsecureHTTP bool) (string, error) {
|
||||
// 最小格式校验:仅保证 URL 可解析且 scheme 合规,不做白名单/私网/SSRF 校验
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return "", errors.New("url is required")
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(trimmed)
|
||||
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
|
||||
return "", fmt.Errorf("invalid url: %s", trimmed)
|
||||
}
|
||||
|
||||
scheme := strings.ToLower(parsed.Scheme)
|
||||
if scheme != "https" && (!allowInsecureHTTP || scheme != "http") {
|
||||
return "", fmt.Errorf("invalid url scheme: %s", parsed.Scheme)
|
||||
}
|
||||
|
||||
host := strings.TrimSpace(parsed.Hostname())
|
||||
if host == "" {
|
||||
return "", errors.New("invalid host")
|
||||
}
|
||||
|
||||
if port := parsed.Port(); port != "" {
|
||||
num, err := strconv.Atoi(port)
|
||||
if err != nil || num <= 0 || num > 65535 {
|
||||
return "", fmt.Errorf("invalid port: %s", port)
|
||||
}
|
||||
}
|
||||
|
||||
return trimmed, nil
|
||||
}
|
||||
|
||||
func ValidateHTTPSURL(raw string, opts ValidationOptions) (string, error) {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return "", errors.New("url is required")
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(trimmed)
|
||||
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
|
||||
return "", fmt.Errorf("invalid url: %s", trimmed)
|
||||
}
|
||||
if !strings.EqualFold(parsed.Scheme, "https") {
|
||||
return "", fmt.Errorf("invalid url scheme: %s", parsed.Scheme)
|
||||
}
|
||||
|
||||
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
|
||||
if host == "" {
|
||||
return "", errors.New("invalid host")
|
||||
}
|
||||
if !opts.AllowPrivate && isBlockedHost(host) {
|
||||
return "", fmt.Errorf("host is not allowed: %s", host)
|
||||
}
|
||||
|
||||
allowlist := normalizeAllowlist(opts.AllowedHosts)
|
||||
if opts.RequireAllowlist && len(allowlist) == 0 {
|
||||
return "", errors.New("allowlist is not configured")
|
||||
}
|
||||
if len(allowlist) > 0 && !isAllowedHost(host, allowlist) {
|
||||
return "", fmt.Errorf("host is not allowed: %s", host)
|
||||
}
|
||||
|
||||
parsed.Path = strings.TrimRight(parsed.Path, "/")
|
||||
parsed.RawPath = ""
|
||||
return strings.TrimRight(parsed.String(), "/"), nil
|
||||
}
|
||||
|
||||
// ValidateResolvedIP 验证 DNS 解析后的 IP 地址是否安全
|
||||
// 用于防止 DNS Rebinding 攻击:在实际 HTTP 请求时调用此函数验证解析后的 IP
|
||||
func ValidateResolvedIP(host string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dns resolution failed: %w", err)
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() ||
|
||||
ip.IsLinkLocalMulticast() || ip.IsUnspecified() {
|
||||
return fmt.Errorf("resolved ip %s is not allowed", ip.String())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeAllowlist(values []string) []string {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
normalized := make([]string, 0, len(values))
|
||||
for _, v := range values {
|
||||
entry := strings.ToLower(strings.TrimSpace(v))
|
||||
if entry == "" {
|
||||
continue
|
||||
}
|
||||
if host, _, err := net.SplitHostPort(entry); err == nil {
|
||||
entry = host
|
||||
}
|
||||
normalized = append(normalized, entry)
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func isAllowedHost(host string, allowlist []string) bool {
|
||||
for _, entry := range allowlist {
|
||||
if entry == "" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(entry, "*.") {
|
||||
suffix := strings.TrimPrefix(entry, "*.")
|
||||
if host == suffix || strings.HasSuffix(host, "."+suffix) {
|
||||
return true
|
||||
}
|
||||
continue
|
||||
}
|
||||
if host == entry {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isBlockedHost(host string) bool {
|
||||
if host == "localhost" || strings.HasSuffix(host, ".localhost") {
|
||||
return true
|
||||
}
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsUnspecified() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
24
backend/internal/util/urlvalidator/validator_test.go
Normal file
24
backend/internal/util/urlvalidator/validator_test.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package urlvalidator
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestValidateURLFormat(t *testing.T) {
|
||||
if _, err := ValidateURLFormat("", false); err == nil {
|
||||
t.Fatalf("expected empty url to fail")
|
||||
}
|
||||
if _, err := ValidateURLFormat("://bad", false); err == nil {
|
||||
t.Fatalf("expected invalid url to fail")
|
||||
}
|
||||
if _, err := ValidateURLFormat("http://example.com", false); err == nil {
|
||||
t.Fatalf("expected http to fail when allow_insecure_http is false")
|
||||
}
|
||||
if _, err := ValidateURLFormat("https://example.com", false); err != nil {
|
||||
t.Fatalf("expected https to pass, got %v", err)
|
||||
}
|
||||
if _, err := ValidateURLFormat("http://example.com", true); err != nil {
|
||||
t.Fatalf("expected http to pass when allow_insecure_http is true, got %v", err)
|
||||
}
|
||||
if _, err := ValidateURLFormat("https://example.com:bad", true); err == nil {
|
||||
t.Fatalf("expected invalid port to fail")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user