diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go index 254af5fe..354f3cd1 100644 --- a/backend/internal/service/payment_order.go +++ b/backend/internal/service/payment_order.go @@ -350,7 +350,7 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen } subject := s.buildPaymentSubject(plan, limitAmount, cfg) outTradeNo := order.OutTradeNo - canonicalReturnURL, err := CanonicalizeReturnURL(req.ReturnURL) + canonicalReturnURL, err := CanonicalizeReturnURL(req.ReturnURL, req.SrcHost) if err != nil { return nil, err } diff --git a/backend/internal/service/payment_resume_service.go b/backend/internal/service/payment_resume_service.go index 486aaac0..1806f5da 100644 --- a/backend/internal/service/payment_resume_service.go +++ b/backend/internal/service/payment_resume_service.go @@ -7,6 +7,7 @@ import ( "encoding/base64" "encoding/json" "fmt" + "net" "net/url" "strconv" "strings" @@ -16,6 +17,8 @@ import ( infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) +const paymentResultReturnPath = "/payment/result" + const ( PaymentSourceHostedRedirect = "hosted_redirect" PaymentSourceWechatInAppResume = "wechat_in_app_resume" @@ -215,7 +218,7 @@ func visibleMethodSourceSettingKey(method string) string { } } -func CanonicalizeReturnURL(raw string) (string, error) { +func CanonicalizeReturnURL(raw string, srcHost string) (string, error) { raw = strings.TrimSpace(raw) if raw == "" { return "", nil @@ -231,19 +234,29 @@ func CanonicalizeReturnURL(raw string) (string, error) { if parsed.Path == "" { parsed.Path = "/" } + if parsed.Path != paymentResultReturnPath { + return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must target the canonical internal payment result page") + } + if !sameOriginHost(parsed.Host, srcHost) { + return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use the same host as the current site") + } return parsed.String(), nil } func buildPaymentReturnURL(base string, orderID int64, resumeToken string) (string, error) { - canonical, err := CanonicalizeReturnURL(base) - if err != nil || canonical == "" { - return canonical, err + canonical := strings.TrimSpace(base) + if canonical == "" { + return "", nil } parsed, err := url.Parse(canonical) if err != nil { return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be a valid URL") } + if !parsed.IsAbs() || parsed.Host == "" { + return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be a valid absolute URL") + } + parsed.Fragment = "" query := parsed.Query() if orderID > 0 { @@ -258,6 +271,31 @@ func buildPaymentReturnURL(base string, orderID int64, resumeToken string) (stri return parsed.String(), nil } +func sameOriginHost(returnURLHost string, requestHost string) bool { + returnHost := strings.TrimSpace(returnURLHost) + reqHost := strings.TrimSpace(requestHost) + if returnHost == "" || reqHost == "" { + return false + } + if strings.EqualFold(returnHost, reqHost) { + return true + } + + returnName, returnPort := splitHostPortDefault(returnHost) + reqName, reqPort := splitHostPortDefault(reqHost) + if returnName == "" || reqName == "" { + return false + } + return strings.EqualFold(returnName, reqName) && returnPort == reqPort +} + +func splitHostPortDefault(raw string) (string, string) { + if host, port, err := net.SplitHostPort(raw); err == nil { + return host, port + } + return raw, "" +} + func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, error) { if err := s.ensureSigningKey(); err != nil { return "", err diff --git a/backend/internal/service/payment_resume_service_test.go b/backend/internal/service/payment_resume_service_test.go index 12d67be2..7fa8dca1 100644 --- a/backend/internal/service/payment_resume_service_test.go +++ b/backend/internal/service/payment_resume_service_test.go @@ -64,23 +64,39 @@ func TestNormalizePaymentSource(t *testing.T) { func TestCanonicalizeReturnURL(t *testing.T) { t.Parallel() - got, err := CanonicalizeReturnURL("https://example.com/pay/result?b=2#a") + got, err := CanonicalizeReturnURL("https://example.com/payment/result?b=2#a", "example.com") if err != nil { t.Fatalf("CanonicalizeReturnURL returned error: %v", err) } - if got != "https://example.com/pay/result?b=2" { - t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://example.com/pay/result?b=2") + if got != "https://example.com/payment/result?b=2" { + t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://example.com/payment/result?b=2") } } func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) { t.Parallel() - if _, err := CanonicalizeReturnURL("/payment/result"); err == nil { + if _, err := CanonicalizeReturnURL("/payment/result", "example.com"); err == nil { t.Fatal("CanonicalizeReturnURL should reject relative URLs") } } +func TestCanonicalizeReturnURLRejectsExternalHost(t *testing.T) { + t.Parallel() + + if _, err := CanonicalizeReturnURL("https://evil.example/payment/result", "app.example.com"); err == nil { + t.Fatal("CanonicalizeReturnURL should reject external hosts") + } +} + +func TestCanonicalizeReturnURLRejectsNonCanonicalPath(t *testing.T) { + t.Parallel() + + if _, err := CanonicalizeReturnURL("https://app.example.com/orders/42", "app.example.com"); err == nil { + t.Fatal("CanonicalizeReturnURL should reject non-canonical result paths") + } +} + func TestBuildPaymentReturnURL(t *testing.T) { t.Parallel()