fix: validate wxpay payments against order snapshots

This commit is contained in:
IanShaw027
2026-04-21 12:57:35 +08:00
parent 35aeeaa6e1
commit 119f784d19
9 changed files with 239 additions and 22 deletions

View File

@@ -32,6 +32,13 @@ const (
wxpayResultPath = "/payment/result" wxpayResultPath = "/payment/result"
) )
const (
wxpayMetadataAppID = "appid"
wxpayMetadataMerchantID = "mchid"
wxpayMetadataCurrency = "currency"
wxpayMetadataTradeState = "trade_state"
)
// WeChat Pay create-payment modes. // WeChat Pay create-payment modes.
const ( const (
wxpayModeNative = "native" wxpayModeNative = "native"
@@ -355,6 +362,32 @@ func mapWxState(s string) string {
} }
} }
func buildWxpayTransactionMetadata(tx *payments.Transaction) map[string]string {
if tx == nil {
return nil
}
metadata := map[string]string{}
if appID := wxSV(tx.Appid); appID != "" {
metadata[wxpayMetadataAppID] = appID
}
if merchantID := wxSV(tx.Mchid); merchantID != "" {
metadata[wxpayMetadataMerchantID] = merchantID
}
if tradeState := wxSV(tx.TradeState); tradeState != "" {
metadata[wxpayMetadataTradeState] = tradeState
}
if tx.Amount != nil {
if currency := wxSV(tx.Amount.Currency); currency != "" {
metadata[wxpayMetadataCurrency] = currency
}
}
if len(metadata) == 0 {
return nil
}
return metadata
}
func (w *Wxpay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) { func (w *Wxpay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
c, err := w.ensureClient() c, err := w.ensureClient()
if err != nil { if err != nil {
@@ -379,7 +412,13 @@ func (w *Wxpay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryO
if tx.SuccessTime != nil { if tx.SuccessTime != nil {
pa = *tx.SuccessTime pa = *tx.SuccessTime
} }
return &payment.QueryOrderResponse{TradeNo: id, Status: mapWxState(wxSV(tx.TradeState)), Amount: amt, PaidAt: pa}, nil return &payment.QueryOrderResponse{
TradeNo: id,
Status: mapWxState(wxSV(tx.TradeState)),
Amount: amt,
PaidAt: pa,
Metadata: buildWxpayTransactionMetadata(tx),
}, nil
} }
func (w *Wxpay) VerifyNotification(ctx context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) { func (w *Wxpay) VerifyNotification(ctx context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) {
@@ -411,7 +450,7 @@ func (w *Wxpay) VerifyNotification(ctx context.Context, rawBody string, headers
} }
return &payment.PaymentNotification{ return &payment.PaymentNotification{
TradeNo: wxSV(tx.TransactionId), OrderID: wxSV(tx.OutTradeNo), TradeNo: wxSV(tx.TransactionId), OrderID: wxSV(tx.OutTradeNo),
Amount: amt, Status: st, RawData: rawBody, Amount: amt, Status: st, RawData: rawBody, Metadata: buildWxpayTransactionMetadata(&tx),
}, nil }, nil
} }

View File

@@ -10,6 +10,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/payment" "github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/wechatpay-apiv3/wechatpay-go/core" "github.com/wechatpay-apiv3/wechatpay-go/core"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/h5" "github.com/wechatpay-apiv3/wechatpay-go/services/payments/h5"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/jsapi" "github.com/wechatpay-apiv3/wechatpay-go/services/payments/jsapi"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/native" "github.com/wechatpay-apiv3/wechatpay-go/services/payments/native"
@@ -102,6 +103,33 @@ func TestWxSV(t *testing.T) {
} }
} }
func TestBuildWxpayTransactionMetadata(t *testing.T) {
t.Parallel()
tx := &payments.Transaction{
Appid: strPtr("wx-app-id"),
Mchid: strPtr("mch-id"),
TradeState: strPtr(wxpayTradeStateSuccess),
Amount: &payments.Amount{
Currency: strPtr(wxpayCurrency),
},
}
metadata := buildWxpayTransactionMetadata(tx)
if metadata[wxpayMetadataAppID] != "wx-app-id" {
t.Fatalf("appid = %q", metadata[wxpayMetadataAppID])
}
if metadata[wxpayMetadataMerchantID] != "mch-id" {
t.Fatalf("mchid = %q", metadata[wxpayMetadataMerchantID])
}
if metadata[wxpayMetadataCurrency] != wxpayCurrency {
t.Fatalf("currency = %q", metadata[wxpayMetadataCurrency])
}
if metadata[wxpayMetadataTradeState] != wxpayTradeStateSuccess {
t.Fatalf("trade_state = %q", metadata[wxpayMetadataTradeState])
}
}
func strPtr(s string) *string { func strPtr(s string) *string {
return &s return &s
} }

View File

@@ -149,19 +149,21 @@ type CreatePaymentResponse struct {
// QueryOrderResponse describes the payment status from the upstream provider. // QueryOrderResponse describes the payment status from the upstream provider.
type QueryOrderResponse struct { type QueryOrderResponse struct {
TradeNo string TradeNo string
Status string // "pending", "paid", "failed", "refunded" Status string // "pending", "paid", "failed", "refunded"
Amount float64 // Amount in CNY Amount float64 // Amount in CNY
PaidAt string // RFC3339 timestamp or empty PaidAt string // RFC3339 timestamp or empty
Metadata map[string]string
} }
// PaymentNotification is the parsed result of a webhook/notify callback. // PaymentNotification is the parsed result of a webhook/notify callback.
type PaymentNotification struct { type PaymentNotification struct {
TradeNo string TradeNo string
OrderID string OrderID string
Amount float64 Amount float64
Status string // "success" or "failed" Status string // "success" or "failed"
RawData string // Raw notification body for audit RawData string // Raw notification body for audit
Metadata map[string]string
} }
// RefundRequest contains the parameters for requesting a refund. // RefundRequest contains the parameters for requesting a refund.

View File

@@ -28,14 +28,14 @@ func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payme
// Fallback: try legacy format (sub2_N where N is DB ID) // Fallback: try legacy format (sub2_N where N is DB ID)
trimmed := strings.TrimPrefix(n.OrderID, orderIDPrefix) trimmed := strings.TrimPrefix(n.OrderID, orderIDPrefix)
if oid, parseErr := strconv.ParseInt(trimmed, 10, 64); parseErr == nil { if oid, parseErr := strconv.ParseInt(trimmed, 10, 64); parseErr == nil {
return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk) return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk, n.Metadata)
} }
return fmt.Errorf("order not found for out_trade_no: %s", n.OrderID) return fmt.Errorf("order not found for out_trade_no: %s", n.OrderID)
} }
return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk) return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk, n.Metadata)
} }
func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string) error { func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string, metadata map[string]string) error {
o, err := s.entClient.PaymentOrder.Get(ctx, oid) o, err := s.entClient.PaymentOrder.Get(ctx, oid)
if err != nil { if err != nil {
slog.Error("order not found", "orderID", oid) slog.Error("order not found", "orderID", oid)
@@ -54,6 +54,13 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
}) })
return fmt.Errorf("provider mismatch: expected %s, got %s", expectedProviderKey, pk) return fmt.Errorf("provider mismatch: expected %s, got %s", expectedProviderKey, pk)
} }
if err := validateProviderNotificationMetadata(o, pk, metadata); err != nil {
s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_METADATA_MISMATCH", pk, map[string]any{
"detail": err.Error(),
"tradeNo": tradeNo,
})
return err
}
// Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount). // Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount).
// Also skip if paid is NaN/Inf (malformed provider data). // Also skip if paid is NaN/Inf (malformed provider data).
if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) { if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) {
@@ -69,6 +76,50 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
return s.toPaid(ctx, o, tradeNo, paid, pk) return s.toPaid(ctx, o, tradeNo, paid, pk)
} }
func validateProviderNotificationMetadata(order *dbent.PaymentOrder, providerKey string, metadata map[string]string) error {
if order == nil || len(metadata) == 0 || !strings.EqualFold(strings.TrimSpace(providerKey), payment.TypeWxpay) {
return nil
}
snapshot := psOrderProviderSnapshot(order)
if snapshot == nil {
return nil
}
if expected := strings.TrimSpace(snapshot.MerchantAppID); expected != "" {
actual := strings.TrimSpace(metadata["appid"])
if actual == "" {
return fmt.Errorf("wxpay notification missing appid")
}
if !strings.EqualFold(expected, actual) {
return fmt.Errorf("wxpay appid mismatch: expected %s, got %s", expected, actual)
}
}
if expected := strings.TrimSpace(snapshot.MerchantID); expected != "" {
actual := strings.TrimSpace(metadata["mchid"])
if actual == "" {
return fmt.Errorf("wxpay notification missing mchid")
}
if !strings.EqualFold(expected, actual) {
return fmt.Errorf("wxpay mchid mismatch: expected %s, got %s", expected, actual)
}
}
if expected := strings.TrimSpace(snapshot.Currency); expected != "" {
actual := strings.ToUpper(strings.TrimSpace(metadata["currency"]))
if actual == "" {
return fmt.Errorf("wxpay notification missing currency")
}
if !strings.EqualFold(expected, actual) {
return fmt.Errorf("wxpay currency mismatch: expected %s, got %s", expected, actual)
}
}
if actual := strings.TrimSpace(metadata["trade_state"]); actual != "" && !strings.EqualFold(actual, "SUCCESS") {
return fmt.Errorf("wxpay trade_state mismatch: expected SUCCESS, got %s", actual)
}
return nil
}
func expectedNotificationProviderKey(registry *payment.Registry, orderPaymentType string, orderProviderKey string, instanceProviderKey string) string { func expectedNotificationProviderKey(registry *payment.Registry, orderPaymentType string, orderProviderKey string, instanceProviderKey string) string {
if key := strings.TrimSpace(instanceProviderKey); key != "" { if key := strings.TrimSpace(instanceProviderKey); key != "" {
return key return key

View File

@@ -264,3 +264,46 @@ func TestExpectedNotificationProviderKeyForOrderUsesSnapshotProviderKey(t *testi
expectedNotificationProviderKeyForOrder(registry, order, ""), expectedNotificationProviderKeyForOrder(registry, order, ""),
) )
} }
func TestValidateProviderNotificationMetadataRejectsWxpaySnapshotMismatch(t *testing.T) {
t.Parallel()
order := &dbent.PaymentOrder{
PaymentType: payment.TypeWxpay,
ProviderSnapshot: map[string]any{
"schema_version": 1,
"merchant_app_id": "wx-app-expected",
"merchant_id": "mch-expected",
"currency": "CNY",
},
}
err := validateProviderNotificationMetadata(order, payment.TypeWxpay, map[string]string{
"appid": "wx-app-other",
"mchid": "mch-expected",
"currency": "CNY",
"trade_state": "SUCCESS",
})
assert.ErrorContains(t, err, "wxpay appid mismatch")
}
func TestValidateProviderNotificationMetadataAllowsLegacyOrdersWithoutSnapshotFields(t *testing.T) {
t.Parallel()
order := &dbent.PaymentOrder{
PaymentType: payment.TypeWxpay,
ProviderSnapshot: map[string]any{
"schema_version": 1,
"provider_instance_id": "9",
"provider_key": payment.TypeWxpay,
},
}
err := validateProviderNotificationMetadata(order, payment.TypeWxpay, map[string]string{
"appid": "wx-app-runtime",
"mchid": "mch-runtime",
"currency": "CNY",
"trade_state": "SUCCESS",
})
assert.NoError(t, err)
}

View File

@@ -139,7 +139,7 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
tm = defaultOrderTimeoutMin tm = defaultOrderTimeoutMin
} }
exp := time.Now().Add(time.Duration(tm) * time.Minute) exp := time.Now().Add(time.Duration(tm) * time.Minute)
providerSnapshot := buildPaymentOrderProviderSnapshot(sel) providerSnapshot := buildPaymentOrderProviderSnapshot(sel, req)
selectedInstanceID := "" selectedInstanceID := ""
selectedProviderKey := "" selectedProviderKey := ""
if sel != nil { if sel != nil {
@@ -208,13 +208,13 @@ func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, us
return nil return nil
} }
func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection) map[string]any { func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection, req CreateOrderRequest) map[string]any {
if sel == nil { if sel == nil {
return nil return nil
} }
snapshot := map[string]any{} snapshot := map[string]any{}
snapshot["schema_version"] = 1 snapshot["schema_version"] = 2
instanceID := strings.TrimSpace(sel.InstanceID) instanceID := strings.TrimSpace(sel.InstanceID)
if instanceID != "" { if instanceID != "" {
@@ -231,12 +231,32 @@ func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection) map[strin
snapshot["payment_mode"] = paymentMode snapshot["payment_mode"] = paymentMode
} }
if providerKey == payment.TypeWxpay {
if merchantAppID := paymentOrderSnapshotWxpayAppID(sel, req); merchantAppID != "" {
snapshot["merchant_app_id"] = merchantAppID
}
if merchantID := strings.TrimSpace(sel.Config["mchId"]); merchantID != "" {
snapshot["merchant_id"] = merchantID
}
snapshot["currency"] = "CNY"
}
if len(snapshot) == 1 { if len(snapshot) == 1 {
return nil return nil
} }
return snapshot return snapshot
} }
func paymentOrderSnapshotWxpayAppID(sel *payment.InstanceSelection, req CreateOrderRequest) string {
if sel == nil || strings.TrimSpace(sel.ProviderKey) != payment.TypeWxpay {
return ""
}
if strings.TrimSpace(req.OpenID) != "" {
return strings.TrimSpace(provider.ResolveWxpayJSAPIAppID(sel.Config))
}
return strings.TrimSpace(sel.Config["appId"])
}
func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, userID int64, amount, limit float64) error { func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, userID int64, amount, limit float64) error {
if limit <= 0 { if limit <= 0 {
return nil return nil

View File

@@ -163,7 +163,7 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
} }
notificationTradeNo = upstreamTradeNo notificationTradeNo = upstreamTradeNo
} }
if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: notificationTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey()); err != nil { if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: notificationTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess, Metadata: resp.Metadata}, prov.ProviderKey()); err != nil {
slog.Error("fulfillment failed during checkPaid", "orderID", o.ID, "error", err) slog.Error("fulfillment failed during checkPaid", "orderID", o.ID, "error", err)
// Still return already_paid — order was paid, fulfillment can be retried // Still return already_paid — order was paid, fulfillment can be retried
} }

View File

@@ -15,6 +15,9 @@ type paymentOrderProviderSnapshot struct {
ProviderInstanceID string ProviderInstanceID string
ProviderKey string ProviderKey string
PaymentMode string PaymentMode string
MerchantAppID string
MerchantID string
Currency string
} }
func psOrderProviderSnapshot(order *dbent.PaymentOrder) *paymentOrderProviderSnapshot { func psOrderProviderSnapshot(order *dbent.PaymentOrder) *paymentOrderProviderSnapshot {
@@ -27,8 +30,17 @@ func psOrderProviderSnapshot(order *dbent.PaymentOrder) *paymentOrderProviderSna
ProviderInstanceID: psSnapshotStringValue(order.ProviderSnapshot["provider_instance_id"]), ProviderInstanceID: psSnapshotStringValue(order.ProviderSnapshot["provider_instance_id"]),
ProviderKey: psSnapshotStringValue(order.ProviderSnapshot["provider_key"]), ProviderKey: psSnapshotStringValue(order.ProviderSnapshot["provider_key"]),
PaymentMode: psSnapshotStringValue(order.ProviderSnapshot["payment_mode"]), PaymentMode: psSnapshotStringValue(order.ProviderSnapshot["payment_mode"]),
MerchantAppID: psSnapshotStringValue(order.ProviderSnapshot["merchant_app_id"]),
MerchantID: psSnapshotStringValue(order.ProviderSnapshot["merchant_id"]),
Currency: psSnapshotStringValue(order.ProviderSnapshot["currency"]),
} }
if snapshot.SchemaVersion == 0 && snapshot.ProviderInstanceID == "" && snapshot.ProviderKey == "" && snapshot.PaymentMode == "" { if snapshot.SchemaVersion == 0 &&
snapshot.ProviderInstanceID == "" &&
snapshot.ProviderKey == "" &&
snapshot.PaymentMode == "" &&
snapshot.MerchantAppID == "" &&
snapshot.MerchantID == "" &&
snapshot.Currency == "" {
return nil return nil
} }
return snapshot return snapshot

View File

@@ -26,18 +26,21 @@ func TestBuildPaymentOrderProviderSnapshot_ExcludesSensitiveConfig(t *testing.T)
}, },
} }
snapshot := buildPaymentOrderProviderSnapshot(sel) snapshot := buildPaymentOrderProviderSnapshot(sel, CreateOrderRequest{})
require.Equal(t, map[string]any{ require.Equal(t, map[string]any{
"schema_version": 1, "schema_version": 2,
"provider_instance_id": "12", "provider_instance_id": "12",
"provider_key": payment.TypeWxpay, "provider_key": payment.TypeWxpay,
"payment_mode": "popup", "payment_mode": "popup",
"merchant_app_id": "wx-app-id",
"currency": "CNY",
}, snapshot) }, snapshot)
require.NotContains(t, snapshot, "config") require.NotContains(t, snapshot, "config")
require.NotContains(t, snapshot, "privateKey") require.NotContains(t, snapshot, "privateKey")
require.NotContains(t, snapshot, "apiV3Key") require.NotContains(t, snapshot, "apiV3Key")
require.NotContains(t, snapshot, "supported_types") require.NotContains(t, snapshot, "supported_types")
require.NotContains(t, snapshot, "instance_name") require.NotContains(t, snapshot, "instance_name")
require.NotContains(t, snapshot, "merchant_id")
} }
func TestCreateOrderInTx_WritesProviderSnapshot(t *testing.T) { func TestCreateOrderInTx_WritesProviderSnapshot(t *testing.T) {
@@ -98,7 +101,7 @@ func TestCreateOrderInTx_WritesProviderSnapshot(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, strconv.FormatInt(instance.ID, 10), valueOrEmpty(order.ProviderInstanceID)) require.Equal(t, strconv.FormatInt(instance.ID, 10), valueOrEmpty(order.ProviderInstanceID))
require.Equal(t, payment.TypeAlipay, valueOrEmpty(order.ProviderKey)) require.Equal(t, payment.TypeAlipay, valueOrEmpty(order.ProviderKey))
require.Equal(t, float64(1), order.ProviderSnapshot["schema_version"]) require.Equal(t, float64(2), order.ProviderSnapshot["schema_version"])
require.Equal(t, strconv.FormatInt(instance.ID, 10), order.ProviderSnapshot["provider_instance_id"]) require.Equal(t, strconv.FormatInt(instance.ID, 10), order.ProviderSnapshot["provider_instance_id"])
require.Equal(t, payment.TypeAlipay, order.ProviderSnapshot["provider_key"]) require.Equal(t, payment.TypeAlipay, order.ProviderSnapshot["provider_key"])
require.Equal(t, "redirect", order.ProviderSnapshot["payment_mode"]) require.Equal(t, "redirect", order.ProviderSnapshot["payment_mode"])
@@ -108,6 +111,25 @@ func TestCreateOrderInTx_WritesProviderSnapshot(t *testing.T) {
require.NotContains(t, order.ProviderSnapshot, "instance_name") require.NotContains(t, order.ProviderSnapshot, "instance_name")
} }
func TestBuildPaymentOrderProviderSnapshot_UsesWxpayJSAPIAppIDForOpenIDOrders(t *testing.T) {
t.Parallel()
snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{
InstanceID: "88",
ProviderKey: payment.TypeWxpay,
Config: map[string]string{
"appId": "wx-open-app",
"mpAppId": "wx-mp-app",
"mchId": "mch-88",
},
PaymentMode: "jsapi",
}, CreateOrderRequest{OpenID: "openid-123"})
require.Equal(t, "wx-mp-app", snapshot["merchant_app_id"])
require.Equal(t, "mch-88", snapshot["merchant_id"])
require.Equal(t, "CNY", snapshot["currency"])
}
func valueOrEmpty(v *string) string { func valueOrEmpty(v *string) string {
if v == nil { if v == nil {
return "" return ""