fix: pin payment read paths to provider snapshots
This commit is contained in:
@@ -45,7 +45,7 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
|
||||
if inst, instErr := s.getOrderProviderInstance(ctx, o); instErr == nil && inst != nil {
|
||||
instanceProviderKey = inst.ProviderKey
|
||||
}
|
||||
expectedProviderKey := expectedNotificationProviderKey(s.registry, o.PaymentType, psStringValue(o.ProviderKey), instanceProviderKey)
|
||||
expectedProviderKey := expectedNotificationProviderKeyForOrder(s.registry, o, instanceProviderKey)
|
||||
if expectedProviderKey != "" && strings.TrimSpace(pk) != "" && !strings.EqualFold(expectedProviderKey, strings.TrimSpace(pk)) {
|
||||
s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_MISMATCH", pk, map[string]any{
|
||||
"expectedProvider": expectedProviderKey,
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
@@ -240,3 +241,26 @@ func TestExpectedNotificationProviderKeyPrefersOrderSnapshotProviderKey(t *testi
|
||||
expectedNotificationProviderKey(registry, payment.TypeAlipay, payment.TypeEasyPay, ""),
|
||||
)
|
||||
}
|
||||
|
||||
func TestExpectedNotificationProviderKeyForOrderUsesSnapshotProviderKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
registry := payment.NewRegistry()
|
||||
registry.Register(paymentFulfillmentTestProvider{
|
||||
key: payment.TypeAlipay,
|
||||
supportedTypes: []payment.PaymentType{payment.TypeAlipay},
|
||||
})
|
||||
|
||||
order := &dbent.PaymentOrder{
|
||||
PaymentType: payment.TypeAlipay,
|
||||
ProviderSnapshot: map[string]any{
|
||||
"schema_version": 1,
|
||||
"provider_key": payment.TypeEasyPay,
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t,
|
||||
payment.TypeEasyPay,
|
||||
expectedNotificationProviderKeyForOrder(registry, order, ""),
|
||||
)
|
||||
}
|
||||
|
||||
115
backend/internal/service/payment_order_provider_snapshot.go
Normal file
115
backend/internal/service/payment_order_provider_snapshot.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
)
|
||||
|
||||
type paymentOrderProviderSnapshot struct {
|
||||
SchemaVersion int
|
||||
ProviderInstanceID string
|
||||
ProviderKey string
|
||||
PaymentMode string
|
||||
}
|
||||
|
||||
func psOrderProviderSnapshot(order *dbent.PaymentOrder) *paymentOrderProviderSnapshot {
|
||||
if order == nil || len(order.ProviderSnapshot) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
snapshot := &paymentOrderProviderSnapshot{
|
||||
SchemaVersion: psSnapshotIntValue(order.ProviderSnapshot["schema_version"]),
|
||||
ProviderInstanceID: psSnapshotStringValue(order.ProviderSnapshot["provider_instance_id"]),
|
||||
ProviderKey: psSnapshotStringValue(order.ProviderSnapshot["provider_key"]),
|
||||
PaymentMode: psSnapshotStringValue(order.ProviderSnapshot["payment_mode"]),
|
||||
}
|
||||
if snapshot.SchemaVersion == 0 && snapshot.ProviderInstanceID == "" && snapshot.ProviderKey == "" && snapshot.PaymentMode == "" {
|
||||
return nil
|
||||
}
|
||||
return snapshot
|
||||
}
|
||||
|
||||
func psSnapshotStringValue(value any) string {
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(typed)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func psSnapshotIntValue(value any) int {
|
||||
switch typed := value.(type) {
|
||||
case int:
|
||||
return typed
|
||||
case int32:
|
||||
return int(typed)
|
||||
case int64:
|
||||
return int(typed)
|
||||
case float32:
|
||||
return int(typed)
|
||||
case float64:
|
||||
return int(typed)
|
||||
case string:
|
||||
n, err := strconv.Atoi(strings.TrimSpace(typed))
|
||||
if err == nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *PaymentService) resolveSnapshotOrderProviderInstance(ctx context.Context, order *dbent.PaymentOrder, snapshot *paymentOrderProviderSnapshot) (*dbent.PaymentProviderInstance, error) {
|
||||
if s == nil || s.entClient == nil || order == nil || snapshot == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
snapshotInstanceID := strings.TrimSpace(snapshot.ProviderInstanceID)
|
||||
columnInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID))
|
||||
if snapshotInstanceID == "" {
|
||||
snapshotInstanceID = columnInstanceID
|
||||
}
|
||||
if snapshotInstanceID == "" {
|
||||
return nil, fmt.Errorf("order %d provider snapshot is missing provider_instance_id", order.ID)
|
||||
}
|
||||
if columnInstanceID != "" && snapshot.ProviderInstanceID != "" && !strings.EqualFold(columnInstanceID, snapshot.ProviderInstanceID) {
|
||||
return nil, fmt.Errorf("order %d provider snapshot instance mismatch: snapshot=%s order=%s", order.ID, snapshot.ProviderInstanceID, columnInstanceID)
|
||||
}
|
||||
|
||||
instID, err := strconv.ParseInt(snapshotInstanceID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("order %d provider snapshot instance id is invalid: %s", order.ID, snapshotInstanceID)
|
||||
}
|
||||
|
||||
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, instID)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, fmt.Errorf("order %d provider snapshot instance %s is missing", order.ID, snapshotInstanceID)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if snapshot.ProviderKey != "" && !strings.EqualFold(strings.TrimSpace(inst.ProviderKey), snapshot.ProviderKey) {
|
||||
return nil, fmt.Errorf("order %d provider snapshot key mismatch: snapshot=%s instance=%s", order.ID, snapshot.ProviderKey, inst.ProviderKey)
|
||||
}
|
||||
|
||||
return inst, nil
|
||||
}
|
||||
|
||||
func expectedNotificationProviderKeyForOrder(registry *payment.Registry, order *dbent.PaymentOrder, instanceProviderKey string) string {
|
||||
if order == nil {
|
||||
return strings.TrimSpace(instanceProviderKey)
|
||||
}
|
||||
|
||||
orderProviderKey := psStringValue(order.ProviderKey)
|
||||
if snapshot := psOrderProviderSnapshot(order); snapshot != nil && snapshot.ProviderKey != "" {
|
||||
orderProviderKey = snapshot.ProviderKey
|
||||
}
|
||||
|
||||
return expectedNotificationProviderKey(registry, order.PaymentType, orderProviderKey, instanceProviderKey)
|
||||
}
|
||||
@@ -27,6 +27,10 @@ func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if snapshot := psOrderProviderSnapshot(o); snapshot != nil {
|
||||
return s.resolveSnapshotOrderProviderInstance(ctx, o, snapshot)
|
||||
}
|
||||
|
||||
instIDStr := strings.TrimSpace(psStringValue(o.ProviderInstanceID))
|
||||
if instIDStr == "" {
|
||||
return s.resolveUniqueLegacyOrderProviderInstance(ctx, o)
|
||||
|
||||
@@ -113,7 +113,7 @@ func (s *PaymentService) webhookRegistryFallbackAllowed(ctx context.Context, pro
|
||||
}
|
||||
|
||||
func psHasPinnedProviderInstance(order *dbent.PaymentOrder) bool {
|
||||
return order != nil && order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != ""
|
||||
return order != nil && (psOrderProviderSnapshot(order) != nil || (order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != ""))
|
||||
}
|
||||
|
||||
func (s *PaymentService) getEnabledWebhookProvidersByKey(ctx context.Context, providerKey string) ([]payment.Provider, error) {
|
||||
|
||||
@@ -5,6 +5,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -205,6 +206,72 @@ func TestGetOrderProviderInstanceLeavesProviderKeyMatchUnresolvedWhenTypeNotSupp
|
||||
require.Nil(t, got)
|
||||
}
|
||||
|
||||
func TestGetOrderProviderInstanceUsesProviderSnapshotWhenPinnedColumnMissing(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := newPaymentConfigServiceTestClient(t)
|
||||
inst, err := client.PaymentProviderInstance.Create().
|
||||
SetProviderKey(payment.TypeStripe).
|
||||
SetName("stripe-snapshot").
|
||||
SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_snapshot"})).
|
||||
SetSupportedTypes("stripe").
|
||||
SetEnabled(true).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
order := &dbent.PaymentOrder{
|
||||
ID: 42,
|
||||
PaymentType: payment.TypeStripe,
|
||||
ProviderSnapshot: map[string]any{
|
||||
"schema_version": 1,
|
||||
"provider_instance_id": strconv.FormatInt(inst.ID, 10),
|
||||
"provider_key": payment.TypeStripe,
|
||||
},
|
||||
}
|
||||
|
||||
svc := &PaymentService{
|
||||
entClient: client,
|
||||
loadBalancer: newWebhookProviderTestLoadBalancer(client),
|
||||
}
|
||||
|
||||
got, err := svc.getOrderProviderInstance(ctx, order)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, inst.ID, got.ID)
|
||||
}
|
||||
|
||||
func TestGetOrderProviderInstanceRejectsMissingSnapshotInstanceWithoutLegacyFallback(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := newPaymentConfigServiceTestClient(t)
|
||||
_, err := client.PaymentProviderInstance.Create().
|
||||
SetProviderKey(payment.TypeStripe).
|
||||
SetName("stripe-legacy-fallback").
|
||||
SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_legacy"})).
|
||||
SetSupportedTypes("stripe").
|
||||
SetEnabled(true).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
order := &dbent.PaymentOrder{
|
||||
ID: 43,
|
||||
PaymentType: payment.TypeStripe,
|
||||
ProviderSnapshot: map[string]any{
|
||||
"schema_version": 1,
|
||||
"provider_instance_id": "999999",
|
||||
"provider_key": payment.TypeStripe,
|
||||
},
|
||||
}
|
||||
|
||||
svc := &PaymentService{
|
||||
entClient: client,
|
||||
loadBalancer: newWebhookProviderTestLoadBalancer(client),
|
||||
}
|
||||
|
||||
got, err := svc.getOrderProviderInstance(ctx, order)
|
||||
require.Nil(t, got)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "provider snapshot instance 999999 is missing")
|
||||
}
|
||||
|
||||
func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := newPaymentConfigServiceTestClient(t)
|
||||
@@ -364,3 +431,86 @@ func TestGetWebhookProviderRejectsRegistryFallbackForPinnedOrder(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "provider instance")
|
||||
}
|
||||
|
||||
func TestGetWebhookProviderUsesProviderSnapshotBeforeWxpayFallback(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := newPaymentConfigServiceTestClient(t)
|
||||
user, err := client.User.Create().
|
||||
SetEmail("snapshot-webhook@example.com").
|
||||
SetPasswordHash("hash").
|
||||
SetUsername("snapshot-webhook").
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
wxpayConfigA := encryptWebhookProviderConfig(t, map[string]string{
|
||||
"appId": "wx-app-snapshot-a",
|
||||
"mchId": "mch-snapshot-a",
|
||||
"privateKey": "private-key-snapshot-a",
|
||||
"apiV3Key": webhookProviderTestEncryptionKey,
|
||||
"publicKey": "public-key-snapshot-a",
|
||||
"publicKeyId": "public-key-id-snapshot-a",
|
||||
"certSerial": "cert-serial-snapshot-a",
|
||||
})
|
||||
wxpayConfigB := encryptWebhookProviderConfig(t, map[string]string{
|
||||
"appId": "wx-app-snapshot-b",
|
||||
"mchId": "mch-snapshot-b",
|
||||
"privateKey": "private-key-snapshot-b",
|
||||
"apiV3Key": webhookProviderTestEncryptionKey,
|
||||
"publicKey": "public-key-snapshot-b",
|
||||
"publicKeyId": "public-key-id-snapshot-b",
|
||||
"certSerial": "cert-serial-snapshot-b",
|
||||
})
|
||||
instA, err := client.PaymentProviderInstance.Create().
|
||||
SetProviderKey(payment.TypeWxpay).
|
||||
SetName("wxpay-snapshot-a").
|
||||
SetConfig(wxpayConfigA).
|
||||
SetSupportedTypes("wxpay").
|
||||
SetEnabled(true).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
_, err = client.PaymentProviderInstance.Create().
|
||||
SetProviderKey(payment.TypeWxpay).
|
||||
SetName("wxpay-snapshot-b").
|
||||
SetConfig(wxpayConfigB).
|
||||
SetSupportedTypes("wxpay").
|
||||
SetEnabled(true).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.PaymentOrder.Create().
|
||||
SetUserID(user.ID).
|
||||
SetUserEmail(user.Email).
|
||||
SetUserName(user.Username).
|
||||
SetAmount(66).
|
||||
SetPayAmount(66).
|
||||
SetFeeRate(0).
|
||||
SetRechargeCode("SNAPSHOT-WEBHOOK").
|
||||
SetOutTradeNo("sub2_test_snapshot_webhook_order").
|
||||
SetPaymentType(payment.TypeWxpay).
|
||||
SetPaymentTradeNo("").
|
||||
SetOrderType(payment.OrderTypeBalance).
|
||||
SetStatus(OrderStatusPending).
|
||||
SetExpiresAt(time.Now().Add(time.Hour)).
|
||||
SetClientIP("127.0.0.1").
|
||||
SetSrcHost("api.example.com").
|
||||
SetProviderSnapshot(map[string]any{
|
||||
"schema_version": 1,
|
||||
"provider_instance_id": strconv.FormatInt(instA.ID, 10),
|
||||
"provider_key": payment.TypeWxpay,
|
||||
"payment_mode": "native",
|
||||
}).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
svc := &PaymentService{
|
||||
entClient: client,
|
||||
loadBalancer: newWebhookProviderTestLoadBalancer(client),
|
||||
registry: payment.NewRegistry(),
|
||||
providersLoaded: true,
|
||||
}
|
||||
|
||||
providers, err := svc.GetWebhookProviders(ctx, payment.TypeWxpay, "sub2_test_snapshot_webhook_order")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, providers, 1)
|
||||
require.Equal(t, payment.TypeWxpay, providers[0].ProviderKey())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user