diff --git a/backend/internal/payment/provider/wxpay.go b/backend/internal/payment/provider/wxpay.go index 7d51dff0..30016338 100644 --- a/backend/internal/payment/provider/wxpay.go +++ b/backend/internal/payment/provider/wxpay.go @@ -32,6 +32,13 @@ const ( wxpayResultPath = "/payment/result" ) +const ( + wxpayMetadataAppID = "appid" + wxpayMetadataMerchantID = "mchid" + wxpayMetadataCurrency = "currency" + wxpayMetadataTradeState = "trade_state" +) + // WeChat Pay create-payment modes. const ( wxpayModeNative = "native" @@ -355,6 +362,32 @@ func mapWxState(s string) string { } } +func buildWxpayTransactionMetadata(tx *payments.Transaction) map[string]string { + if tx == nil { + return nil + } + + metadata := map[string]string{} + if appID := wxSV(tx.Appid); appID != "" { + metadata[wxpayMetadataAppID] = appID + } + if merchantID := wxSV(tx.Mchid); merchantID != "" { + metadata[wxpayMetadataMerchantID] = merchantID + } + if tradeState := wxSV(tx.TradeState); tradeState != "" { + metadata[wxpayMetadataTradeState] = tradeState + } + if tx.Amount != nil { + if currency := wxSV(tx.Amount.Currency); currency != "" { + metadata[wxpayMetadataCurrency] = currency + } + } + if len(metadata) == 0 { + return nil + } + return metadata +} + func (w *Wxpay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) { c, err := w.ensureClient() if err != nil { @@ -379,7 +412,13 @@ func (w *Wxpay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryO if tx.SuccessTime != nil { pa = *tx.SuccessTime } - return &payment.QueryOrderResponse{TradeNo: id, Status: mapWxState(wxSV(tx.TradeState)), Amount: amt, PaidAt: pa}, nil + return &payment.QueryOrderResponse{ + TradeNo: id, + Status: mapWxState(wxSV(tx.TradeState)), + Amount: amt, + PaidAt: pa, + Metadata: buildWxpayTransactionMetadata(tx), + }, nil } func (w *Wxpay) VerifyNotification(ctx context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) { @@ -411,7 +450,7 @@ func (w *Wxpay) VerifyNotification(ctx context.Context, rawBody string, headers } return &payment.PaymentNotification{ TradeNo: wxSV(tx.TransactionId), OrderID: wxSV(tx.OutTradeNo), - Amount: amt, Status: st, RawData: rawBody, + Amount: amt, Status: st, RawData: rawBody, Metadata: buildWxpayTransactionMetadata(&tx), }, nil } diff --git a/backend/internal/payment/provider/wxpay_test.go b/backend/internal/payment/provider/wxpay_test.go index 6d0006be..b3f4f648 100644 --- a/backend/internal/payment/provider/wxpay_test.go +++ b/backend/internal/payment/provider/wxpay_test.go @@ -10,6 +10,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/wechatpay-apiv3/wechatpay-go/core" + "github.com/wechatpay-apiv3/wechatpay-go/services/payments" "github.com/wechatpay-apiv3/wechatpay-go/services/payments/h5" "github.com/wechatpay-apiv3/wechatpay-go/services/payments/jsapi" "github.com/wechatpay-apiv3/wechatpay-go/services/payments/native" @@ -102,6 +103,33 @@ func TestWxSV(t *testing.T) { } } +func TestBuildWxpayTransactionMetadata(t *testing.T) { + t.Parallel() + + tx := &payments.Transaction{ + Appid: strPtr("wx-app-id"), + Mchid: strPtr("mch-id"), + TradeState: strPtr(wxpayTradeStateSuccess), + Amount: &payments.Amount{ + Currency: strPtr(wxpayCurrency), + }, + } + + metadata := buildWxpayTransactionMetadata(tx) + if metadata[wxpayMetadataAppID] != "wx-app-id" { + t.Fatalf("appid = %q", metadata[wxpayMetadataAppID]) + } + if metadata[wxpayMetadataMerchantID] != "mch-id" { + t.Fatalf("mchid = %q", metadata[wxpayMetadataMerchantID]) + } + if metadata[wxpayMetadataCurrency] != wxpayCurrency { + t.Fatalf("currency = %q", metadata[wxpayMetadataCurrency]) + } + if metadata[wxpayMetadataTradeState] != wxpayTradeStateSuccess { + t.Fatalf("trade_state = %q", metadata[wxpayMetadataTradeState]) + } +} + func strPtr(s string) *string { return &s } diff --git a/backend/internal/payment/types.go b/backend/internal/payment/types.go index bb125247..29abf82b 100644 --- a/backend/internal/payment/types.go +++ b/backend/internal/payment/types.go @@ -149,19 +149,21 @@ type CreatePaymentResponse struct { // QueryOrderResponse describes the payment status from the upstream provider. type QueryOrderResponse struct { - TradeNo string - Status string // "pending", "paid", "failed", "refunded" - Amount float64 // Amount in CNY - PaidAt string // RFC3339 timestamp or empty + TradeNo string + Status string // "pending", "paid", "failed", "refunded" + Amount float64 // Amount in CNY + PaidAt string // RFC3339 timestamp or empty + Metadata map[string]string } // PaymentNotification is the parsed result of a webhook/notify callback. type PaymentNotification struct { - TradeNo string - OrderID string - Amount float64 - Status string // "success" or "failed" - RawData string // Raw notification body for audit + TradeNo string + OrderID string + Amount float64 + Status string // "success" or "failed" + RawData string // Raw notification body for audit + Metadata map[string]string } // RefundRequest contains the parameters for requesting a refund. diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go index 9cb03cca..7bde03c8 100644 --- a/backend/internal/service/payment_fulfillment.go +++ b/backend/internal/service/payment_fulfillment.go @@ -28,14 +28,14 @@ func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payme // Fallback: try legacy format (sub2_N where N is DB ID) trimmed := strings.TrimPrefix(n.OrderID, orderIDPrefix) if oid, parseErr := strconv.ParseInt(trimmed, 10, 64); parseErr == nil { - return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk) + return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk, n.Metadata) } return fmt.Errorf("order not found for out_trade_no: %s", n.OrderID) } - return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk) + return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk, n.Metadata) } -func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string) error { +func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string, metadata map[string]string) error { o, err := s.entClient.PaymentOrder.Get(ctx, oid) if err != nil { slog.Error("order not found", "orderID", oid) @@ -54,6 +54,13 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo }) return fmt.Errorf("provider mismatch: expected %s, got %s", expectedProviderKey, pk) } + if err := validateProviderNotificationMetadata(o, pk, metadata); err != nil { + s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_METADATA_MISMATCH", pk, map[string]any{ + "detail": err.Error(), + "tradeNo": tradeNo, + }) + return err + } // Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount). // Also skip if paid is NaN/Inf (malformed provider data). if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) { @@ -69,6 +76,50 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo return s.toPaid(ctx, o, tradeNo, paid, pk) } +func validateProviderNotificationMetadata(order *dbent.PaymentOrder, providerKey string, metadata map[string]string) error { + if order == nil || len(metadata) == 0 || !strings.EqualFold(strings.TrimSpace(providerKey), payment.TypeWxpay) { + return nil + } + + snapshot := psOrderProviderSnapshot(order) + if snapshot == nil { + return nil + } + + if expected := strings.TrimSpace(snapshot.MerchantAppID); expected != "" { + actual := strings.TrimSpace(metadata["appid"]) + if actual == "" { + return fmt.Errorf("wxpay notification missing appid") + } + if !strings.EqualFold(expected, actual) { + return fmt.Errorf("wxpay appid mismatch: expected %s, got %s", expected, actual) + } + } + if expected := strings.TrimSpace(snapshot.MerchantID); expected != "" { + actual := strings.TrimSpace(metadata["mchid"]) + if actual == "" { + return fmt.Errorf("wxpay notification missing mchid") + } + if !strings.EqualFold(expected, actual) { + return fmt.Errorf("wxpay mchid mismatch: expected %s, got %s", expected, actual) + } + } + if expected := strings.TrimSpace(snapshot.Currency); expected != "" { + actual := strings.ToUpper(strings.TrimSpace(metadata["currency"])) + if actual == "" { + return fmt.Errorf("wxpay notification missing currency") + } + if !strings.EqualFold(expected, actual) { + return fmt.Errorf("wxpay currency mismatch: expected %s, got %s", expected, actual) + } + } + if actual := strings.TrimSpace(metadata["trade_state"]); actual != "" && !strings.EqualFold(actual, "SUCCESS") { + return fmt.Errorf("wxpay trade_state mismatch: expected SUCCESS, got %s", actual) + } + + return nil +} + func expectedNotificationProviderKey(registry *payment.Registry, orderPaymentType string, orderProviderKey string, instanceProviderKey string) string { if key := strings.TrimSpace(instanceProviderKey); key != "" { return key diff --git a/backend/internal/service/payment_fulfillment_test.go b/backend/internal/service/payment_fulfillment_test.go index 3ce82973..8883d3b8 100644 --- a/backend/internal/service/payment_fulfillment_test.go +++ b/backend/internal/service/payment_fulfillment_test.go @@ -264,3 +264,46 @@ func TestExpectedNotificationProviderKeyForOrderUsesSnapshotProviderKey(t *testi expectedNotificationProviderKeyForOrder(registry, order, ""), ) } + +func TestValidateProviderNotificationMetadataRejectsWxpaySnapshotMismatch(t *testing.T) { + t.Parallel() + + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeWxpay, + ProviderSnapshot: map[string]any{ + "schema_version": 1, + "merchant_app_id": "wx-app-expected", + "merchant_id": "mch-expected", + "currency": "CNY", + }, + } + + err := validateProviderNotificationMetadata(order, payment.TypeWxpay, map[string]string{ + "appid": "wx-app-other", + "mchid": "mch-expected", + "currency": "CNY", + "trade_state": "SUCCESS", + }) + assert.ErrorContains(t, err, "wxpay appid mismatch") +} + +func TestValidateProviderNotificationMetadataAllowsLegacyOrdersWithoutSnapshotFields(t *testing.T) { + t.Parallel() + + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeWxpay, + ProviderSnapshot: map[string]any{ + "schema_version": 1, + "provider_instance_id": "9", + "provider_key": payment.TypeWxpay, + }, + } + + err := validateProviderNotificationMetadata(order, payment.TypeWxpay, map[string]string{ + "appid": "wx-app-runtime", + "mchid": "mch-runtime", + "currency": "CNY", + "trade_state": "SUCCESS", + }) + assert.NoError(t, err) +} diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go index 1f01bc11..6ee490a8 100644 --- a/backend/internal/service/payment_order.go +++ b/backend/internal/service/payment_order.go @@ -139,7 +139,7 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq tm = defaultOrderTimeoutMin } exp := time.Now().Add(time.Duration(tm) * time.Minute) - providerSnapshot := buildPaymentOrderProviderSnapshot(sel) + providerSnapshot := buildPaymentOrderProviderSnapshot(sel, req) selectedInstanceID := "" selectedProviderKey := "" if sel != nil { @@ -208,13 +208,13 @@ func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, us return nil } -func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection) map[string]any { +func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection, req CreateOrderRequest) map[string]any { if sel == nil { return nil } snapshot := map[string]any{} - snapshot["schema_version"] = 1 + snapshot["schema_version"] = 2 instanceID := strings.TrimSpace(sel.InstanceID) if instanceID != "" { @@ -231,12 +231,32 @@ func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection) map[strin snapshot["payment_mode"] = paymentMode } + if providerKey == payment.TypeWxpay { + if merchantAppID := paymentOrderSnapshotWxpayAppID(sel, req); merchantAppID != "" { + snapshot["merchant_app_id"] = merchantAppID + } + if merchantID := strings.TrimSpace(sel.Config["mchId"]); merchantID != "" { + snapshot["merchant_id"] = merchantID + } + snapshot["currency"] = "CNY" + } + if len(snapshot) == 1 { return nil } return snapshot } +func paymentOrderSnapshotWxpayAppID(sel *payment.InstanceSelection, req CreateOrderRequest) string { + if sel == nil || strings.TrimSpace(sel.ProviderKey) != payment.TypeWxpay { + return "" + } + if strings.TrimSpace(req.OpenID) != "" { + return strings.TrimSpace(provider.ResolveWxpayJSAPIAppID(sel.Config)) + } + return strings.TrimSpace(sel.Config["appId"]) +} + func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, userID int64, amount, limit float64) error { if limit <= 0 { return nil diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go index 1564c36d..c11baac1 100644 --- a/backend/internal/service/payment_order_lifecycle.go +++ b/backend/internal/service/payment_order_lifecycle.go @@ -163,7 +163,7 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s } notificationTradeNo = upstreamTradeNo } - if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: notificationTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey()); err != nil { + if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: notificationTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess, Metadata: resp.Metadata}, prov.ProviderKey()); err != nil { slog.Error("fulfillment failed during checkPaid", "orderID", o.ID, "error", err) // Still return already_paid — order was paid, fulfillment can be retried } diff --git a/backend/internal/service/payment_order_provider_snapshot.go b/backend/internal/service/payment_order_provider_snapshot.go index 9a0aa106..31a790c7 100644 --- a/backend/internal/service/payment_order_provider_snapshot.go +++ b/backend/internal/service/payment_order_provider_snapshot.go @@ -15,6 +15,9 @@ type paymentOrderProviderSnapshot struct { ProviderInstanceID string ProviderKey string PaymentMode string + MerchantAppID string + MerchantID string + Currency string } func psOrderProviderSnapshot(order *dbent.PaymentOrder) *paymentOrderProviderSnapshot { @@ -27,8 +30,17 @@ func psOrderProviderSnapshot(order *dbent.PaymentOrder) *paymentOrderProviderSna ProviderInstanceID: psSnapshotStringValue(order.ProviderSnapshot["provider_instance_id"]), ProviderKey: psSnapshotStringValue(order.ProviderSnapshot["provider_key"]), PaymentMode: psSnapshotStringValue(order.ProviderSnapshot["payment_mode"]), + MerchantAppID: psSnapshotStringValue(order.ProviderSnapshot["merchant_app_id"]), + MerchantID: psSnapshotStringValue(order.ProviderSnapshot["merchant_id"]), + Currency: psSnapshotStringValue(order.ProviderSnapshot["currency"]), } - if snapshot.SchemaVersion == 0 && snapshot.ProviderInstanceID == "" && snapshot.ProviderKey == "" && snapshot.PaymentMode == "" { + if snapshot.SchemaVersion == 0 && + snapshot.ProviderInstanceID == "" && + snapshot.ProviderKey == "" && + snapshot.PaymentMode == "" && + snapshot.MerchantAppID == "" && + snapshot.MerchantID == "" && + snapshot.Currency == "" { return nil } return snapshot diff --git a/backend/internal/service/payment_order_provider_snapshot_test.go b/backend/internal/service/payment_order_provider_snapshot_test.go index c75566bc..bc6666a8 100644 --- a/backend/internal/service/payment_order_provider_snapshot_test.go +++ b/backend/internal/service/payment_order_provider_snapshot_test.go @@ -26,18 +26,21 @@ func TestBuildPaymentOrderProviderSnapshot_ExcludesSensitiveConfig(t *testing.T) }, } - snapshot := buildPaymentOrderProviderSnapshot(sel) + snapshot := buildPaymentOrderProviderSnapshot(sel, CreateOrderRequest{}) require.Equal(t, map[string]any{ - "schema_version": 1, + "schema_version": 2, "provider_instance_id": "12", "provider_key": payment.TypeWxpay, "payment_mode": "popup", + "merchant_app_id": "wx-app-id", + "currency": "CNY", }, snapshot) require.NotContains(t, snapshot, "config") require.NotContains(t, snapshot, "privateKey") require.NotContains(t, snapshot, "apiV3Key") require.NotContains(t, snapshot, "supported_types") require.NotContains(t, snapshot, "instance_name") + require.NotContains(t, snapshot, "merchant_id") } func TestCreateOrderInTx_WritesProviderSnapshot(t *testing.T) { @@ -98,7 +101,7 @@ func TestCreateOrderInTx_WritesProviderSnapshot(t *testing.T) { require.NoError(t, err) require.Equal(t, strconv.FormatInt(instance.ID, 10), valueOrEmpty(order.ProviderInstanceID)) require.Equal(t, payment.TypeAlipay, valueOrEmpty(order.ProviderKey)) - require.Equal(t, float64(1), order.ProviderSnapshot["schema_version"]) + require.Equal(t, float64(2), order.ProviderSnapshot["schema_version"]) require.Equal(t, strconv.FormatInt(instance.ID, 10), order.ProviderSnapshot["provider_instance_id"]) require.Equal(t, payment.TypeAlipay, order.ProviderSnapshot["provider_key"]) require.Equal(t, "redirect", order.ProviderSnapshot["payment_mode"]) @@ -108,6 +111,25 @@ func TestCreateOrderInTx_WritesProviderSnapshot(t *testing.T) { require.NotContains(t, order.ProviderSnapshot, "instance_name") } +func TestBuildPaymentOrderProviderSnapshot_UsesWxpayJSAPIAppIDForOpenIDOrders(t *testing.T) { + t.Parallel() + + snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{ + InstanceID: "88", + ProviderKey: payment.TypeWxpay, + Config: map[string]string{ + "appId": "wx-open-app", + "mpAppId": "wx-mp-app", + "mchId": "mch-88", + }, + PaymentMode: "jsapi", + }, CreateOrderRequest{OpenID: "openid-123"}) + + require.Equal(t, "wx-mp-app", snapshot["merchant_app_id"]) + require.Equal(t, "mch-88", snapshot["merchant_id"]) + require.Equal(t, "CNY", snapshot["currency"]) +} + func valueOrEmpty(v *string) string { if v == nil { return ""