diff --git a/backend/internal/service/payment_service.go b/backend/internal/service/payment_service.go index e897741a..d3175ba6 100644 --- a/backend/internal/service/payment_service.go +++ b/backend/internal/service/payment_service.go @@ -9,7 +9,6 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" - "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/Wei-Shaw/sub2api/internal/payment/provider" @@ -225,25 +224,6 @@ func (s *PaymentService) loadProviders(ctx context.Context) { } } -// GetWebhookProvider returns the provider instance that should verify a webhook. -// It extracts out_trade_no from the raw body, looks up the order to find the -// original provider instance, and creates a provider with that instance's credentials. -// Falls back to the registry provider when the order cannot be found. -func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) { - if outTradeNo != "" { - order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx) - if err == nil { - p, pErr := s.getOrderProvider(ctx, order) - if pErr == nil { - return p, nil - } - slog.Warn("[Webhook] order provider creation failed, falling back to registry", "outTradeNo", outTradeNo, "error", pErr) - } - } - s.EnsureProviders(ctx) - return s.registry.GetProviderByKey(providerKey) -} - // --- Helpers --- func psIsRefundStatus(s string) bool { diff --git a/backend/internal/service/payment_webhook_provider.go b/backend/internal/service/payment_webhook_provider.go new file mode 100644 index 00000000..a877db2b --- /dev/null +++ b/backend/internal/service/payment_webhook_provider.go @@ -0,0 +1,86 @@ +package service + +import ( + "context" + "fmt" + "log/slog" + "strconv" + "strings" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/paymentorder" + "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/Wei-Shaw/sub2api/internal/payment/provider" +) + +// GetWebhookProvider returns the provider instance that should verify a webhook. +// It resolves the original provider instance from the order whenever possible and +// only falls back to a registry provider for legacy/single-instance scenarios. +func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) { + if outTradeNo != "" { + order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx) + if err == nil { + if psHasPinnedProviderInstance(order) { + return s.getPinnedOrderProvider(ctx, order) + } + if !s.webhookRegistryFallbackAllowed(ctx, providerKey) { + return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey) + } + s.EnsureProviders(ctx) + return s.registry.GetProviderByKey(providerKey) + } + } + + if !s.webhookRegistryFallbackAllowed(ctx, providerKey) { + return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey) + } + + s.EnsureProviders(ctx) + return s.registry.GetProviderByKey(providerKey) +} + +func (s *PaymentService) getPinnedOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) { + inst, err := s.getOrderProviderInstance(ctx, o) + if err != nil { + return nil, fmt.Errorf("load order provider instance: %w", err) + } + if inst == nil { + return nil, fmt.Errorf("order %d provider instance is missing", o.ID) + } + + instID := strconv.FormatInt(int64(inst.ID), 10) + cfg, err := s.loadBalancer.GetInstanceConfig(ctx, int64(inst.ID)) + if err != nil { + return nil, fmt.Errorf("load provider instance config: %w", err) + } + + prov, err := provider.CreateProvider(inst.ProviderKey, instID, cfg) + if err != nil { + return nil, fmt.Errorf("create pinned provider: %w", err) + } + return prov, nil +} + +func (s *PaymentService) webhookRegistryFallbackAllowed(ctx context.Context, providerKey string) bool { + providerKey = strings.TrimSpace(providerKey) + if providerKey == "" || s == nil || s.entClient == nil { + return false + } + + count, err := s.entClient.PaymentProviderInstance.Query(). + Where( + paymentproviderinstance.ProviderKeyEQ(providerKey), + paymentproviderinstance.EnabledEQ(true), + ). + Count(ctx) + if err != nil { + slog.Warn("payment webhook fallback instance count failed", "provider", providerKey, "error", err) + return false + } + return count <= 1 +} + +func psHasPinnedProviderInstance(order *dbent.PaymentOrder) bool { + return order != nil && order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != "" +} diff --git a/backend/internal/service/payment_webhook_provider_test.go b/backend/internal/service/payment_webhook_provider_test.go new file mode 100644 index 00000000..85c296de --- /dev/null +++ b/backend/internal/service/payment_webhook_provider_test.go @@ -0,0 +1,141 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/stretchr/testify/require" +) + +type webhookProviderTestDouble struct { + key string + types []payment.PaymentType +} + +func (p webhookProviderTestDouble) Name() string { return p.key } +func (p webhookProviderTestDouble) ProviderKey() string { return p.key } +func (p webhookProviderTestDouble) SupportedTypes() []payment.PaymentType { return p.types } +func (p webhookProviderTestDouble) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { + panic("unexpected call") +} +func (p webhookProviderTestDouble) QueryOrder(context.Context, string) (*payment.QueryOrderResponse, error) { + panic("unexpected call") +} +func (p webhookProviderTestDouble) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) { + panic("unexpected call") +} +func (p webhookProviderTestDouble) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) { + panic("unexpected call") +} + +func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-a"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-b"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + svc := &PaymentService{ + entClient: client, + registry: payment.NewRegistry(), + providersLoaded: true, + } + + _, err = svc.GetWebhookProvider(ctx, payment.TypeWxpay, "") + require.Error(t, err) + require.Contains(t, err.Error(), "ambiguous") +} + +func TestGetWebhookProviderAllowsSingleInstanceRegistryFallback(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeStripe). + SetName("stripe-a"). + SetConfig("{}"). + SetSupportedTypes("stripe"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + registry := payment.NewRegistry() + registry.Register(webhookProviderTestDouble{ + key: payment.TypeStripe, + types: []payment.PaymentType{payment.TypeStripe}, + }) + + svc := &PaymentService{ + entClient: client, + registry: registry, + providersLoaded: true, + } + + prov, err := svc.GetWebhookProvider(ctx, payment.TypeStripe, "") + require.NoError(t, err) + require.Equal(t, payment.TypeStripe, prov.ProviderKey()) +} + +func TestGetWebhookProviderRejectsRegistryFallbackForPinnedOrder(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + user, err := client.User.Create(). + SetEmail("webhook@example.com"). + SetPasswordHash("hash"). + SetUsername("webhook"). + Save(ctx) + require.NoError(t, err) + + pinnedInstanceID := "999" + _, err = client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("TEST-RECHARGE"). + SetOutTradeNo("sub2_test_pinned_order"). + SetPaymentType(payment.TypeWxpay). + SetPaymentTradeNo(""). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + SetProviderInstanceID(pinnedInstanceID). + Save(ctx) + require.NoError(t, err) + + registry := payment.NewRegistry() + registry.Register(webhookProviderTestDouble{ + key: payment.TypeWxpay, + types: []payment.PaymentType{payment.TypeWxpay}, + }) + + svc := &PaymentService{ + entClient: client, + registry: registry, + providersLoaded: true, + } + + _, err = svc.GetWebhookProvider(ctx, payment.TypeWxpay, "sub2_test_pinned_order") + require.Error(t, err) + require.Contains(t, err.Error(), "provider instance") +}