From 35aeeaa6e17632ab08bd7f7f85d5edc396e4ce0f Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Tue, 21 Apr 2026 12:50:55 +0800 Subject: [PATCH] fix: pin payment read paths to provider snapshots --- .../internal/service/payment_fulfillment.go | 2 +- .../service/payment_fulfillment_test.go | 24 +++ .../payment_order_provider_snapshot.go | 115 ++++++++++++++ backend/internal/service/payment_refund.go | 4 + .../service/payment_webhook_provider.go | 2 +- .../service/payment_webhook_provider_test.go | 150 ++++++++++++++++++ 6 files changed, 295 insertions(+), 2 deletions(-) create mode 100644 backend/internal/service/payment_order_provider_snapshot.go diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go index 83bac21d..9cb03cca 100644 --- a/backend/internal/service/payment_fulfillment.go +++ b/backend/internal/service/payment_fulfillment.go @@ -45,7 +45,7 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo if inst, instErr := s.getOrderProviderInstance(ctx, o); instErr == nil && inst != nil { instanceProviderKey = inst.ProviderKey } - expectedProviderKey := expectedNotificationProviderKey(s.registry, o.PaymentType, psStringValue(o.ProviderKey), instanceProviderKey) + expectedProviderKey := expectedNotificationProviderKeyForOrder(s.registry, o, instanceProviderKey) if expectedProviderKey != "" && strings.TrimSpace(pk) != "" && !strings.EqualFold(expectedProviderKey, strings.TrimSpace(pk)) { s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_MISMATCH", pk, map[string]any{ "expectedProvider": expectedProviderKey, diff --git a/backend/internal/service/payment_fulfillment_test.go b/backend/internal/service/payment_fulfillment_test.go index 712129b0..3ce82973 100644 --- a/backend/internal/service/payment_fulfillment_test.go +++ b/backend/internal/service/payment_fulfillment_test.go @@ -7,6 +7,7 @@ import ( "errors" "testing" + dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/stretchr/testify/assert" ) @@ -240,3 +241,26 @@ func TestExpectedNotificationProviderKeyPrefersOrderSnapshotProviderKey(t *testi expectedNotificationProviderKey(registry, payment.TypeAlipay, payment.TypeEasyPay, ""), ) } + +func TestExpectedNotificationProviderKeyForOrderUsesSnapshotProviderKey(t *testing.T) { + t.Parallel() + + registry := payment.NewRegistry() + registry.Register(paymentFulfillmentTestProvider{ + key: payment.TypeAlipay, + supportedTypes: []payment.PaymentType{payment.TypeAlipay}, + }) + + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeAlipay, + ProviderSnapshot: map[string]any{ + "schema_version": 1, + "provider_key": payment.TypeEasyPay, + }, + } + + assert.Equal(t, + payment.TypeEasyPay, + expectedNotificationProviderKeyForOrder(registry, order, ""), + ) +} diff --git a/backend/internal/service/payment_order_provider_snapshot.go b/backend/internal/service/payment_order_provider_snapshot.go new file mode 100644 index 00000000..9a0aa106 --- /dev/null +++ b/backend/internal/service/payment_order_provider_snapshot.go @@ -0,0 +1,115 @@ +package service + +import ( + "context" + "fmt" + "strconv" + "strings" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/payment" +) + +type paymentOrderProviderSnapshot struct { + SchemaVersion int + ProviderInstanceID string + ProviderKey string + PaymentMode string +} + +func psOrderProviderSnapshot(order *dbent.PaymentOrder) *paymentOrderProviderSnapshot { + if order == nil || len(order.ProviderSnapshot) == 0 { + return nil + } + + snapshot := &paymentOrderProviderSnapshot{ + SchemaVersion: psSnapshotIntValue(order.ProviderSnapshot["schema_version"]), + ProviderInstanceID: psSnapshotStringValue(order.ProviderSnapshot["provider_instance_id"]), + ProviderKey: psSnapshotStringValue(order.ProviderSnapshot["provider_key"]), + PaymentMode: psSnapshotStringValue(order.ProviderSnapshot["payment_mode"]), + } + if snapshot.SchemaVersion == 0 && snapshot.ProviderInstanceID == "" && snapshot.ProviderKey == "" && snapshot.PaymentMode == "" { + return nil + } + return snapshot +} + +func psSnapshotStringValue(value any) string { + switch typed := value.(type) { + case string: + return strings.TrimSpace(typed) + default: + return "" + } +} + +func psSnapshotIntValue(value any) int { + switch typed := value.(type) { + case int: + return typed + case int32: + return int(typed) + case int64: + return int(typed) + case float32: + return int(typed) + case float64: + return int(typed) + case string: + n, err := strconv.Atoi(strings.TrimSpace(typed)) + if err == nil { + return n + } + } + return 0 +} + +func (s *PaymentService) resolveSnapshotOrderProviderInstance(ctx context.Context, order *dbent.PaymentOrder, snapshot *paymentOrderProviderSnapshot) (*dbent.PaymentProviderInstance, error) { + if s == nil || s.entClient == nil || order == nil || snapshot == nil { + return nil, nil + } + + snapshotInstanceID := strings.TrimSpace(snapshot.ProviderInstanceID) + columnInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID)) + if snapshotInstanceID == "" { + snapshotInstanceID = columnInstanceID + } + if snapshotInstanceID == "" { + return nil, fmt.Errorf("order %d provider snapshot is missing provider_instance_id", order.ID) + } + if columnInstanceID != "" && snapshot.ProviderInstanceID != "" && !strings.EqualFold(columnInstanceID, snapshot.ProviderInstanceID) { + return nil, fmt.Errorf("order %d provider snapshot instance mismatch: snapshot=%s order=%s", order.ID, snapshot.ProviderInstanceID, columnInstanceID) + } + + instID, err := strconv.ParseInt(snapshotInstanceID, 10, 64) + if err != nil { + return nil, fmt.Errorf("order %d provider snapshot instance id is invalid: %s", order.ID, snapshotInstanceID) + } + + inst, err := s.entClient.PaymentProviderInstance.Get(ctx, instID) + if err != nil { + if dbent.IsNotFound(err) { + return nil, fmt.Errorf("order %d provider snapshot instance %s is missing", order.ID, snapshotInstanceID) + } + return nil, err + } + + if snapshot.ProviderKey != "" && !strings.EqualFold(strings.TrimSpace(inst.ProviderKey), snapshot.ProviderKey) { + return nil, fmt.Errorf("order %d provider snapshot key mismatch: snapshot=%s instance=%s", order.ID, snapshot.ProviderKey, inst.ProviderKey) + } + + return inst, nil +} + +func expectedNotificationProviderKeyForOrder(registry *payment.Registry, order *dbent.PaymentOrder, instanceProviderKey string) string { + if order == nil { + return strings.TrimSpace(instanceProviderKey) + } + + orderProviderKey := psStringValue(order.ProviderKey) + if snapshot := psOrderProviderSnapshot(order); snapshot != nil && snapshot.ProviderKey != "" { + orderProviderKey = snapshot.ProviderKey + } + + return expectedNotificationProviderKey(registry, order.PaymentType, orderProviderKey, instanceProviderKey) +} diff --git a/backend/internal/service/payment_refund.go b/backend/internal/service/payment_refund.go index 57469fa3..fbaeff99 100644 --- a/backend/internal/service/payment_refund.go +++ b/backend/internal/service/payment_refund.go @@ -27,6 +27,10 @@ func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent. return nil, nil } + if snapshot := psOrderProviderSnapshot(o); snapshot != nil { + return s.resolveSnapshotOrderProviderInstance(ctx, o, snapshot) + } + instIDStr := strings.TrimSpace(psStringValue(o.ProviderInstanceID)) if instIDStr == "" { return s.resolveUniqueLegacyOrderProviderInstance(ctx, o) diff --git a/backend/internal/service/payment_webhook_provider.go b/backend/internal/service/payment_webhook_provider.go index 82dc9ea3..f2da40d9 100644 --- a/backend/internal/service/payment_webhook_provider.go +++ b/backend/internal/service/payment_webhook_provider.go @@ -113,7 +113,7 @@ func (s *PaymentService) webhookRegistryFallbackAllowed(ctx context.Context, pro } func psHasPinnedProviderInstance(order *dbent.PaymentOrder) bool { - return order != nil && order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != "" + return order != nil && (psOrderProviderSnapshot(order) != nil || (order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != "")) } func (s *PaymentService) getEnabledWebhookProvidersByKey(ctx context.Context, providerKey string) ([]payment.Provider, error) { diff --git a/backend/internal/service/payment_webhook_provider_test.go b/backend/internal/service/payment_webhook_provider_test.go index 15b447c2..f12cf691 100644 --- a/backend/internal/service/payment_webhook_provider_test.go +++ b/backend/internal/service/payment_webhook_provider_test.go @@ -5,6 +5,7 @@ package service import ( "context" "encoding/json" + "strconv" "testing" "time" @@ -205,6 +206,72 @@ func TestGetOrderProviderInstanceLeavesProviderKeyMatchUnresolvedWhenTypeNotSupp require.Nil(t, got) } +func TestGetOrderProviderInstanceUsesProviderSnapshotWhenPinnedColumnMissing(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + inst, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeStripe). + SetName("stripe-snapshot"). + SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_snapshot"})). + SetSupportedTypes("stripe"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + order := &dbent.PaymentOrder{ + ID: 42, + PaymentType: payment.TypeStripe, + ProviderSnapshot: map[string]any{ + "schema_version": 1, + "provider_instance_id": strconv.FormatInt(inst.ID, 10), + "provider_key": payment.TypeStripe, + }, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, inst.ID, got.ID) +} + +func TestGetOrderProviderInstanceRejectsMissingSnapshotInstanceWithoutLegacyFallback(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeStripe). + SetName("stripe-legacy-fallback"). + SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_legacy"})). + SetSupportedTypes("stripe"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + order := &dbent.PaymentOrder{ + ID: 43, + PaymentType: payment.TypeStripe, + ProviderSnapshot: map[string]any{ + "schema_version": 1, + "provider_instance_id": "999999", + "provider_key": payment.TypeStripe, + }, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.Nil(t, got) + require.Error(t, err) + require.Contains(t, err.Error(), "provider snapshot instance 999999 is missing") +} + func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) { ctx := context.Background() client := newPaymentConfigServiceTestClient(t) @@ -364,3 +431,86 @@ func TestGetWebhookProviderRejectsRegistryFallbackForPinnedOrder(t *testing.T) { require.Error(t, err) require.Contains(t, err.Error(), "provider instance") } + +func TestGetWebhookProviderUsesProviderSnapshotBeforeWxpayFallback(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + user, err := client.User.Create(). + SetEmail("snapshot-webhook@example.com"). + SetPasswordHash("hash"). + SetUsername("snapshot-webhook"). + Save(ctx) + require.NoError(t, err) + + wxpayConfigA := encryptWebhookProviderConfig(t, map[string]string{ + "appId": "wx-app-snapshot-a", + "mchId": "mch-snapshot-a", + "privateKey": "private-key-snapshot-a", + "apiV3Key": webhookProviderTestEncryptionKey, + "publicKey": "public-key-snapshot-a", + "publicKeyId": "public-key-id-snapshot-a", + "certSerial": "cert-serial-snapshot-a", + }) + wxpayConfigB := encryptWebhookProviderConfig(t, map[string]string{ + "appId": "wx-app-snapshot-b", + "mchId": "mch-snapshot-b", + "privateKey": "private-key-snapshot-b", + "apiV3Key": webhookProviderTestEncryptionKey, + "publicKey": "public-key-snapshot-b", + "publicKeyId": "public-key-id-snapshot-b", + "certSerial": "cert-serial-snapshot-b", + }) + instA, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-snapshot-a"). + SetConfig(wxpayConfigA). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-snapshot-b"). + SetConfig(wxpayConfigB). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + _, err = client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(66). + SetPayAmount(66). + SetFeeRate(0). + SetRechargeCode("SNAPSHOT-WEBHOOK"). + SetOutTradeNo("sub2_test_snapshot_webhook_order"). + SetPaymentType(payment.TypeWxpay). + SetPaymentTradeNo(""). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + SetProviderSnapshot(map[string]any{ + "schema_version": 1, + "provider_instance_id": strconv.FormatInt(instA.ID, 10), + "provider_key": payment.TypeWxpay, + "payment_mode": "native", + }). + Save(ctx) + require.NoError(t, err) + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + registry: payment.NewRegistry(), + providersLoaded: true, + } + + providers, err := svc.GetWebhookProviders(ctx, payment.TypeWxpay, "sub2_test_snapshot_webhook_order") + require.NoError(t, err) + require.Len(t, providers, 1) + require.Equal(t, payment.TypeWxpay, providers[0].ProviderKey()) +}