fix: pin payment read paths to provider snapshots

This commit is contained in:
IanShaw027
2026-04-21 12:50:55 +08:00
parent 561405ab00
commit 35aeeaa6e1
6 changed files with 295 additions and 2 deletions

View File

@@ -45,7 +45,7 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
if inst, instErr := s.getOrderProviderInstance(ctx, o); instErr == nil && inst != nil {
instanceProviderKey = inst.ProviderKey
}
expectedProviderKey := expectedNotificationProviderKey(s.registry, o.PaymentType, psStringValue(o.ProviderKey), instanceProviderKey)
expectedProviderKey := expectedNotificationProviderKeyForOrder(s.registry, o, instanceProviderKey)
if expectedProviderKey != "" && strings.TrimSpace(pk) != "" && !strings.EqualFold(expectedProviderKey, strings.TrimSpace(pk)) {
s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_MISMATCH", pk, map[string]any{
"expectedProvider": expectedProviderKey,

View File

@@ -7,6 +7,7 @@ import (
"errors"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/stretchr/testify/assert"
)
@@ -240,3 +241,26 @@ func TestExpectedNotificationProviderKeyPrefersOrderSnapshotProviderKey(t *testi
expectedNotificationProviderKey(registry, payment.TypeAlipay, payment.TypeEasyPay, ""),
)
}
func TestExpectedNotificationProviderKeyForOrderUsesSnapshotProviderKey(t *testing.T) {
t.Parallel()
registry := payment.NewRegistry()
registry.Register(paymentFulfillmentTestProvider{
key: payment.TypeAlipay,
supportedTypes: []payment.PaymentType{payment.TypeAlipay},
})
order := &dbent.PaymentOrder{
PaymentType: payment.TypeAlipay,
ProviderSnapshot: map[string]any{
"schema_version": 1,
"provider_key": payment.TypeEasyPay,
},
}
assert.Equal(t,
payment.TypeEasyPay,
expectedNotificationProviderKeyForOrder(registry, order, ""),
)
}

View File

@@ -0,0 +1,115 @@
package service
import (
"context"
"fmt"
"strconv"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
type paymentOrderProviderSnapshot struct {
SchemaVersion int
ProviderInstanceID string
ProviderKey string
PaymentMode string
}
func psOrderProviderSnapshot(order *dbent.PaymentOrder) *paymentOrderProviderSnapshot {
if order == nil || len(order.ProviderSnapshot) == 0 {
return nil
}
snapshot := &paymentOrderProviderSnapshot{
SchemaVersion: psSnapshotIntValue(order.ProviderSnapshot["schema_version"]),
ProviderInstanceID: psSnapshotStringValue(order.ProviderSnapshot["provider_instance_id"]),
ProviderKey: psSnapshotStringValue(order.ProviderSnapshot["provider_key"]),
PaymentMode: psSnapshotStringValue(order.ProviderSnapshot["payment_mode"]),
}
if snapshot.SchemaVersion == 0 && snapshot.ProviderInstanceID == "" && snapshot.ProviderKey == "" && snapshot.PaymentMode == "" {
return nil
}
return snapshot
}
func psSnapshotStringValue(value any) string {
switch typed := value.(type) {
case string:
return strings.TrimSpace(typed)
default:
return ""
}
}
func psSnapshotIntValue(value any) int {
switch typed := value.(type) {
case int:
return typed
case int32:
return int(typed)
case int64:
return int(typed)
case float32:
return int(typed)
case float64:
return int(typed)
case string:
n, err := strconv.Atoi(strings.TrimSpace(typed))
if err == nil {
return n
}
}
return 0
}
func (s *PaymentService) resolveSnapshotOrderProviderInstance(ctx context.Context, order *dbent.PaymentOrder, snapshot *paymentOrderProviderSnapshot) (*dbent.PaymentProviderInstance, error) {
if s == nil || s.entClient == nil || order == nil || snapshot == nil {
return nil, nil
}
snapshotInstanceID := strings.TrimSpace(snapshot.ProviderInstanceID)
columnInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID))
if snapshotInstanceID == "" {
snapshotInstanceID = columnInstanceID
}
if snapshotInstanceID == "" {
return nil, fmt.Errorf("order %d provider snapshot is missing provider_instance_id", order.ID)
}
if columnInstanceID != "" && snapshot.ProviderInstanceID != "" && !strings.EqualFold(columnInstanceID, snapshot.ProviderInstanceID) {
return nil, fmt.Errorf("order %d provider snapshot instance mismatch: snapshot=%s order=%s", order.ID, snapshot.ProviderInstanceID, columnInstanceID)
}
instID, err := strconv.ParseInt(snapshotInstanceID, 10, 64)
if err != nil {
return nil, fmt.Errorf("order %d provider snapshot instance id is invalid: %s", order.ID, snapshotInstanceID)
}
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, instID)
if err != nil {
if dbent.IsNotFound(err) {
return nil, fmt.Errorf("order %d provider snapshot instance %s is missing", order.ID, snapshotInstanceID)
}
return nil, err
}
if snapshot.ProviderKey != "" && !strings.EqualFold(strings.TrimSpace(inst.ProviderKey), snapshot.ProviderKey) {
return nil, fmt.Errorf("order %d provider snapshot key mismatch: snapshot=%s instance=%s", order.ID, snapshot.ProviderKey, inst.ProviderKey)
}
return inst, nil
}
func expectedNotificationProviderKeyForOrder(registry *payment.Registry, order *dbent.PaymentOrder, instanceProviderKey string) string {
if order == nil {
return strings.TrimSpace(instanceProviderKey)
}
orderProviderKey := psStringValue(order.ProviderKey)
if snapshot := psOrderProviderSnapshot(order); snapshot != nil && snapshot.ProviderKey != "" {
orderProviderKey = snapshot.ProviderKey
}
return expectedNotificationProviderKey(registry, order.PaymentType, orderProviderKey, instanceProviderKey)
}

View File

@@ -27,6 +27,10 @@ func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.
return nil, nil
}
if snapshot := psOrderProviderSnapshot(o); snapshot != nil {
return s.resolveSnapshotOrderProviderInstance(ctx, o, snapshot)
}
instIDStr := strings.TrimSpace(psStringValue(o.ProviderInstanceID))
if instIDStr == "" {
return s.resolveUniqueLegacyOrderProviderInstance(ctx, o)

View File

@@ -113,7 +113,7 @@ func (s *PaymentService) webhookRegistryFallbackAllowed(ctx context.Context, pro
}
func psHasPinnedProviderInstance(order *dbent.PaymentOrder) bool {
return order != nil && order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != ""
return order != nil && (psOrderProviderSnapshot(order) != nil || (order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != ""))
}
func (s *PaymentService) getEnabledWebhookProvidersByKey(ctx context.Context, providerKey string) ([]payment.Provider, error) {

View File

@@ -5,6 +5,7 @@ package service
import (
"context"
"encoding/json"
"strconv"
"testing"
"time"
@@ -205,6 +206,72 @@ func TestGetOrderProviderInstanceLeavesProviderKeyMatchUnresolvedWhenTypeNotSupp
require.Nil(t, got)
}
func TestGetOrderProviderInstanceUsesProviderSnapshotWhenPinnedColumnMissing(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
inst, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeStripe).
SetName("stripe-snapshot").
SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_snapshot"})).
SetSupportedTypes("stripe").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
order := &dbent.PaymentOrder{
ID: 42,
PaymentType: payment.TypeStripe,
ProviderSnapshot: map[string]any{
"schema_version": 1,
"provider_instance_id": strconv.FormatInt(inst.ID, 10),
"provider_key": payment.TypeStripe,
},
}
svc := &PaymentService{
entClient: client,
loadBalancer: newWebhookProviderTestLoadBalancer(client),
}
got, err := svc.getOrderProviderInstance(ctx, order)
require.NoError(t, err)
require.NotNil(t, got)
require.Equal(t, inst.ID, got.ID)
}
func TestGetOrderProviderInstanceRejectsMissingSnapshotInstanceWithoutLegacyFallback(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeStripe).
SetName("stripe-legacy-fallback").
SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_legacy"})).
SetSupportedTypes("stripe").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
order := &dbent.PaymentOrder{
ID: 43,
PaymentType: payment.TypeStripe,
ProviderSnapshot: map[string]any{
"schema_version": 1,
"provider_instance_id": "999999",
"provider_key": payment.TypeStripe,
},
}
svc := &PaymentService{
entClient: client,
loadBalancer: newWebhookProviderTestLoadBalancer(client),
}
got, err := svc.getOrderProviderInstance(ctx, order)
require.Nil(t, got)
require.Error(t, err)
require.Contains(t, err.Error(), "provider snapshot instance 999999 is missing")
}
func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
@@ -364,3 +431,86 @@ func TestGetWebhookProviderRejectsRegistryFallbackForPinnedOrder(t *testing.T) {
require.Error(t, err)
require.Contains(t, err.Error(), "provider instance")
}
func TestGetWebhookProviderUsesProviderSnapshotBeforeWxpayFallback(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
user, err := client.User.Create().
SetEmail("snapshot-webhook@example.com").
SetPasswordHash("hash").
SetUsername("snapshot-webhook").
Save(ctx)
require.NoError(t, err)
wxpayConfigA := encryptWebhookProviderConfig(t, map[string]string{
"appId": "wx-app-snapshot-a",
"mchId": "mch-snapshot-a",
"privateKey": "private-key-snapshot-a",
"apiV3Key": webhookProviderTestEncryptionKey,
"publicKey": "public-key-snapshot-a",
"publicKeyId": "public-key-id-snapshot-a",
"certSerial": "cert-serial-snapshot-a",
})
wxpayConfigB := encryptWebhookProviderConfig(t, map[string]string{
"appId": "wx-app-snapshot-b",
"mchId": "mch-snapshot-b",
"privateKey": "private-key-snapshot-b",
"apiV3Key": webhookProviderTestEncryptionKey,
"publicKey": "public-key-snapshot-b",
"publicKeyId": "public-key-id-snapshot-b",
"certSerial": "cert-serial-snapshot-b",
})
instA, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("wxpay-snapshot-a").
SetConfig(wxpayConfigA).
SetSupportedTypes("wxpay").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("wxpay-snapshot-b").
SetConfig(wxpayConfigB).
SetSupportedTypes("wxpay").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
_, err = client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(66).
SetPayAmount(66).
SetFeeRate(0).
SetRechargeCode("SNAPSHOT-WEBHOOK").
SetOutTradeNo("sub2_test_snapshot_webhook_order").
SetPaymentType(payment.TypeWxpay).
SetPaymentTradeNo("").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
SetProviderSnapshot(map[string]any{
"schema_version": 1,
"provider_instance_id": strconv.FormatInt(instA.ID, 10),
"provider_key": payment.TypeWxpay,
"payment_mode": "native",
}).
Save(ctx)
require.NoError(t, err)
svc := &PaymentService{
entClient: client,
loadBalancer: newWebhookProviderTestLoadBalancer(client),
registry: payment.NewRegistry(),
providersLoaded: true,
}
providers, err := svc.GetWebhookProviders(ctx, payment.TypeWxpay, "sub2_test_snapshot_webhook_order")
require.NoError(t, err)
require.Len(t, providers, 1)
require.Equal(t, payment.TypeWxpay, providers[0].ProviderKey())
}