fix: validate wxpay payments against order snapshots

This commit is contained in:
IanShaw027
2026-04-21 12:57:35 +08:00
parent 35aeeaa6e1
commit 119f784d19
9 changed files with 239 additions and 22 deletions

View File

@@ -28,14 +28,14 @@ func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payme
// Fallback: try legacy format (sub2_N where N is DB ID)
trimmed := strings.TrimPrefix(n.OrderID, orderIDPrefix)
if oid, parseErr := strconv.ParseInt(trimmed, 10, 64); parseErr == nil {
return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk)
return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk, n.Metadata)
}
return fmt.Errorf("order not found for out_trade_no: %s", n.OrderID)
}
return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk)
return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk, n.Metadata)
}
func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string) error {
func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string, metadata map[string]string) error {
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
if err != nil {
slog.Error("order not found", "orderID", oid)
@@ -54,6 +54,13 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
})
return fmt.Errorf("provider mismatch: expected %s, got %s", expectedProviderKey, pk)
}
if err := validateProviderNotificationMetadata(o, pk, metadata); err != nil {
s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_METADATA_MISMATCH", pk, map[string]any{
"detail": err.Error(),
"tradeNo": tradeNo,
})
return err
}
// Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount).
// Also skip if paid is NaN/Inf (malformed provider data).
if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) {
@@ -69,6 +76,50 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
return s.toPaid(ctx, o, tradeNo, paid, pk)
}
func validateProviderNotificationMetadata(order *dbent.PaymentOrder, providerKey string, metadata map[string]string) error {
if order == nil || len(metadata) == 0 || !strings.EqualFold(strings.TrimSpace(providerKey), payment.TypeWxpay) {
return nil
}
snapshot := psOrderProviderSnapshot(order)
if snapshot == nil {
return nil
}
if expected := strings.TrimSpace(snapshot.MerchantAppID); expected != "" {
actual := strings.TrimSpace(metadata["appid"])
if actual == "" {
return fmt.Errorf("wxpay notification missing appid")
}
if !strings.EqualFold(expected, actual) {
return fmt.Errorf("wxpay appid mismatch: expected %s, got %s", expected, actual)
}
}
if expected := strings.TrimSpace(snapshot.MerchantID); expected != "" {
actual := strings.TrimSpace(metadata["mchid"])
if actual == "" {
return fmt.Errorf("wxpay notification missing mchid")
}
if !strings.EqualFold(expected, actual) {
return fmt.Errorf("wxpay mchid mismatch: expected %s, got %s", expected, actual)
}
}
if expected := strings.TrimSpace(snapshot.Currency); expected != "" {
actual := strings.ToUpper(strings.TrimSpace(metadata["currency"]))
if actual == "" {
return fmt.Errorf("wxpay notification missing currency")
}
if !strings.EqualFold(expected, actual) {
return fmt.Errorf("wxpay currency mismatch: expected %s, got %s", expected, actual)
}
}
if actual := strings.TrimSpace(metadata["trade_state"]); actual != "" && !strings.EqualFold(actual, "SUCCESS") {
return fmt.Errorf("wxpay trade_state mismatch: expected SUCCESS, got %s", actual)
}
return nil
}
func expectedNotificationProviderKey(registry *payment.Registry, orderPaymentType string, orderProviderKey string, instanceProviderKey string) string {
if key := strings.TrimSpace(instanceProviderKey); key != "" {
return key

View File

@@ -264,3 +264,46 @@ func TestExpectedNotificationProviderKeyForOrderUsesSnapshotProviderKey(t *testi
expectedNotificationProviderKeyForOrder(registry, order, ""),
)
}
func TestValidateProviderNotificationMetadataRejectsWxpaySnapshotMismatch(t *testing.T) {
t.Parallel()
order := &dbent.PaymentOrder{
PaymentType: payment.TypeWxpay,
ProviderSnapshot: map[string]any{
"schema_version": 1,
"merchant_app_id": "wx-app-expected",
"merchant_id": "mch-expected",
"currency": "CNY",
},
}
err := validateProviderNotificationMetadata(order, payment.TypeWxpay, map[string]string{
"appid": "wx-app-other",
"mchid": "mch-expected",
"currency": "CNY",
"trade_state": "SUCCESS",
})
assert.ErrorContains(t, err, "wxpay appid mismatch")
}
func TestValidateProviderNotificationMetadataAllowsLegacyOrdersWithoutSnapshotFields(t *testing.T) {
t.Parallel()
order := &dbent.PaymentOrder{
PaymentType: payment.TypeWxpay,
ProviderSnapshot: map[string]any{
"schema_version": 1,
"provider_instance_id": "9",
"provider_key": payment.TypeWxpay,
},
}
err := validateProviderNotificationMetadata(order, payment.TypeWxpay, map[string]string{
"appid": "wx-app-runtime",
"mchid": "mch-runtime",
"currency": "CNY",
"trade_state": "SUCCESS",
})
assert.NoError(t, err)
}

View File

@@ -139,7 +139,7 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
tm = defaultOrderTimeoutMin
}
exp := time.Now().Add(time.Duration(tm) * time.Minute)
providerSnapshot := buildPaymentOrderProviderSnapshot(sel)
providerSnapshot := buildPaymentOrderProviderSnapshot(sel, req)
selectedInstanceID := ""
selectedProviderKey := ""
if sel != nil {
@@ -208,13 +208,13 @@ func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, us
return nil
}
func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection) map[string]any {
func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection, req CreateOrderRequest) map[string]any {
if sel == nil {
return nil
}
snapshot := map[string]any{}
snapshot["schema_version"] = 1
snapshot["schema_version"] = 2
instanceID := strings.TrimSpace(sel.InstanceID)
if instanceID != "" {
@@ -231,12 +231,32 @@ func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection) map[strin
snapshot["payment_mode"] = paymentMode
}
if providerKey == payment.TypeWxpay {
if merchantAppID := paymentOrderSnapshotWxpayAppID(sel, req); merchantAppID != "" {
snapshot["merchant_app_id"] = merchantAppID
}
if merchantID := strings.TrimSpace(sel.Config["mchId"]); merchantID != "" {
snapshot["merchant_id"] = merchantID
}
snapshot["currency"] = "CNY"
}
if len(snapshot) == 1 {
return nil
}
return snapshot
}
func paymentOrderSnapshotWxpayAppID(sel *payment.InstanceSelection, req CreateOrderRequest) string {
if sel == nil || strings.TrimSpace(sel.ProviderKey) != payment.TypeWxpay {
return ""
}
if strings.TrimSpace(req.OpenID) != "" {
return strings.TrimSpace(provider.ResolveWxpayJSAPIAppID(sel.Config))
}
return strings.TrimSpace(sel.Config["appId"])
}
func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, userID int64, amount, limit float64) error {
if limit <= 0 {
return nil

View File

@@ -163,7 +163,7 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
}
notificationTradeNo = upstreamTradeNo
}
if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: notificationTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey()); err != nil {
if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: notificationTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess, Metadata: resp.Metadata}, prov.ProviderKey()); err != nil {
slog.Error("fulfillment failed during checkPaid", "orderID", o.ID, "error", err)
// Still return already_paid — order was paid, fulfillment can be retried
}

View File

@@ -15,6 +15,9 @@ type paymentOrderProviderSnapshot struct {
ProviderInstanceID string
ProviderKey string
PaymentMode string
MerchantAppID string
MerchantID string
Currency string
}
func psOrderProviderSnapshot(order *dbent.PaymentOrder) *paymentOrderProviderSnapshot {
@@ -27,8 +30,17 @@ func psOrderProviderSnapshot(order *dbent.PaymentOrder) *paymentOrderProviderSna
ProviderInstanceID: psSnapshotStringValue(order.ProviderSnapshot["provider_instance_id"]),
ProviderKey: psSnapshotStringValue(order.ProviderSnapshot["provider_key"]),
PaymentMode: psSnapshotStringValue(order.ProviderSnapshot["payment_mode"]),
MerchantAppID: psSnapshotStringValue(order.ProviderSnapshot["merchant_app_id"]),
MerchantID: psSnapshotStringValue(order.ProviderSnapshot["merchant_id"]),
Currency: psSnapshotStringValue(order.ProviderSnapshot["currency"]),
}
if snapshot.SchemaVersion == 0 && snapshot.ProviderInstanceID == "" && snapshot.ProviderKey == "" && snapshot.PaymentMode == "" {
if snapshot.SchemaVersion == 0 &&
snapshot.ProviderInstanceID == "" &&
snapshot.ProviderKey == "" &&
snapshot.PaymentMode == "" &&
snapshot.MerchantAppID == "" &&
snapshot.MerchantID == "" &&
snapshot.Currency == "" {
return nil
}
return snapshot

View File

@@ -26,18 +26,21 @@ func TestBuildPaymentOrderProviderSnapshot_ExcludesSensitiveConfig(t *testing.T)
},
}
snapshot := buildPaymentOrderProviderSnapshot(sel)
snapshot := buildPaymentOrderProviderSnapshot(sel, CreateOrderRequest{})
require.Equal(t, map[string]any{
"schema_version": 1,
"schema_version": 2,
"provider_instance_id": "12",
"provider_key": payment.TypeWxpay,
"payment_mode": "popup",
"merchant_app_id": "wx-app-id",
"currency": "CNY",
}, snapshot)
require.NotContains(t, snapshot, "config")
require.NotContains(t, snapshot, "privateKey")
require.NotContains(t, snapshot, "apiV3Key")
require.NotContains(t, snapshot, "supported_types")
require.NotContains(t, snapshot, "instance_name")
require.NotContains(t, snapshot, "merchant_id")
}
func TestCreateOrderInTx_WritesProviderSnapshot(t *testing.T) {
@@ -98,7 +101,7 @@ func TestCreateOrderInTx_WritesProviderSnapshot(t *testing.T) {
require.NoError(t, err)
require.Equal(t, strconv.FormatInt(instance.ID, 10), valueOrEmpty(order.ProviderInstanceID))
require.Equal(t, payment.TypeAlipay, valueOrEmpty(order.ProviderKey))
require.Equal(t, float64(1), order.ProviderSnapshot["schema_version"])
require.Equal(t, float64(2), order.ProviderSnapshot["schema_version"])
require.Equal(t, strconv.FormatInt(instance.ID, 10), order.ProviderSnapshot["provider_instance_id"])
require.Equal(t, payment.TypeAlipay, order.ProviderSnapshot["provider_key"])
require.Equal(t, "redirect", order.ProviderSnapshot["payment_mode"])
@@ -108,6 +111,25 @@ func TestCreateOrderInTx_WritesProviderSnapshot(t *testing.T) {
require.NotContains(t, order.ProviderSnapshot, "instance_name")
}
func TestBuildPaymentOrderProviderSnapshot_UsesWxpayJSAPIAppIDForOpenIDOrders(t *testing.T) {
t.Parallel()
snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{
InstanceID: "88",
ProviderKey: payment.TypeWxpay,
Config: map[string]string{
"appId": "wx-open-app",
"mpAppId": "wx-mp-app",
"mchId": "mch-88",
},
PaymentMode: "jsapi",
}, CreateOrderRequest{OpenID: "openid-123"})
require.Equal(t, "wx-mp-app", snapshot["merchant_app_id"])
require.Equal(t, "mch-88", snapshot["merchant_id"])
require.Equal(t, "CNY", snapshot["currency"])
}
func valueOrEmpty(v *string) string {
if v == nil {
return ""