From 1d8432b8a4a3f2f6fde27ee0d5c01f00783a1c04 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Tue, 21 Apr 2026 01:40:56 +0800 Subject: [PATCH] fix: harden payment resume and wxpay webhook routing --- .../handler/payment_webhook_handler.go | 37 +++++- .../handler/payment_webhook_handler_test.go | 70 +++++++++++ backend/internal/payment/load_balancer.go | 38 ++++++ backend/internal/service/payment_order.go | 66 +++++++++-- .../service/payment_order_jsapi_test.go | 112 ++++++++++++++++++ .../service/payment_resume_service.go | 32 ++++- .../service/payment_resume_service_test.go | 64 ++++++++++ .../service/payment_webhook_provider.go | 76 +++++++++++- .../service/payment_webhook_provider_test.go | 62 +++++++++- 9 files changed, 526 insertions(+), 31 deletions(-) create mode 100644 backend/internal/service/payment_order_jsapi_test.go diff --git a/backend/internal/handler/payment_webhook_handler.go b/backend/internal/handler/payment_webhook_handler.go index 9fdefa93..c06a5b7e 100644 --- a/backend/internal/handler/payment_webhook_handler.go +++ b/backend/internal/handler/payment_webhook_handler.go @@ -1,6 +1,8 @@ package handler import ( + "context" + "fmt" "io" "log/slog" "net/http" @@ -77,9 +79,13 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string) // This is needed when multiple instances of the same provider exist (e.g. multiple EasyPay accounts). outTradeNo := extractOutTradeNo(rawBody, providerKey) - provider, err := h.paymentService.GetWebhookProvider(c.Request.Context(), providerKey, outTradeNo) + providers, err := h.paymentService.GetWebhookProviders(c.Request.Context(), providerKey, outTradeNo) if err != nil { slog.Warn("[Payment Webhook] provider not found", "provider", providerKey, "outTradeNo", outTradeNo, "error", err) + if providerKey == payment.TypeWxpay { + c.String(http.StatusBadRequest, "verify failed") + return + } writeSuccessResponse(c, providerKey) return } @@ -89,7 +95,7 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string) headers[strings.ToLower(k)] = c.GetHeader(k) } - notification, err := provider.VerifyNotification(c.Request.Context(), rawBody, headers) + resolvedProviderKey, notification, err := verifyNotificationWithProviders(c.Request.Context(), providers, rawBody, headers) if err != nil { truncatedBody := rawBody if len(truncatedBody) > webhookLogTruncateLen { @@ -103,17 +109,17 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string) // nil notification means irrelevant event (e.g. Stripe non-payment event); return success. if notification == nil { - writeSuccessResponse(c, providerKey) + writeSuccessResponse(c, resolvedProviderKey) return } - if err := h.paymentService.HandlePaymentNotification(c.Request.Context(), notification, providerKey); err != nil { - slog.Error("[Payment Webhook] handle notification failed", "provider", providerKey, "error", err) + if err := h.paymentService.HandlePaymentNotification(c.Request.Context(), notification, resolvedProviderKey); err != nil { + slog.Error("[Payment Webhook] handle notification failed", "provider", resolvedProviderKey, "error", err) c.String(http.StatusInternalServerError, "handle failed") return } - writeSuccessResponse(c, providerKey) + writeSuccessResponse(c, resolvedProviderKey) } // extractOutTradeNo parses the webhook body to find the out_trade_no. @@ -131,6 +137,25 @@ func extractOutTradeNo(rawBody, providerKey string) string { return "" } +func verifyNotificationWithProviders(ctx context.Context, providers []payment.Provider, rawBody string, headers map[string]string) (string, *payment.PaymentNotification, error) { + var lastErr error + for _, provider := range providers { + if provider == nil { + continue + } + notification, err := provider.VerifyNotification(ctx, rawBody, headers) + if err != nil { + lastErr = err + continue + } + return provider.ProviderKey(), notification, nil + } + if lastErr != nil { + return "", nil, lastErr + } + return "", nil, fmt.Errorf("no webhook provider could verify notification") +} + // wxpaySuccessResponse is the JSON response expected by WeChat Pay webhook. type wxpaySuccessResponse struct { Code string `json:"code"` diff --git a/backend/internal/handler/payment_webhook_handler_test.go b/backend/internal/handler/payment_webhook_handler_test.go index 6f448131..88221b5c 100644 --- a/backend/internal/handler/payment_webhook_handler_test.go +++ b/backend/internal/handler/payment_webhook_handler_test.go @@ -3,11 +3,14 @@ package handler import ( + "context" "encoding/json" + "errors" "net/http" "net/http/httptest" "testing" + "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -131,3 +134,70 @@ func TestExtractOutTradeNo(t *testing.T) { }) } } + +func TestVerifyNotificationWithProvidersReturnsMatchedProvider(t *testing.T) { + firstErr := errors.New("wrong provider") + providers := []payment.Provider{ + webhookHandlerProviderStub{ + key: payment.TypeWxpay, + verifyErr: firstErr, + }, + webhookHandlerProviderStub{ + key: payment.TypeWxpay, + notification: &payment.PaymentNotification{ + OrderID: "sub2_42", + TradeNo: "trade-42", + Status: payment.NotificationStatusSuccess, + }, + }, + } + + providerKey, notification, err := verifyNotificationWithProviders(context.Background(), providers, "{}", map[string]string{"wechatpay-signature": "sig"}) + require.NoError(t, err) + require.Equal(t, payment.TypeWxpay, providerKey) + require.NotNil(t, notification) + require.Equal(t, "sub2_42", notification.OrderID) +} + +func TestVerifyNotificationWithProvidersFailsWhenAllProvidersReject(t *testing.T) { + providers := []payment.Provider{ + webhookHandlerProviderStub{ + key: payment.TypeWxpay, + verifyErr: errors.New("verify failed a"), + }, + webhookHandlerProviderStub{ + key: payment.TypeWxpay, + verifyErr: errors.New("verify failed b"), + }, + } + + _, _, err := verifyNotificationWithProviders(context.Background(), providers, "{}", nil) + require.Error(t, err) +} + +type webhookHandlerProviderStub struct { + key string + notification *payment.PaymentNotification + verifyErr error +} + +func (p webhookHandlerProviderStub) Name() string { return p.key } +func (p webhookHandlerProviderStub) ProviderKey() string { return p.key } +func (p webhookHandlerProviderStub) SupportedTypes() []payment.PaymentType { + return []payment.PaymentType{payment.PaymentType(p.key)} +} +func (p webhookHandlerProviderStub) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { + panic("unexpected call") +} +func (p webhookHandlerProviderStub) QueryOrder(context.Context, string) (*payment.QueryOrderResponse, error) { + panic("unexpected call") +} +func (p webhookHandlerProviderStub) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) { + if p.verifyErr != nil { + return nil, p.verifyErr + } + return p.notification, nil +} +func (p webhookHandlerProviderStub) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) { + panic("unexpected call") +} diff --git a/backend/internal/payment/load_balancer.go b/backend/internal/payment/load_balancer.go index ec4ed1d3..ddf792bb 100644 --- a/backend/internal/payment/load_balancer.go +++ b/backend/internal/payment/load_balancer.go @@ -45,11 +45,31 @@ type DefaultLoadBalancer struct { counter atomic.Uint64 } +type contextKey string + +const wxpayJSAPIAppIDContextKey contextKey = "payment.wxpay.jsapi_app_id" + // NewDefaultLoadBalancer creates a new load balancer. func NewDefaultLoadBalancer(db *dbent.Client, encryptionKey []byte) *DefaultLoadBalancer { return &DefaultLoadBalancer{db: db, encryptionKey: encryptionKey} } +func WithWxpayJSAPIAppID(ctx context.Context, appID string) context.Context { + appID = strings.TrimSpace(appID) + if appID == "" { + return ctx + } + return context.WithValue(ctx, wxpayJSAPIAppIDContextKey, appID) +} + +func wxpayJSAPIAppIDFromContext(ctx context.Context) string { + if ctx == nil { + return "" + } + appID, _ := ctx.Value(wxpayJSAPIAppIDContextKey).(string) + return strings.TrimSpace(appID) +} + // instanceCandidate pairs an instance with its pre-fetched daily usage. type instanceCandidate struct { inst *dbent.PaymentProviderInstance @@ -116,6 +136,7 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances( } var matched []*dbent.PaymentProviderInstance + expectedWxpayJSAPIAppID := wxpayJSAPIAppIDFromContext(ctx) for _, inst := range instances { // Stripe: match by provider_key because supported_types lists sub-types (card,link,alipay,wxpay), // not "stripe" itself. The checkout page aggregates all sub-types under "stripe". @@ -124,6 +145,16 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances( matched = append(matched, inst) } } else if InstanceSupportsType(inst.SupportedTypes, paymentType) { + if expectedWxpayJSAPIAppID != "" && normalizeVisibleMethodSupportType(paymentType) == TypeWxpay && inst.ProviderKey == TypeWxpay { + config, cfgErr := lb.decryptConfig(inst.Config) + if cfgErr != nil { + slog.Warn("skip wxpay instance with unreadable config during jsapi filtering", "instance_id", inst.ID, "error", cfgErr) + continue + } + if resolveWxpayJSAPIAppID(config) != expectedWxpayJSAPIAppID { + continue + } + } matched = append(matched, inst) } } @@ -358,6 +389,13 @@ func legacyVisibleMethodAlias(paymentType PaymentType) PaymentType { } } +func resolveWxpayJSAPIAppID(config map[string]string) string { + if appID := strings.TrimSpace(config["mpAppId"]); appID != "" { + return appID + } + return strings.TrimSpace(config["appId"]) +} + // GetInstanceConfig decrypts and returns the configuration for a provider instance by ID. func (lb *DefaultLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) { inst, err := lb.db.PaymentProviderInstance.Get(ctx, instanceID) diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go index 221d6b94..7d973b92 100644 --- a/backend/internal/service/payment_order.go +++ b/backend/internal/service/payment_order.go @@ -216,7 +216,11 @@ func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, user } func (s *PaymentService) selectCreateOrderInstance(ctx context.Context, req CreateOrderRequest, cfg *PaymentConfig, payAmount float64) (*payment.InstanceSelection, error) { - sel, err := s.loadBalancer.SelectInstance(ctx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount) + selectCtx, err := s.prepareCreateOrderSelectionContext(ctx, req) + if err != nil { + return nil, err + } + sel, err := s.loadBalancer.SelectInstance(selectCtx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount) if err != nil { return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment method (%s) is not configured", req.PaymentType)) } @@ -226,6 +230,44 @@ func (s *PaymentService) selectCreateOrderInstance(ctx context.Context, req Crea return sel, nil } +func (s *PaymentService) prepareCreateOrderSelectionContext(ctx context.Context, req CreateOrderRequest) (context.Context, error) { + if !requestNeedsWeChatJSAPICompatibility(req) { + return ctx, nil + } + if !s.usesOfficialWxpayVisibleMethod(ctx) { + return ctx, nil + } + expectedAppID, _, err := s.getWeChatPaymentOAuthCredential(ctx) + if err != nil { + return nil, err + } + return payment.WithWxpayJSAPIAppID(ctx, expectedAppID), nil +} + +func requestNeedsWeChatJSAPICompatibility(req CreateOrderRequest) bool { + if payment.GetBasePaymentType(req.PaymentType) != payment.TypeWxpay { + return false + } + return req.IsWeChatBrowser || strings.TrimSpace(req.OpenID) != "" +} + +func (s *PaymentService) usesOfficialWxpayVisibleMethod(ctx context.Context) bool { + if s == nil || s.configService == nil || s.configService.settingRepo == nil { + return false + } + vals, err := s.configService.settingRepo.GetMultiple(ctx, []string{ + SettingPaymentVisibleMethodWxpayEnabled, + SettingPaymentVisibleMethodWxpaySource, + }) + if err != nil { + return false + } + if vals[SettingPaymentVisibleMethodWxpayEnabled] != "true" { + return false + } + return NormalizeVisibleMethodSource(payment.TypeWxpay, vals[SettingPaymentVisibleMethodWxpaySource]) == VisibleMethodSourceOfficialWechat +} + func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, limitAmount float64, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan, sel *payment.InstanceSelection) (*CreateOrderResponse, error) { prov, err := provider.CreateProvider(sel.ProviderKey, sel.InstanceID, sel.Config) if err != nil { @@ -239,16 +281,18 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen } resumeToken := "" if resume := s.paymentResume(); resume != nil { - resumeToken, err = resume.CreateToken(ResumeTokenClaims{ - OrderID: order.ID, - UserID: order.UserID, - ProviderInstanceID: sel.InstanceID, - ProviderKey: sel.ProviderKey, - PaymentType: req.PaymentType, - CanonicalReturnURL: canonicalReturnURL, - }) - if err != nil { - return nil, fmt.Errorf("create payment resume token: %w", err) + if resume.isSigningConfigured() { + resumeToken, err = resume.CreateToken(ResumeTokenClaims{ + OrderID: order.ID, + UserID: order.UserID, + ProviderInstanceID: sel.InstanceID, + ProviderKey: sel.ProviderKey, + PaymentType: req.PaymentType, + CanonicalReturnURL: canonicalReturnURL, + }) + if err != nil { + return nil, fmt.Errorf("create payment resume token: %w", err) + } } } providerReturnURL, err := buildPaymentReturnURL(canonicalReturnURL, order.ID, resumeToken) diff --git a/backend/internal/service/payment_order_jsapi_test.go b/backend/internal/service/payment_order_jsapi_test.go new file mode 100644 index 00000000..08492432 --- /dev/null +++ b/backend/internal/service/payment_order_jsapi_test.go @@ -0,0 +1,112 @@ +package service + +import ( + "context" + "encoding/json" + "fmt" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/payment" +) + +const jsapiTestEncryptionKey = "0123456789abcdef0123456789abcdef" + +func TestSelectCreateOrderInstancePrefersJSAPICompatibleWxpayInstance(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + compatibleConfig := mustEncryptJSAPITestConfig(t, map[string]string{ + "appId": "wx-merchant-app", + "mpAppId": "wx-mp-app", + "mchId": "mch-compatible", + "privateKey": "private-key", + "apiV3Key": jsapiTestEncryptionKey, + "publicKey": "public-key", + "publicKeyId": "key-compatible", + "certSerial": "serial-compatible", + }) + incompatibleConfig := mustEncryptJSAPITestConfig(t, map[string]string{ + "appId": "wx-merchant-other", + "mpAppId": "wx-mp-other", + "mchId": "mch-incompatible", + "privateKey": "private-key", + "apiV3Key": jsapiTestEncryptionKey, + "publicKey": "public-key", + "publicKeyId": "key-incompatible", + "certSerial": "serial-incompatible", + }) + + compatible, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-compatible"). + SetConfig(compatibleConfig). + SetSupportedTypes("wxpay"). + SetEnabled(true). + SetSortOrder(1). + Save(ctx) + if err != nil { + t.Fatalf("create compatible wxpay instance: %v", err) + } + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-incompatible"). + SetConfig(incompatibleConfig). + SetSupportedTypes("wxpay"). + SetEnabled(true). + SetSortOrder(2). + Save(ctx) + if err != nil { + t.Fatalf("create incompatible wxpay instance: %v", err) + } + + configService := &PaymentConfigService{ + entClient: client, + settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{ + SettingPaymentVisibleMethodWxpayEnabled: "true", + SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat, + }}, + encryptionKey: []byte(jsapiTestEncryptionKey), + } + loadBalancer := newVisibleMethodLoadBalancer( + payment.NewDefaultLoadBalancer(client, []byte(jsapiTestEncryptionKey)), + configService, + ) + svc := &PaymentService{ + entClient: client, + loadBalancer: loadBalancer, + configService: configService, + } + + t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx-mp-app") + t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wechat-secret") + + sel, err := svc.selectCreateOrderInstance(ctx, CreateOrderRequest{ + PaymentType: payment.TypeWxpay, + OpenID: "openid-123", + IsWeChatBrowser: true, + }, &PaymentConfig{LoadBalanceStrategy: string(payment.StrategyRoundRobin)}, 12.5) + if err != nil { + t.Fatalf("selectCreateOrderInstance returned error: %v", err) + } + if sel == nil { + t.Fatal("expected selected instance, got nil") + } + expectedInstanceID := fmt.Sprintf("%d", compatible.ID) + if sel.InstanceID != expectedInstanceID { + t.Fatalf("selected instance id = %q, want %q", sel.InstanceID, expectedInstanceID) + } +} + +func mustEncryptJSAPITestConfig(t *testing.T, config map[string]string) string { + t.Helper() + + data, err := json.Marshal(config) + if err != nil { + t.Fatalf("marshal config: %v", err) + } + encrypted, err := payment.Encrypt(string(data), []byte(jsapiTestEncryptionKey)) + if err != nil { + t.Fatalf("encrypt config: %v", err) + } + return encrypted +} diff --git a/backend/internal/service/payment_resume_service.go b/backend/internal/service/payment_resume_service.go index 64d1d125..05997b38 100644 --- a/backend/internal/service/payment_resume_service.go +++ b/backend/internal/service/payment_resume_service.go @@ -33,6 +33,9 @@ const ( VisibleMethodSourceEasyPayWechat = "easypay_wxpay" wechatPaymentResumeTokenType = "wechat_payment_resume" + + paymentResumeNotConfiguredCode = "PAYMENT_RESUME_NOT_CONFIGURED" + paymentResumeNotConfiguredMessage = "payment resume tokens require a configured signing key" ) type ResumeTokenClaims struct { @@ -70,6 +73,17 @@ func NewPaymentResumeService(signingKey []byte) *PaymentResumeService { return &PaymentResumeService{signingKey: signingKey} } +func (s *PaymentResumeService) isSigningConfigured() bool { + return s != nil && len(s.signingKey) > 0 +} + +func (s *PaymentResumeService) ensureSigningKey() error { + if s.isSigningConfigured() { + return nil + } + return infraerrors.ServiceUnavailable(paymentResumeNotConfiguredCode, paymentResumeNotConfiguredMessage) +} + func NormalizeVisibleMethod(method string) string { return payment.GetBasePaymentType(strings.TrimSpace(method)) } @@ -240,6 +254,9 @@ func buildPaymentReturnURL(base string, orderID int64, resumeToken string) (stri } func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, error) { + if err := s.ensureSigningKey(); err != nil { + return "", err + } if claims.OrderID <= 0 { return "", fmt.Errorf("resume token requires order id") } @@ -250,6 +267,9 @@ func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, er } func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, error) { + if err := s.ensureSigningKey(); err != nil { + return nil, err + } var claims ResumeTokenClaims if err := s.parseSignedToken(token, &claims); err != nil { return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is invalid") @@ -261,6 +281,9 @@ func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, err } func (s *PaymentResumeService) CreateWeChatPaymentResumeToken(claims WeChatPaymentResumeClaims) (string, error) { + if err := s.ensureSigningKey(); err != nil { + return "", err + } claims.OpenID = strings.TrimSpace(claims.OpenID) if claims.OpenID == "" { return "", fmt.Errorf("wechat payment resume token requires openid") @@ -282,6 +305,9 @@ func (s *PaymentResumeService) CreateWeChatPaymentResumeToken(claims WeChatPayme } func (s *PaymentResumeService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) { + if err := s.ensureSigningKey(); err != nil { + return nil, err + } var claims WeChatPaymentResumeClaims if err := s.parseSignedToken(token, &claims); err != nil { return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token payload is invalid") @@ -330,11 +356,7 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error { } func (s *PaymentResumeService) sign(payload string) string { - key := s.signingKey - if len(key) == 0 { - key = []byte(paymentResumeFallbackSigningKey) - } - mac := hmac.New(sha256.New, key) + mac := hmac.New(sha256.New, s.signingKey) _, _ = mac.Write([]byte(payload)) return base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) } diff --git a/backend/internal/service/payment_resume_service_test.go b/backend/internal/service/payment_resume_service_test.go index 24d50494..4ac89f0f 100644 --- a/backend/internal/service/payment_resume_service_test.go +++ b/backend/internal/service/payment_resume_service_test.go @@ -4,6 +4,10 @@ package service import ( "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" "net/url" "strconv" "testing" @@ -150,6 +154,27 @@ func TestPaymentResumeTokenRoundTrip(t *testing.T) { } } +func TestCreateTokenRejectsMissingSigningKey(t *testing.T) { + t.Parallel() + + svc := NewPaymentResumeService(nil) + _, err := svc.CreateToken(ResumeTokenClaims{OrderID: 42}) + if err == nil { + t.Fatal("CreateToken should reject missing signing key") + } +} + +func TestParseTokenRejectsFallbackSignedTokenWhenSigningKeyMissing(t *testing.T) { + t.Parallel() + + token := mustCreateFallbackSignedToken(t, ResumeTokenClaims{OrderID: 42, UserID: 7}) + svc := NewPaymentResumeService(nil) + _, err := svc.ParseToken(token) + if err == nil { + t.Fatal("ParseToken should reject tokens when signing key is missing") + } +} + func TestWeChatPaymentResumeTokenRoundTrip(t *testing.T) { t.Parallel() @@ -183,6 +208,31 @@ func TestWeChatPaymentResumeTokenRoundTrip(t *testing.T) { } } +func TestCreateWeChatPaymentResumeTokenRejectsMissingSigningKey(t *testing.T) { + t.Parallel() + + svc := NewPaymentResumeService(nil) + _, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{OpenID: "openid-123"}) + if err == nil { + t.Fatal("CreateWeChatPaymentResumeToken should reject missing signing key") + } +} + +func TestParseWeChatPaymentResumeTokenRejectsFallbackSignedTokenWhenSigningKeyMissing(t *testing.T) { + t.Parallel() + + token := mustCreateFallbackSignedToken(t, WeChatPaymentResumeClaims{ + TokenType: wechatPaymentResumeTokenType, + OpenID: "openid-123", + PaymentType: payment.TypeWxpay, + }) + svc := NewPaymentResumeService(nil) + _, err := svc.ParseWeChatPaymentResumeToken(token) + if err == nil { + t.Fatal("ParseWeChatPaymentResumeToken should reject tokens when signing key is missing") + } +} + func TestNormalizeVisibleMethodSource(t *testing.T) { t.Parallel() @@ -315,3 +365,17 @@ func (c *captureLoadBalancer) SelectInstance(_ context.Context, providerKey stri c.lastPaymentType = paymentType return &payment.InstanceSelection{ProviderKey: providerKey, SupportedTypes: paymentType}, nil } + +func mustCreateFallbackSignedToken(t *testing.T, claims any) string { + t.Helper() + + payload, err := json.Marshal(claims) + if err != nil { + t.Fatalf("marshal claims: %v", err) + } + encodedPayload := base64.RawURLEncoding.EncodeToString(payload) + mac := hmac.New(sha256.New, []byte(paymentResumeFallbackSigningKey)) + _, _ = mac.Write([]byte(encodedPayload)) + signature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) + return encodedPayload + "." + signature +} diff --git a/backend/internal/service/payment_webhook_provider.go b/backend/internal/service/payment_webhook_provider.go index 289d63ed..82dc9ea3 100644 --- a/backend/internal/service/payment_webhook_provider.go +++ b/backend/internal/service/payment_webhook_provider.go @@ -16,33 +16,70 @@ import ( // It resolves the original provider instance from the order whenever possible and // only falls back to a registry provider for legacy/single-instance scenarios. func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) { + providers, err := s.GetWebhookProviders(ctx, providerKey, outTradeNo) + if err != nil { + return nil, err + } + if len(providers) == 0 { + return nil, payment.ErrProviderNotFound + } + return providers[0], nil +} + +// GetWebhookProviders returns provider candidates that can verify the webhook. +// Official WeChat Pay may require multiple candidates because the callback body +// cannot be bound to a merchant before decryption. +func (s *PaymentService) GetWebhookProviders(ctx context.Context, providerKey, outTradeNo string) ([]payment.Provider, error) { if outTradeNo != "" { order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx) if err == nil { if psHasPinnedProviderInstance(order) { - return s.getPinnedOrderProvider(ctx, order) + prov, err := s.getPinnedOrderProvider(ctx, order) + if err != nil { + return nil, err + } + return []payment.Provider{prov}, nil } inst, err := s.getOrderProviderInstance(ctx, order) if err != nil { return nil, fmt.Errorf("load order provider instance: %w", err) } if inst != nil { - return s.createProviderFromInstance(ctx, inst) + prov, err := s.createProviderFromInstance(ctx, inst) + if err != nil { + return nil, err + } + return []payment.Provider{prov}, nil + } + if strings.TrimSpace(providerKey) == payment.TypeWxpay { + return s.getEnabledWebhookProvidersByKey(ctx, providerKey) } if !s.webhookRegistryFallbackAllowed(ctx, providerKey) { return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey) } s.EnsureProviders(ctx) - return s.registry.GetProviderByKey(providerKey) + prov, err := s.registry.GetProviderByKey(providerKey) + if err != nil { + return nil, err + } + return []payment.Provider{prov}, nil } } + if strings.TrimSpace(providerKey) == payment.TypeWxpay { + return s.getEnabledWebhookProvidersByKey(ctx, providerKey) + } + if !s.webhookRegistryFallbackAllowed(ctx, providerKey) { return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey) } s.EnsureProviders(ctx) - return s.registry.GetProviderByKey(providerKey) + prov, err := s.registry.GetProviderByKey(providerKey) + if err != nil { + return nil, err + } + return []payment.Provider{prov}, nil } func (s *PaymentService) getPinnedOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) { @@ -78,3 +115,34 @@ func (s *PaymentService) webhookRegistryFallbackAllowed(ctx context.Context, pro func psHasPinnedProviderInstance(order *dbent.PaymentOrder) bool { return order != nil && order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != "" } + +func (s *PaymentService) getEnabledWebhookProvidersByKey(ctx context.Context, providerKey string) ([]payment.Provider, error) { + providerKey = strings.TrimSpace(providerKey) + instances, err := s.entClient.PaymentProviderInstance.Query(). + Where( + paymentproviderinstance.ProviderKeyEQ(providerKey), + paymentproviderinstance.EnabledEQ(true), + ). + Order(dbent.Asc(paymentproviderinstance.FieldSortOrder)). + All(ctx) + if err != nil { + return nil, fmt.Errorf("query webhook provider instances: %w", err) + } + if len(instances) == 0 { + return nil, payment.ErrProviderNotFound + } + + providers := make([]payment.Provider, 0, len(instances)) + for _, inst := range instances { + prov, provErr := s.createProviderFromInstance(ctx, inst) + if provErr != nil { + slog.Warn("skip webhook provider instance", "provider", providerKey, "instanceID", inst.ID, "error", provErr) + continue + } + providers = append(providers, prov) + } + if len(providers) == 0 { + return nil, payment.ErrProviderNotFound + } + return providers, nil +} diff --git a/backend/internal/service/payment_webhook_provider_test.go b/backend/internal/service/payment_webhook_provider_test.go index 4f0b6848..15b447c2 100644 --- a/backend/internal/service/payment_webhook_provider_test.go +++ b/backend/internal/service/payment_webhook_provider_test.go @@ -208,10 +208,28 @@ func TestGetOrderProviderInstanceLeavesProviderKeyMatchUnresolvedWhenTypeNotSupp func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) { ctx := context.Background() client := newPaymentConfigServiceTestClient(t) + wxpayConfigA := encryptWebhookProviderConfig(t, map[string]string{ + "appId": "wx-app-a", + "mchId": "mch-a", + "privateKey": "private-key-a", + "apiV3Key": webhookProviderTestEncryptionKey, + "publicKey": "public-key-a", + "publicKeyId": "public-key-id-a", + "certSerial": "cert-serial-a", + }) + wxpayConfigB := encryptWebhookProviderConfig(t, map[string]string{ + "appId": "wx-app-b", + "mchId": "mch-b", + "privateKey": "private-key-b", + "apiV3Key": webhookProviderTestEncryptionKey, + "publicKey": "public-key-b", + "publicKeyId": "public-key-id-b", + "certSerial": "cert-serial-b", + }) _, err := client.PaymentProviderInstance.Create(). SetProviderKey(payment.TypeWxpay). SetName("wxpay-a"). - SetConfig("{}"). + SetConfig(wxpayConfigA). SetSupportedTypes("wxpay"). SetEnabled(true). Save(ctx) @@ -219,19 +237,51 @@ func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) { _, err = client.PaymentProviderInstance.Create(). SetProviderKey(payment.TypeWxpay). SetName("wxpay-b"). - SetConfig("{}"). + SetConfig(wxpayConfigB). SetSupportedTypes("wxpay"). SetEnabled(true). Save(ctx) require.NoError(t, err) + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + registry: payment.NewRegistry(), + providersLoaded: true, + } + + providers, err := svc.GetWebhookProviders(ctx, payment.TypeWxpay, "") + require.NoError(t, err) + require.Len(t, providers, 2) +} + +func TestGetWebhookProvidersRejectAmbiguousFallbackForNonWxpay(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeAlipay). + SetName("alipay-a"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeAlipay). + SetName("alipay-b"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + svc := &PaymentService{ entClient: client, registry: payment.NewRegistry(), providersLoaded: true, } - _, err = svc.GetWebhookProvider(ctx, payment.TypeWxpay, "") + _, err = svc.GetWebhookProviders(ctx, payment.TypeAlipay, "") require.Error(t, err) require.Contains(t, err.Error(), "ambiguous") } @@ -260,8 +310,10 @@ func TestGetWebhookProviderAllowsSingleInstanceRegistryFallback(t *testing.T) { providersLoaded: true, } - prov, err := svc.GetWebhookProvider(ctx, payment.TypeStripe, "") + providers, err := svc.GetWebhookProviders(ctx, payment.TypeStripe, "") require.NoError(t, err) + require.Len(t, providers, 1) + prov := providers[0] require.Equal(t, payment.TypeStripe, prov.ProviderKey()) } @@ -308,7 +360,7 @@ func TestGetWebhookProviderRejectsRegistryFallbackForPinnedOrder(t *testing.T) { providersLoaded: true, } - _, err = svc.GetWebhookProvider(ctx, payment.TypeWxpay, "sub2_test_pinned_order") + _, err = svc.GetWebhookProviders(ctx, payment.TypeWxpay, "sub2_test_pinned_order") require.Error(t, err) require.Contains(t, err.Error(), "provider instance") }