From bdcd3d87e530fc1bbbec4888a51ae79c1cb0e298 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 21:09:38 +0800 Subject: [PATCH] fix: resolve unique legacy payment providers --- .../service/payment_order_lifecycle.go | 46 +++++--- backend/internal/service/payment_refund.go | 79 ++++++++++++- .../service/payment_webhook_provider.go | 22 ++-- .../service/payment_webhook_provider_test.go | 109 ++++++++++++++++++ 4 files changed, 220 insertions(+), 36 deletions(-) diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go index f804eb8b..24d8b7a2 100644 --- a/backend/internal/service/payment_order_lifecycle.go +++ b/backend/internal/service/payment_order_lifecycle.go @@ -5,7 +5,6 @@ import ( "fmt" "log/slog" "strconv" - "strings" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -237,29 +236,38 @@ func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) // getOrderProvider creates a provider using the order's original instance config. // Falls back to registry lookup if instance ID is missing (legacy orders). func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) { - if o.ProviderInstanceID != nil && *o.ProviderInstanceID != "" { - instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64) - if err == nil { - cfg, err := s.loadBalancer.GetInstanceConfig(ctx, instID) - if err == nil { - providerKey := strings.TrimSpace(psStringValue(o.ProviderKey)) - if providerKey == "" { - providerKey = s.registry.GetProviderKey(o.PaymentType) - } - if providerKey == "" { - providerKey = o.PaymentType - } - p, err := provider.CreateProvider(providerKey, *o.ProviderInstanceID, cfg) - if err == nil { - return p, nil - } - } - } + inst, err := s.getOrderProviderInstance(ctx, o) + if err != nil { + return nil, fmt.Errorf("load order provider instance: %w", err) + } + if inst != nil { + return s.createProviderFromInstance(ctx, inst) } s.EnsureProviders(ctx) return s.registry.GetProvider(o.PaymentType) } +func (s *PaymentService) createProviderFromInstance(ctx context.Context, inst *dbent.PaymentProviderInstance) (payment.Provider, error) { + if inst == nil { + return nil, fmt.Errorf("payment provider instance is missing") + } + + cfg, err := s.loadBalancer.GetInstanceConfig(ctx, int64(inst.ID)) + if err != nil { + return nil, fmt.Errorf("load provider instance config: %w", err) + } + if inst.PaymentMode != "" { + cfg["paymentMode"] = inst.PaymentMode + } + + instID := strconv.FormatInt(int64(inst.ID), 10) + prov, err := provider.CreateProvider(inst.ProviderKey, instID, cfg) + if err != nil { + return nil, fmt.Errorf("create provider from instance: %w", err) + } + return prov, nil +} + func psStringValue(value *string) string { if value == nil { return "" diff --git a/backend/internal/service/payment_refund.go b/backend/internal/service/payment_refund.go index c5bda763..01eecff8 100644 --- a/backend/internal/service/payment_refund.go +++ b/backend/internal/service/payment_refund.go @@ -12,6 +12,7 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/paymentorder" + "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" "github.com/Wei-Shaw/sub2api/internal/payment" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) @@ -19,18 +20,90 @@ import ( // --- Refund Flow --- // getOrderProviderInstance looks up the provider instance that processed this order. -// Returns nil, nil for legacy orders without provider_instance_id. +// For legacy orders without provider_instance_id, it resolves only when the +// enabled instance is uniquely identifiable from the stored order fields. func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) { - if o.ProviderInstanceID == nil || *o.ProviderInstanceID == "" { + if s == nil || s.entClient == nil || o == nil { return nil, nil } - instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64) + + instIDStr := strings.TrimSpace(psStringValue(o.ProviderInstanceID)) + if instIDStr == "" { + return s.resolveUniqueLegacyOrderProviderInstance(ctx, o) + } + + instID, err := strconv.ParseInt(instIDStr, 10, 64) if err != nil { return nil, nil } return s.entClient.PaymentProviderInstance.Get(ctx, instID) } +func (s *PaymentService) resolveUniqueLegacyOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) { + providerKey := strings.TrimSpace(psStringValue(o.ProviderKey)) + if providerKey != "" { + instances, err := s.entClient.PaymentProviderInstance.Query(). + Where( + paymentproviderinstance.EnabledEQ(true), + paymentproviderinstance.ProviderKeyEQ(providerKey), + ). + All(ctx) + if err != nil { + return nil, err + } + if len(instances) == 1 { + return instances[0], nil + } + return nil, nil + } + + paymentType := payment.GetBasePaymentType(strings.TrimSpace(o.PaymentType)) + if paymentType == "" { + return nil, nil + } + + instances, err := s.entClient.PaymentProviderInstance.Query(). + Where(paymentproviderinstance.EnabledEQ(true)). + All(ctx) + if err != nil { + return nil, err + } + + var matched []*dbent.PaymentProviderInstance + for _, inst := range instances { + if psLegacyOrderMatchesInstance(paymentType, inst) { + matched = append(matched, inst) + } + } + if len(matched) == 1 { + return matched[0], nil + } + return nil, nil +} + +func psLegacyOrderMatchesInstance(orderPaymentType string, inst *dbent.PaymentProviderInstance) bool { + if inst == nil { + return false + } + + baseType := payment.GetBasePaymentType(strings.TrimSpace(orderPaymentType)) + instanceProviderKey := strings.TrimSpace(inst.ProviderKey) + if baseType == "" { + return false + } + + if baseType == payment.TypeStripe { + return instanceProviderKey == payment.TypeStripe + } + if instanceProviderKey == payment.TypeStripe { + return false + } + if instanceProviderKey == baseType { + return true + } + return payment.InstanceSupportsType(inst.SupportedTypes, baseType) +} + func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error { o, err := s.validateRefundRequest(ctx, oid, uid) if err != nil { diff --git a/backend/internal/service/payment_webhook_provider.go b/backend/internal/service/payment_webhook_provider.go index a877db2b..289d63ed 100644 --- a/backend/internal/service/payment_webhook_provider.go +++ b/backend/internal/service/payment_webhook_provider.go @@ -4,14 +4,12 @@ import ( "context" "fmt" "log/slog" - "strconv" "strings" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" "github.com/Wei-Shaw/sub2api/internal/payment" - "github.com/Wei-Shaw/sub2api/internal/payment/provider" ) // GetWebhookProvider returns the provider instance that should verify a webhook. @@ -24,6 +22,13 @@ func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, ou if psHasPinnedProviderInstance(order) { return s.getPinnedOrderProvider(ctx, order) } + inst, err := s.getOrderProviderInstance(ctx, order) + if err != nil { + return nil, fmt.Errorf("load order provider instance: %w", err) + } + if inst != nil { + return s.createProviderFromInstance(ctx, inst) + } if !s.webhookRegistryFallbackAllowed(ctx, providerKey) { return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey) } @@ -48,18 +53,7 @@ func (s *PaymentService) getPinnedOrderProvider(ctx context.Context, o *dbent.Pa if inst == nil { return nil, fmt.Errorf("order %d provider instance is missing", o.ID) } - - instID := strconv.FormatInt(int64(inst.ID), 10) - cfg, err := s.loadBalancer.GetInstanceConfig(ctx, int64(inst.ID)) - if err != nil { - return nil, fmt.Errorf("load provider instance config: %w", err) - } - - prov, err := provider.CreateProvider(inst.ProviderKey, instID, cfg) - if err != nil { - return nil, fmt.Errorf("create pinned provider: %w", err) - } - return prov, nil + return s.createProviderFromInstance(ctx, inst) } func (s *PaymentService) webhookRegistryFallbackAllowed(ctx context.Context, providerKey string) bool { diff --git a/backend/internal/service/payment_webhook_provider_test.go b/backend/internal/service/payment_webhook_provider_test.go index 85c296de..33e4186d 100644 --- a/backend/internal/service/payment_webhook_provider_test.go +++ b/backend/internal/service/payment_webhook_provider_test.go @@ -4,13 +4,17 @@ package service import ( "context" + "encoding/json" "testing" "time" + dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/stretchr/testify/require" ) +const webhookProviderTestEncryptionKey = "0123456789abcdef0123456789abcdef" + type webhookProviderTestDouble struct { key string types []payment.PaymentType @@ -32,6 +36,111 @@ func (p webhookProviderTestDouble) Refund(context.Context, payment.RefundRequest panic("unexpected call") } +func encryptWebhookProviderConfig(t *testing.T, config map[string]string) string { + t.Helper() + + data, err := json.Marshal(config) + require.NoError(t, err) + + encrypted, err := payment.Encrypt(string(data), []byte(webhookProviderTestEncryptionKey)) + require.NoError(t, err) + return encrypted +} + +func newWebhookProviderTestLoadBalancer(client *dbent.Client) payment.LoadBalancer { + return payment.NewDefaultLoadBalancer(client, []byte(webhookProviderTestEncryptionKey)) +} + +func TestGetOrderProviderInstanceResolvesUniqueLegacyProviderKey(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + inst, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeStripe). + SetName("stripe-a"). + SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_test_legacy_provider_key"})). + SetSupportedTypes("stripe"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + providerKey := payment.TypeStripe + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeStripe, + ProviderKey: &providerKey, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, inst.ID, got.ID) +} + +func TestGetOrderProviderInstanceResolvesUniqueLegacyPaymentType(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + inst, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-a"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeWxpayDirect, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, inst.ID, got.ID) +} + +func TestGetOrderProviderInstanceLeavesAmbiguousLegacyOrderUnresolved(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeEasyPay). + SetName("easypay-a"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-a"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeWxpay, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.NoError(t, err) + require.Nil(t, got) +} + func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) { ctx := context.Background() client := newPaymentConfigServiceTestClient(t)