diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go index 44818b37..519455f0 100644 --- a/backend/internal/service/payment_fulfillment.go +++ b/backend/internal/service/payment_fulfillment.go @@ -41,6 +41,19 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo slog.Error("order not found", "orderID", oid) return nil } + instanceProviderKey := "" + if inst, instErr := s.getOrderProviderInstance(ctx, o); instErr == nil && inst != nil { + instanceProviderKey = inst.ProviderKey + } + expectedProviderKey := expectedNotificationProviderKey(s.registry, o.PaymentType, instanceProviderKey) + if expectedProviderKey != "" && strings.TrimSpace(pk) != "" && !strings.EqualFold(expectedProviderKey, strings.TrimSpace(pk)) { + s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_MISMATCH", pk, map[string]any{ + "expectedProvider": expectedProviderKey, + "actualProvider": pk, + "tradeNo": tradeNo, + }) + return fmt.Errorf("provider mismatch: expected %s, got %s", expectedProviderKey, pk) + } // Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount). // Also skip if paid is NaN/Inf (malformed provider data). if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) { @@ -56,6 +69,18 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo return s.toPaid(ctx, o, tradeNo, paid, pk) } +func expectedNotificationProviderKey(registry *payment.Registry, orderPaymentType string, instanceProviderKey string) string { + if key := strings.TrimSpace(instanceProviderKey); key != "" { + return key + } + if registry != nil { + if key := strings.TrimSpace(registry.GetProviderKey(payment.PaymentType(orderPaymentType))); key != "" { + return key + } + } + return strings.TrimSpace(orderPaymentType) +} + func (s *PaymentService) toPaid(ctx context.Context, o *dbent.PaymentOrder, tradeNo string, paid float64, pk string) error { previousStatus := o.Status now := time.Now() diff --git a/backend/internal/service/payment_fulfillment_test.go b/backend/internal/service/payment_fulfillment_test.go index 625b0d9f..4cc00301 100644 --- a/backend/internal/service/payment_fulfillment_test.go +++ b/backend/internal/service/payment_fulfillment_test.go @@ -3,12 +3,37 @@ package service import ( + "context" "errors" "testing" + "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/stretchr/testify/assert" ) +type paymentFulfillmentTestProvider struct { + key string + supportedTypes []payment.PaymentType +} + +func (p paymentFulfillmentTestProvider) Name() string { return p.key } +func (p paymentFulfillmentTestProvider) ProviderKey() string { return p.key } +func (p paymentFulfillmentTestProvider) SupportedTypes() []payment.PaymentType { + return p.supportedTypes +} +func (p paymentFulfillmentTestProvider) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { + panic("unexpected call") +} +func (p paymentFulfillmentTestProvider) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) { + panic("unexpected call") +} +func (p paymentFulfillmentTestProvider) VerifyNotification(ctx context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) { + panic("unexpected call") +} +func (p paymentFulfillmentTestProvider) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) { + panic("unexpected call") +} + // --------------------------------------------------------------------------- // resolveRedeemAction — pure idempotency decision logic // --------------------------------------------------------------------------- @@ -161,3 +186,42 @@ func TestResolveRedeemAction_IsUsedCanUseConsistency(t *testing.T) { assert.True(t, unusedCode.CanUse()) assert.Equal(t, redeemActionRedeem, resolveRedeemAction(unusedCode, nil)) } + +func TestExpectedNotificationProviderKeyPrefersOrderInstanceProvider(t *testing.T) { + t.Parallel() + + registry := payment.NewRegistry() + registry.Register(paymentFulfillmentTestProvider{ + key: payment.TypeAlipay, + supportedTypes: []payment.PaymentType{payment.TypeAlipay}, + }) + + assert.Equal(t, + payment.TypeEasyPay, + expectedNotificationProviderKey(registry, payment.TypeAlipay, payment.TypeEasyPay), + ) +} + +func TestExpectedNotificationProviderKeyUsesRegistryMappingForLegacyOrders(t *testing.T) { + t.Parallel() + + registry := payment.NewRegistry() + registry.Register(paymentFulfillmentTestProvider{ + key: payment.TypeEasyPay, + supportedTypes: []payment.PaymentType{payment.TypeAlipay}, + }) + + assert.Equal(t, + payment.TypeEasyPay, + expectedNotificationProviderKey(registry, payment.TypeAlipay, ""), + ) +} + +func TestExpectedNotificationProviderKeyFallsBackToPaymentType(t *testing.T) { + t.Parallel() + + assert.Equal(t, + payment.TypeWxpay, + expectedNotificationProviderKey(nil, payment.TypeWxpay, ""), + ) +}