fix: pin payment read paths to provider snapshots
This commit is contained in:
@@ -45,7 +45,7 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
|
|||||||
if inst, instErr := s.getOrderProviderInstance(ctx, o); instErr == nil && inst != nil {
|
if inst, instErr := s.getOrderProviderInstance(ctx, o); instErr == nil && inst != nil {
|
||||||
instanceProviderKey = inst.ProviderKey
|
instanceProviderKey = inst.ProviderKey
|
||||||
}
|
}
|
||||||
expectedProviderKey := expectedNotificationProviderKey(s.registry, o.PaymentType, psStringValue(o.ProviderKey), instanceProviderKey)
|
expectedProviderKey := expectedNotificationProviderKeyForOrder(s.registry, o, instanceProviderKey)
|
||||||
if expectedProviderKey != "" && strings.TrimSpace(pk) != "" && !strings.EqualFold(expectedProviderKey, strings.TrimSpace(pk)) {
|
if expectedProviderKey != "" && strings.TrimSpace(pk) != "" && !strings.EqualFold(expectedProviderKey, strings.TrimSpace(pk)) {
|
||||||
s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_MISMATCH", pk, map[string]any{
|
s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_MISMATCH", pk, map[string]any{
|
||||||
"expectedProvider": expectedProviderKey,
|
"expectedProvider": expectedProviderKey,
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
@@ -240,3 +241,26 @@ func TestExpectedNotificationProviderKeyPrefersOrderSnapshotProviderKey(t *testi
|
|||||||
expectedNotificationProviderKey(registry, payment.TypeAlipay, payment.TypeEasyPay, ""),
|
expectedNotificationProviderKey(registry, payment.TypeAlipay, payment.TypeEasyPay, ""),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExpectedNotificationProviderKeyForOrderUsesSnapshotProviderKey(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
registry := payment.NewRegistry()
|
||||||
|
registry.Register(paymentFulfillmentTestProvider{
|
||||||
|
key: payment.TypeAlipay,
|
||||||
|
supportedTypes: []payment.PaymentType{payment.TypeAlipay},
|
||||||
|
})
|
||||||
|
|
||||||
|
order := &dbent.PaymentOrder{
|
||||||
|
PaymentType: payment.TypeAlipay,
|
||||||
|
ProviderSnapshot: map[string]any{
|
||||||
|
"schema_version": 1,
|
||||||
|
"provider_key": payment.TypeEasyPay,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t,
|
||||||
|
payment.TypeEasyPay,
|
||||||
|
expectedNotificationProviderKeyForOrder(registry, order, ""),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|||||||
115
backend/internal/service/payment_order_provider_snapshot.go
Normal file
115
backend/internal/service/payment_order_provider_snapshot.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||||
|
)
|
||||||
|
|
||||||
|
type paymentOrderProviderSnapshot struct {
|
||||||
|
SchemaVersion int
|
||||||
|
ProviderInstanceID string
|
||||||
|
ProviderKey string
|
||||||
|
PaymentMode string
|
||||||
|
}
|
||||||
|
|
||||||
|
func psOrderProviderSnapshot(order *dbent.PaymentOrder) *paymentOrderProviderSnapshot {
|
||||||
|
if order == nil || len(order.ProviderSnapshot) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
snapshot := &paymentOrderProviderSnapshot{
|
||||||
|
SchemaVersion: psSnapshotIntValue(order.ProviderSnapshot["schema_version"]),
|
||||||
|
ProviderInstanceID: psSnapshotStringValue(order.ProviderSnapshot["provider_instance_id"]),
|
||||||
|
ProviderKey: psSnapshotStringValue(order.ProviderSnapshot["provider_key"]),
|
||||||
|
PaymentMode: psSnapshotStringValue(order.ProviderSnapshot["payment_mode"]),
|
||||||
|
}
|
||||||
|
if snapshot.SchemaVersion == 0 && snapshot.ProviderInstanceID == "" && snapshot.ProviderKey == "" && snapshot.PaymentMode == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return snapshot
|
||||||
|
}
|
||||||
|
|
||||||
|
func psSnapshotStringValue(value any) string {
|
||||||
|
switch typed := value.(type) {
|
||||||
|
case string:
|
||||||
|
return strings.TrimSpace(typed)
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func psSnapshotIntValue(value any) int {
|
||||||
|
switch typed := value.(type) {
|
||||||
|
case int:
|
||||||
|
return typed
|
||||||
|
case int32:
|
||||||
|
return int(typed)
|
||||||
|
case int64:
|
||||||
|
return int(typed)
|
||||||
|
case float32:
|
||||||
|
return int(typed)
|
||||||
|
case float64:
|
||||||
|
return int(typed)
|
||||||
|
case string:
|
||||||
|
n, err := strconv.Atoi(strings.TrimSpace(typed))
|
||||||
|
if err == nil {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PaymentService) resolveSnapshotOrderProviderInstance(ctx context.Context, order *dbent.PaymentOrder, snapshot *paymentOrderProviderSnapshot) (*dbent.PaymentProviderInstance, error) {
|
||||||
|
if s == nil || s.entClient == nil || order == nil || snapshot == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
snapshotInstanceID := strings.TrimSpace(snapshot.ProviderInstanceID)
|
||||||
|
columnInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID))
|
||||||
|
if snapshotInstanceID == "" {
|
||||||
|
snapshotInstanceID = columnInstanceID
|
||||||
|
}
|
||||||
|
if snapshotInstanceID == "" {
|
||||||
|
return nil, fmt.Errorf("order %d provider snapshot is missing provider_instance_id", order.ID)
|
||||||
|
}
|
||||||
|
if columnInstanceID != "" && snapshot.ProviderInstanceID != "" && !strings.EqualFold(columnInstanceID, snapshot.ProviderInstanceID) {
|
||||||
|
return nil, fmt.Errorf("order %d provider snapshot instance mismatch: snapshot=%s order=%s", order.ID, snapshot.ProviderInstanceID, columnInstanceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
instID, err := strconv.ParseInt(snapshotInstanceID, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("order %d provider snapshot instance id is invalid: %s", order.ID, snapshotInstanceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, instID)
|
||||||
|
if err != nil {
|
||||||
|
if dbent.IsNotFound(err) {
|
||||||
|
return nil, fmt.Errorf("order %d provider snapshot instance %s is missing", order.ID, snapshotInstanceID)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if snapshot.ProviderKey != "" && !strings.EqualFold(strings.TrimSpace(inst.ProviderKey), snapshot.ProviderKey) {
|
||||||
|
return nil, fmt.Errorf("order %d provider snapshot key mismatch: snapshot=%s instance=%s", order.ID, snapshot.ProviderKey, inst.ProviderKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
return inst, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func expectedNotificationProviderKeyForOrder(registry *payment.Registry, order *dbent.PaymentOrder, instanceProviderKey string) string {
|
||||||
|
if order == nil {
|
||||||
|
return strings.TrimSpace(instanceProviderKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
orderProviderKey := psStringValue(order.ProviderKey)
|
||||||
|
if snapshot := psOrderProviderSnapshot(order); snapshot != nil && snapshot.ProviderKey != "" {
|
||||||
|
orderProviderKey = snapshot.ProviderKey
|
||||||
|
}
|
||||||
|
|
||||||
|
return expectedNotificationProviderKey(registry, order.PaymentType, orderProviderKey, instanceProviderKey)
|
||||||
|
}
|
||||||
@@ -27,6 +27,10 @@ func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if snapshot := psOrderProviderSnapshot(o); snapshot != nil {
|
||||||
|
return s.resolveSnapshotOrderProviderInstance(ctx, o, snapshot)
|
||||||
|
}
|
||||||
|
|
||||||
instIDStr := strings.TrimSpace(psStringValue(o.ProviderInstanceID))
|
instIDStr := strings.TrimSpace(psStringValue(o.ProviderInstanceID))
|
||||||
if instIDStr == "" {
|
if instIDStr == "" {
|
||||||
return s.resolveUniqueLegacyOrderProviderInstance(ctx, o)
|
return s.resolveUniqueLegacyOrderProviderInstance(ctx, o)
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ func (s *PaymentService) webhookRegistryFallbackAllowed(ctx context.Context, pro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func psHasPinnedProviderInstance(order *dbent.PaymentOrder) bool {
|
func psHasPinnedProviderInstance(order *dbent.PaymentOrder) bool {
|
||||||
return order != nil && order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != ""
|
return order != nil && (psOrderProviderSnapshot(order) != nil || (order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != ""))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PaymentService) getEnabledWebhookProvidersByKey(ctx context.Context, providerKey string) ([]payment.Provider, error) {
|
func (s *PaymentService) getEnabledWebhookProvidersByKey(ctx context.Context, providerKey string) ([]payment.Provider, error) {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -205,6 +206,72 @@ func TestGetOrderProviderInstanceLeavesProviderKeyMatchUnresolvedWhenTypeNotSupp
|
|||||||
require.Nil(t, got)
|
require.Nil(t, got)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetOrderProviderInstanceUsesProviderSnapshotWhenPinnedColumnMissing(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := newPaymentConfigServiceTestClient(t)
|
||||||
|
inst, err := client.PaymentProviderInstance.Create().
|
||||||
|
SetProviderKey(payment.TypeStripe).
|
||||||
|
SetName("stripe-snapshot").
|
||||||
|
SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_snapshot"})).
|
||||||
|
SetSupportedTypes("stripe").
|
||||||
|
SetEnabled(true).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
order := &dbent.PaymentOrder{
|
||||||
|
ID: 42,
|
||||||
|
PaymentType: payment.TypeStripe,
|
||||||
|
ProviderSnapshot: map[string]any{
|
||||||
|
"schema_version": 1,
|
||||||
|
"provider_instance_id": strconv.FormatInt(inst.ID, 10),
|
||||||
|
"provider_key": payment.TypeStripe,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &PaymentService{
|
||||||
|
entClient: client,
|
||||||
|
loadBalancer: newWebhookProviderTestLoadBalancer(client),
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := svc.getOrderProviderInstance(ctx, order)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, got)
|
||||||
|
require.Equal(t, inst.ID, got.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOrderProviderInstanceRejectsMissingSnapshotInstanceWithoutLegacyFallback(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := newPaymentConfigServiceTestClient(t)
|
||||||
|
_, err := client.PaymentProviderInstance.Create().
|
||||||
|
SetProviderKey(payment.TypeStripe).
|
||||||
|
SetName("stripe-legacy-fallback").
|
||||||
|
SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_legacy"})).
|
||||||
|
SetSupportedTypes("stripe").
|
||||||
|
SetEnabled(true).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
order := &dbent.PaymentOrder{
|
||||||
|
ID: 43,
|
||||||
|
PaymentType: payment.TypeStripe,
|
||||||
|
ProviderSnapshot: map[string]any{
|
||||||
|
"schema_version": 1,
|
||||||
|
"provider_instance_id": "999999",
|
||||||
|
"provider_key": payment.TypeStripe,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &PaymentService{
|
||||||
|
entClient: client,
|
||||||
|
loadBalancer: newWebhookProviderTestLoadBalancer(client),
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := svc.getOrderProviderInstance(ctx, order)
|
||||||
|
require.Nil(t, got)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "provider snapshot instance 999999 is missing")
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) {
|
func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
client := newPaymentConfigServiceTestClient(t)
|
client := newPaymentConfigServiceTestClient(t)
|
||||||
@@ -364,3 +431,86 @@ func TestGetWebhookProviderRejectsRegistryFallbackForPinnedOrder(t *testing.T) {
|
|||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), "provider instance")
|
require.Contains(t, err.Error(), "provider instance")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetWebhookProviderUsesProviderSnapshotBeforeWxpayFallback(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := newPaymentConfigServiceTestClient(t)
|
||||||
|
user, err := client.User.Create().
|
||||||
|
SetEmail("snapshot-webhook@example.com").
|
||||||
|
SetPasswordHash("hash").
|
||||||
|
SetUsername("snapshot-webhook").
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
wxpayConfigA := encryptWebhookProviderConfig(t, map[string]string{
|
||||||
|
"appId": "wx-app-snapshot-a",
|
||||||
|
"mchId": "mch-snapshot-a",
|
||||||
|
"privateKey": "private-key-snapshot-a",
|
||||||
|
"apiV3Key": webhookProviderTestEncryptionKey,
|
||||||
|
"publicKey": "public-key-snapshot-a",
|
||||||
|
"publicKeyId": "public-key-id-snapshot-a",
|
||||||
|
"certSerial": "cert-serial-snapshot-a",
|
||||||
|
})
|
||||||
|
wxpayConfigB := encryptWebhookProviderConfig(t, map[string]string{
|
||||||
|
"appId": "wx-app-snapshot-b",
|
||||||
|
"mchId": "mch-snapshot-b",
|
||||||
|
"privateKey": "private-key-snapshot-b",
|
||||||
|
"apiV3Key": webhookProviderTestEncryptionKey,
|
||||||
|
"publicKey": "public-key-snapshot-b",
|
||||||
|
"publicKeyId": "public-key-id-snapshot-b",
|
||||||
|
"certSerial": "cert-serial-snapshot-b",
|
||||||
|
})
|
||||||
|
instA, err := client.PaymentProviderInstance.Create().
|
||||||
|
SetProviderKey(payment.TypeWxpay).
|
||||||
|
SetName("wxpay-snapshot-a").
|
||||||
|
SetConfig(wxpayConfigA).
|
||||||
|
SetSupportedTypes("wxpay").
|
||||||
|
SetEnabled(true).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = client.PaymentProviderInstance.Create().
|
||||||
|
SetProviderKey(payment.TypeWxpay).
|
||||||
|
SetName("wxpay-snapshot-b").
|
||||||
|
SetConfig(wxpayConfigB).
|
||||||
|
SetSupportedTypes("wxpay").
|
||||||
|
SetEnabled(true).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = client.PaymentOrder.Create().
|
||||||
|
SetUserID(user.ID).
|
||||||
|
SetUserEmail(user.Email).
|
||||||
|
SetUserName(user.Username).
|
||||||
|
SetAmount(66).
|
||||||
|
SetPayAmount(66).
|
||||||
|
SetFeeRate(0).
|
||||||
|
SetRechargeCode("SNAPSHOT-WEBHOOK").
|
||||||
|
SetOutTradeNo("sub2_test_snapshot_webhook_order").
|
||||||
|
SetPaymentType(payment.TypeWxpay).
|
||||||
|
SetPaymentTradeNo("").
|
||||||
|
SetOrderType(payment.OrderTypeBalance).
|
||||||
|
SetStatus(OrderStatusPending).
|
||||||
|
SetExpiresAt(time.Now().Add(time.Hour)).
|
||||||
|
SetClientIP("127.0.0.1").
|
||||||
|
SetSrcHost("api.example.com").
|
||||||
|
SetProviderSnapshot(map[string]any{
|
||||||
|
"schema_version": 1,
|
||||||
|
"provider_instance_id": strconv.FormatInt(instA.ID, 10),
|
||||||
|
"provider_key": payment.TypeWxpay,
|
||||||
|
"payment_mode": "native",
|
||||||
|
}).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
svc := &PaymentService{
|
||||||
|
entClient: client,
|
||||||
|
loadBalancer: newWebhookProviderTestLoadBalancer(client),
|
||||||
|
registry: payment.NewRegistry(),
|
||||||
|
providersLoaded: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
providers, err := svc.GetWebhookProviders(ctx, payment.TypeWxpay, "sub2_test_snapshot_webhook_order")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, providers, 1)
|
||||||
|
require.Equal(t, payment.TypeWxpay, providers[0].ProviderKey())
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user