fix: harden payment resume and wxpay webhook routing

This commit is contained in:
IanShaw027
2026-04-21 01:40:56 +08:00
parent 0a461d8248
commit 1d8432b8a4
9 changed files with 526 additions and 31 deletions

View File

@@ -1,6 +1,8 @@
package handler
import (
"context"
"fmt"
"io"
"log/slog"
"net/http"
@@ -77,9 +79,13 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
// This is needed when multiple instances of the same provider exist (e.g. multiple EasyPay accounts).
outTradeNo := extractOutTradeNo(rawBody, providerKey)
provider, err := h.paymentService.GetWebhookProvider(c.Request.Context(), providerKey, outTradeNo)
providers, err := h.paymentService.GetWebhookProviders(c.Request.Context(), providerKey, outTradeNo)
if err != nil {
slog.Warn("[Payment Webhook] provider not found", "provider", providerKey, "outTradeNo", outTradeNo, "error", err)
if providerKey == payment.TypeWxpay {
c.String(http.StatusBadRequest, "verify failed")
return
}
writeSuccessResponse(c, providerKey)
return
}
@@ -89,7 +95,7 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
headers[strings.ToLower(k)] = c.GetHeader(k)
}
notification, err := provider.VerifyNotification(c.Request.Context(), rawBody, headers)
resolvedProviderKey, notification, err := verifyNotificationWithProviders(c.Request.Context(), providers, rawBody, headers)
if err != nil {
truncatedBody := rawBody
if len(truncatedBody) > webhookLogTruncateLen {
@@ -103,17 +109,17 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
// nil notification means irrelevant event (e.g. Stripe non-payment event); return success.
if notification == nil {
writeSuccessResponse(c, providerKey)
writeSuccessResponse(c, resolvedProviderKey)
return
}
if err := h.paymentService.HandlePaymentNotification(c.Request.Context(), notification, providerKey); err != nil {
slog.Error("[Payment Webhook] handle notification failed", "provider", providerKey, "error", err)
if err := h.paymentService.HandlePaymentNotification(c.Request.Context(), notification, resolvedProviderKey); err != nil {
slog.Error("[Payment Webhook] handle notification failed", "provider", resolvedProviderKey, "error", err)
c.String(http.StatusInternalServerError, "handle failed")
return
}
writeSuccessResponse(c, providerKey)
writeSuccessResponse(c, resolvedProviderKey)
}
// extractOutTradeNo parses the webhook body to find the out_trade_no.
@@ -131,6 +137,25 @@ func extractOutTradeNo(rawBody, providerKey string) string {
return ""
}
func verifyNotificationWithProviders(ctx context.Context, providers []payment.Provider, rawBody string, headers map[string]string) (string, *payment.PaymentNotification, error) {
var lastErr error
for _, provider := range providers {
if provider == nil {
continue
}
notification, err := provider.VerifyNotification(ctx, rawBody, headers)
if err != nil {
lastErr = err
continue
}
return provider.ProviderKey(), notification, nil
}
if lastErr != nil {
return "", nil, lastErr
}
return "", nil, fmt.Errorf("no webhook provider could verify notification")
}
// wxpaySuccessResponse is the JSON response expected by WeChat Pay webhook.
type wxpaySuccessResponse struct {
Code string `json:"code"`

View File

@@ -3,11 +3,14 @@
package handler
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -131,3 +134,70 @@ func TestExtractOutTradeNo(t *testing.T) {
})
}
}
func TestVerifyNotificationWithProvidersReturnsMatchedProvider(t *testing.T) {
firstErr := errors.New("wrong provider")
providers := []payment.Provider{
webhookHandlerProviderStub{
key: payment.TypeWxpay,
verifyErr: firstErr,
},
webhookHandlerProviderStub{
key: payment.TypeWxpay,
notification: &payment.PaymentNotification{
OrderID: "sub2_42",
TradeNo: "trade-42",
Status: payment.NotificationStatusSuccess,
},
},
}
providerKey, notification, err := verifyNotificationWithProviders(context.Background(), providers, "{}", map[string]string{"wechatpay-signature": "sig"})
require.NoError(t, err)
require.Equal(t, payment.TypeWxpay, providerKey)
require.NotNil(t, notification)
require.Equal(t, "sub2_42", notification.OrderID)
}
func TestVerifyNotificationWithProvidersFailsWhenAllProvidersReject(t *testing.T) {
providers := []payment.Provider{
webhookHandlerProviderStub{
key: payment.TypeWxpay,
verifyErr: errors.New("verify failed a"),
},
webhookHandlerProviderStub{
key: payment.TypeWxpay,
verifyErr: errors.New("verify failed b"),
},
}
_, _, err := verifyNotificationWithProviders(context.Background(), providers, "{}", nil)
require.Error(t, err)
}
type webhookHandlerProviderStub struct {
key string
notification *payment.PaymentNotification
verifyErr error
}
func (p webhookHandlerProviderStub) Name() string { return p.key }
func (p webhookHandlerProviderStub) ProviderKey() string { return p.key }
func (p webhookHandlerProviderStub) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.PaymentType(p.key)}
}
func (p webhookHandlerProviderStub) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
panic("unexpected call")
}
func (p webhookHandlerProviderStub) QueryOrder(context.Context, string) (*payment.QueryOrderResponse, error) {
panic("unexpected call")
}
func (p webhookHandlerProviderStub) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) {
if p.verifyErr != nil {
return nil, p.verifyErr
}
return p.notification, nil
}
func (p webhookHandlerProviderStub) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) {
panic("unexpected call")
}

View File

@@ -45,11 +45,31 @@ type DefaultLoadBalancer struct {
counter atomic.Uint64
}
type contextKey string
const wxpayJSAPIAppIDContextKey contextKey = "payment.wxpay.jsapi_app_id"
// NewDefaultLoadBalancer creates a new load balancer.
func NewDefaultLoadBalancer(db *dbent.Client, encryptionKey []byte) *DefaultLoadBalancer {
return &DefaultLoadBalancer{db: db, encryptionKey: encryptionKey}
}
func WithWxpayJSAPIAppID(ctx context.Context, appID string) context.Context {
appID = strings.TrimSpace(appID)
if appID == "" {
return ctx
}
return context.WithValue(ctx, wxpayJSAPIAppIDContextKey, appID)
}
func wxpayJSAPIAppIDFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
appID, _ := ctx.Value(wxpayJSAPIAppIDContextKey).(string)
return strings.TrimSpace(appID)
}
// instanceCandidate pairs an instance with its pre-fetched daily usage.
type instanceCandidate struct {
inst *dbent.PaymentProviderInstance
@@ -116,6 +136,7 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances(
}
var matched []*dbent.PaymentProviderInstance
expectedWxpayJSAPIAppID := wxpayJSAPIAppIDFromContext(ctx)
for _, inst := range instances {
// Stripe: match by provider_key because supported_types lists sub-types (card,link,alipay,wxpay),
// not "stripe" itself. The checkout page aggregates all sub-types under "stripe".
@@ -124,6 +145,16 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances(
matched = append(matched, inst)
}
} else if InstanceSupportsType(inst.SupportedTypes, paymentType) {
if expectedWxpayJSAPIAppID != "" && normalizeVisibleMethodSupportType(paymentType) == TypeWxpay && inst.ProviderKey == TypeWxpay {
config, cfgErr := lb.decryptConfig(inst.Config)
if cfgErr != nil {
slog.Warn("skip wxpay instance with unreadable config during jsapi filtering", "instance_id", inst.ID, "error", cfgErr)
continue
}
if resolveWxpayJSAPIAppID(config) != expectedWxpayJSAPIAppID {
continue
}
}
matched = append(matched, inst)
}
}
@@ -358,6 +389,13 @@ func legacyVisibleMethodAlias(paymentType PaymentType) PaymentType {
}
}
func resolveWxpayJSAPIAppID(config map[string]string) string {
if appID := strings.TrimSpace(config["mpAppId"]); appID != "" {
return appID
}
return strings.TrimSpace(config["appId"])
}
// GetInstanceConfig decrypts and returns the configuration for a provider instance by ID.
func (lb *DefaultLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) {
inst, err := lb.db.PaymentProviderInstance.Get(ctx, instanceID)

View File

@@ -216,7 +216,11 @@ func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, user
}
func (s *PaymentService) selectCreateOrderInstance(ctx context.Context, req CreateOrderRequest, cfg *PaymentConfig, payAmount float64) (*payment.InstanceSelection, error) {
sel, err := s.loadBalancer.SelectInstance(ctx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount)
selectCtx, err := s.prepareCreateOrderSelectionContext(ctx, req)
if err != nil {
return nil, err
}
sel, err := s.loadBalancer.SelectInstance(selectCtx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount)
if err != nil {
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment method (%s) is not configured", req.PaymentType))
}
@@ -226,6 +230,44 @@ func (s *PaymentService) selectCreateOrderInstance(ctx context.Context, req Crea
return sel, nil
}
func (s *PaymentService) prepareCreateOrderSelectionContext(ctx context.Context, req CreateOrderRequest) (context.Context, error) {
if !requestNeedsWeChatJSAPICompatibility(req) {
return ctx, nil
}
if !s.usesOfficialWxpayVisibleMethod(ctx) {
return ctx, nil
}
expectedAppID, _, err := s.getWeChatPaymentOAuthCredential(ctx)
if err != nil {
return nil, err
}
return payment.WithWxpayJSAPIAppID(ctx, expectedAppID), nil
}
func requestNeedsWeChatJSAPICompatibility(req CreateOrderRequest) bool {
if payment.GetBasePaymentType(req.PaymentType) != payment.TypeWxpay {
return false
}
return req.IsWeChatBrowser || strings.TrimSpace(req.OpenID) != ""
}
func (s *PaymentService) usesOfficialWxpayVisibleMethod(ctx context.Context) bool {
if s == nil || s.configService == nil || s.configService.settingRepo == nil {
return false
}
vals, err := s.configService.settingRepo.GetMultiple(ctx, []string{
SettingPaymentVisibleMethodWxpayEnabled,
SettingPaymentVisibleMethodWxpaySource,
})
if err != nil {
return false
}
if vals[SettingPaymentVisibleMethodWxpayEnabled] != "true" {
return false
}
return NormalizeVisibleMethodSource(payment.TypeWxpay, vals[SettingPaymentVisibleMethodWxpaySource]) == VisibleMethodSourceOfficialWechat
}
func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, limitAmount float64, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan, sel *payment.InstanceSelection) (*CreateOrderResponse, error) {
prov, err := provider.CreateProvider(sel.ProviderKey, sel.InstanceID, sel.Config)
if err != nil {
@@ -239,16 +281,18 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
}
resumeToken := ""
if resume := s.paymentResume(); resume != nil {
resumeToken, err = resume.CreateToken(ResumeTokenClaims{
OrderID: order.ID,
UserID: order.UserID,
ProviderInstanceID: sel.InstanceID,
ProviderKey: sel.ProviderKey,
PaymentType: req.PaymentType,
CanonicalReturnURL: canonicalReturnURL,
})
if err != nil {
return nil, fmt.Errorf("create payment resume token: %w", err)
if resume.isSigningConfigured() {
resumeToken, err = resume.CreateToken(ResumeTokenClaims{
OrderID: order.ID,
UserID: order.UserID,
ProviderInstanceID: sel.InstanceID,
ProviderKey: sel.ProviderKey,
PaymentType: req.PaymentType,
CanonicalReturnURL: canonicalReturnURL,
})
if err != nil {
return nil, fmt.Errorf("create payment resume token: %w", err)
}
}
}
providerReturnURL, err := buildPaymentReturnURL(canonicalReturnURL, order.ID, resumeToken)

View File

@@ -0,0 +1,112 @@
package service
import (
"context"
"encoding/json"
"fmt"
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
const jsapiTestEncryptionKey = "0123456789abcdef0123456789abcdef"
func TestSelectCreateOrderInstancePrefersJSAPICompatibleWxpayInstance(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
compatibleConfig := mustEncryptJSAPITestConfig(t, map[string]string{
"appId": "wx-merchant-app",
"mpAppId": "wx-mp-app",
"mchId": "mch-compatible",
"privateKey": "private-key",
"apiV3Key": jsapiTestEncryptionKey,
"publicKey": "public-key",
"publicKeyId": "key-compatible",
"certSerial": "serial-compatible",
})
incompatibleConfig := mustEncryptJSAPITestConfig(t, map[string]string{
"appId": "wx-merchant-other",
"mpAppId": "wx-mp-other",
"mchId": "mch-incompatible",
"privateKey": "private-key",
"apiV3Key": jsapiTestEncryptionKey,
"publicKey": "public-key",
"publicKeyId": "key-incompatible",
"certSerial": "serial-incompatible",
})
compatible, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("wxpay-compatible").
SetConfig(compatibleConfig).
SetSupportedTypes("wxpay").
SetEnabled(true).
SetSortOrder(1).
Save(ctx)
if err != nil {
t.Fatalf("create compatible wxpay instance: %v", err)
}
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("wxpay-incompatible").
SetConfig(incompatibleConfig).
SetSupportedTypes("wxpay").
SetEnabled(true).
SetSortOrder(2).
Save(ctx)
if err != nil {
t.Fatalf("create incompatible wxpay instance: %v", err)
}
configService := &PaymentConfigService{
entClient: client,
settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{
SettingPaymentVisibleMethodWxpayEnabled: "true",
SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat,
}},
encryptionKey: []byte(jsapiTestEncryptionKey),
}
loadBalancer := newVisibleMethodLoadBalancer(
payment.NewDefaultLoadBalancer(client, []byte(jsapiTestEncryptionKey)),
configService,
)
svc := &PaymentService{
entClient: client,
loadBalancer: loadBalancer,
configService: configService,
}
t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx-mp-app")
t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wechat-secret")
sel, err := svc.selectCreateOrderInstance(ctx, CreateOrderRequest{
PaymentType: payment.TypeWxpay,
OpenID: "openid-123",
IsWeChatBrowser: true,
}, &PaymentConfig{LoadBalanceStrategy: string(payment.StrategyRoundRobin)}, 12.5)
if err != nil {
t.Fatalf("selectCreateOrderInstance returned error: %v", err)
}
if sel == nil {
t.Fatal("expected selected instance, got nil")
}
expectedInstanceID := fmt.Sprintf("%d", compatible.ID)
if sel.InstanceID != expectedInstanceID {
t.Fatalf("selected instance id = %q, want %q", sel.InstanceID, expectedInstanceID)
}
}
func mustEncryptJSAPITestConfig(t *testing.T, config map[string]string) string {
t.Helper()
data, err := json.Marshal(config)
if err != nil {
t.Fatalf("marshal config: %v", err)
}
encrypted, err := payment.Encrypt(string(data), []byte(jsapiTestEncryptionKey))
if err != nil {
t.Fatalf("encrypt config: %v", err)
}
return encrypted
}

View File

@@ -33,6 +33,9 @@ const (
VisibleMethodSourceEasyPayWechat = "easypay_wxpay"
wechatPaymentResumeTokenType = "wechat_payment_resume"
paymentResumeNotConfiguredCode = "PAYMENT_RESUME_NOT_CONFIGURED"
paymentResumeNotConfiguredMessage = "payment resume tokens require a configured signing key"
)
type ResumeTokenClaims struct {
@@ -70,6 +73,17 @@ func NewPaymentResumeService(signingKey []byte) *PaymentResumeService {
return &PaymentResumeService{signingKey: signingKey}
}
func (s *PaymentResumeService) isSigningConfigured() bool {
return s != nil && len(s.signingKey) > 0
}
func (s *PaymentResumeService) ensureSigningKey() error {
if s.isSigningConfigured() {
return nil
}
return infraerrors.ServiceUnavailable(paymentResumeNotConfiguredCode, paymentResumeNotConfiguredMessage)
}
func NormalizeVisibleMethod(method string) string {
return payment.GetBasePaymentType(strings.TrimSpace(method))
}
@@ -240,6 +254,9 @@ func buildPaymentReturnURL(base string, orderID int64, resumeToken string) (stri
}
func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, error) {
if err := s.ensureSigningKey(); err != nil {
return "", err
}
if claims.OrderID <= 0 {
return "", fmt.Errorf("resume token requires order id")
}
@@ -250,6 +267,9 @@ func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, er
}
func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, error) {
if err := s.ensureSigningKey(); err != nil {
return nil, err
}
var claims ResumeTokenClaims
if err := s.parseSignedToken(token, &claims); err != nil {
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is invalid")
@@ -261,6 +281,9 @@ func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, err
}
func (s *PaymentResumeService) CreateWeChatPaymentResumeToken(claims WeChatPaymentResumeClaims) (string, error) {
if err := s.ensureSigningKey(); err != nil {
return "", err
}
claims.OpenID = strings.TrimSpace(claims.OpenID)
if claims.OpenID == "" {
return "", fmt.Errorf("wechat payment resume token requires openid")
@@ -282,6 +305,9 @@ func (s *PaymentResumeService) CreateWeChatPaymentResumeToken(claims WeChatPayme
}
func (s *PaymentResumeService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) {
if err := s.ensureSigningKey(); err != nil {
return nil, err
}
var claims WeChatPaymentResumeClaims
if err := s.parseSignedToken(token, &claims); err != nil {
return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token payload is invalid")
@@ -330,11 +356,7 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
}
func (s *PaymentResumeService) sign(payload string) string {
key := s.signingKey
if len(key) == 0 {
key = []byte(paymentResumeFallbackSigningKey)
}
mac := hmac.New(sha256.New, key)
mac := hmac.New(sha256.New, s.signingKey)
_, _ = mac.Write([]byte(payload))
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
}

View File

@@ -4,6 +4,10 @@ package service
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"net/url"
"strconv"
"testing"
@@ -150,6 +154,27 @@ func TestPaymentResumeTokenRoundTrip(t *testing.T) {
}
}
func TestCreateTokenRejectsMissingSigningKey(t *testing.T) {
t.Parallel()
svc := NewPaymentResumeService(nil)
_, err := svc.CreateToken(ResumeTokenClaims{OrderID: 42})
if err == nil {
t.Fatal("CreateToken should reject missing signing key")
}
}
func TestParseTokenRejectsFallbackSignedTokenWhenSigningKeyMissing(t *testing.T) {
t.Parallel()
token := mustCreateFallbackSignedToken(t, ResumeTokenClaims{OrderID: 42, UserID: 7})
svc := NewPaymentResumeService(nil)
_, err := svc.ParseToken(token)
if err == nil {
t.Fatal("ParseToken should reject tokens when signing key is missing")
}
}
func TestWeChatPaymentResumeTokenRoundTrip(t *testing.T) {
t.Parallel()
@@ -183,6 +208,31 @@ func TestWeChatPaymentResumeTokenRoundTrip(t *testing.T) {
}
}
func TestCreateWeChatPaymentResumeTokenRejectsMissingSigningKey(t *testing.T) {
t.Parallel()
svc := NewPaymentResumeService(nil)
_, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{OpenID: "openid-123"})
if err == nil {
t.Fatal("CreateWeChatPaymentResumeToken should reject missing signing key")
}
}
func TestParseWeChatPaymentResumeTokenRejectsFallbackSignedTokenWhenSigningKeyMissing(t *testing.T) {
t.Parallel()
token := mustCreateFallbackSignedToken(t, WeChatPaymentResumeClaims{
TokenType: wechatPaymentResumeTokenType,
OpenID: "openid-123",
PaymentType: payment.TypeWxpay,
})
svc := NewPaymentResumeService(nil)
_, err := svc.ParseWeChatPaymentResumeToken(token)
if err == nil {
t.Fatal("ParseWeChatPaymentResumeToken should reject tokens when signing key is missing")
}
}
func TestNormalizeVisibleMethodSource(t *testing.T) {
t.Parallel()
@@ -315,3 +365,17 @@ func (c *captureLoadBalancer) SelectInstance(_ context.Context, providerKey stri
c.lastPaymentType = paymentType
return &payment.InstanceSelection{ProviderKey: providerKey, SupportedTypes: paymentType}, nil
}
func mustCreateFallbackSignedToken(t *testing.T, claims any) string {
t.Helper()
payload, err := json.Marshal(claims)
if err != nil {
t.Fatalf("marshal claims: %v", err)
}
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
mac := hmac.New(sha256.New, []byte(paymentResumeFallbackSigningKey))
_, _ = mac.Write([]byte(encodedPayload))
signature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
return encodedPayload + "." + signature
}

View File

@@ -16,33 +16,70 @@ import (
// It resolves the original provider instance from the order whenever possible and
// only falls back to a registry provider for legacy/single-instance scenarios.
func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) {
providers, err := s.GetWebhookProviders(ctx, providerKey, outTradeNo)
if err != nil {
return nil, err
}
if len(providers) == 0 {
return nil, payment.ErrProviderNotFound
}
return providers[0], nil
}
// GetWebhookProviders returns provider candidates that can verify the webhook.
// Official WeChat Pay may require multiple candidates because the callback body
// cannot be bound to a merchant before decryption.
func (s *PaymentService) GetWebhookProviders(ctx context.Context, providerKey, outTradeNo string) ([]payment.Provider, error) {
if outTradeNo != "" {
order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx)
if err == nil {
if psHasPinnedProviderInstance(order) {
return s.getPinnedOrderProvider(ctx, order)
prov, err := s.getPinnedOrderProvider(ctx, order)
if err != nil {
return nil, err
}
return []payment.Provider{prov}, nil
}
inst, err := s.getOrderProviderInstance(ctx, order)
if err != nil {
return nil, fmt.Errorf("load order provider instance: %w", err)
}
if inst != nil {
return s.createProviderFromInstance(ctx, inst)
prov, err := s.createProviderFromInstance(ctx, inst)
if err != nil {
return nil, err
}
return []payment.Provider{prov}, nil
}
if strings.TrimSpace(providerKey) == payment.TypeWxpay {
return s.getEnabledWebhookProvidersByKey(ctx, providerKey)
}
if !s.webhookRegistryFallbackAllowed(ctx, providerKey) {
return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey)
}
s.EnsureProviders(ctx)
return s.registry.GetProviderByKey(providerKey)
prov, err := s.registry.GetProviderByKey(providerKey)
if err != nil {
return nil, err
}
return []payment.Provider{prov}, nil
}
}
if strings.TrimSpace(providerKey) == payment.TypeWxpay {
return s.getEnabledWebhookProvidersByKey(ctx, providerKey)
}
if !s.webhookRegistryFallbackAllowed(ctx, providerKey) {
return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey)
}
s.EnsureProviders(ctx)
return s.registry.GetProviderByKey(providerKey)
prov, err := s.registry.GetProviderByKey(providerKey)
if err != nil {
return nil, err
}
return []payment.Provider{prov}, nil
}
func (s *PaymentService) getPinnedOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
@@ -78,3 +115,34 @@ func (s *PaymentService) webhookRegistryFallbackAllowed(ctx context.Context, pro
func psHasPinnedProviderInstance(order *dbent.PaymentOrder) bool {
return order != nil && order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != ""
}
func (s *PaymentService) getEnabledWebhookProvidersByKey(ctx context.Context, providerKey string) ([]payment.Provider, error) {
providerKey = strings.TrimSpace(providerKey)
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(
paymentproviderinstance.ProviderKeyEQ(providerKey),
paymentproviderinstance.EnabledEQ(true),
).
Order(dbent.Asc(paymentproviderinstance.FieldSortOrder)).
All(ctx)
if err != nil {
return nil, fmt.Errorf("query webhook provider instances: %w", err)
}
if len(instances) == 0 {
return nil, payment.ErrProviderNotFound
}
providers := make([]payment.Provider, 0, len(instances))
for _, inst := range instances {
prov, provErr := s.createProviderFromInstance(ctx, inst)
if provErr != nil {
slog.Warn("skip webhook provider instance", "provider", providerKey, "instanceID", inst.ID, "error", provErr)
continue
}
providers = append(providers, prov)
}
if len(providers) == 0 {
return nil, payment.ErrProviderNotFound
}
return providers, nil
}

View File

@@ -208,10 +208,28 @@ func TestGetOrderProviderInstanceLeavesProviderKeyMatchUnresolvedWhenTypeNotSupp
func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
wxpayConfigA := encryptWebhookProviderConfig(t, map[string]string{
"appId": "wx-app-a",
"mchId": "mch-a",
"privateKey": "private-key-a",
"apiV3Key": webhookProviderTestEncryptionKey,
"publicKey": "public-key-a",
"publicKeyId": "public-key-id-a",
"certSerial": "cert-serial-a",
})
wxpayConfigB := encryptWebhookProviderConfig(t, map[string]string{
"appId": "wx-app-b",
"mchId": "mch-b",
"privateKey": "private-key-b",
"apiV3Key": webhookProviderTestEncryptionKey,
"publicKey": "public-key-b",
"publicKeyId": "public-key-id-b",
"certSerial": "cert-serial-b",
})
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("wxpay-a").
SetConfig("{}").
SetConfig(wxpayConfigA).
SetSupportedTypes("wxpay").
SetEnabled(true).
Save(ctx)
@@ -219,19 +237,51 @@ func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) {
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("wxpay-b").
SetConfig("{}").
SetConfig(wxpayConfigB).
SetSupportedTypes("wxpay").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
svc := &PaymentService{
entClient: client,
loadBalancer: newWebhookProviderTestLoadBalancer(client),
registry: payment.NewRegistry(),
providersLoaded: true,
}
providers, err := svc.GetWebhookProviders(ctx, payment.TypeWxpay, "")
require.NoError(t, err)
require.Len(t, providers, 2)
}
func TestGetWebhookProvidersRejectAmbiguousFallbackForNonWxpay(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeAlipay).
SetName("alipay-a").
SetConfig("{}").
SetSupportedTypes("alipay").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeAlipay).
SetName("alipay-b").
SetConfig("{}").
SetSupportedTypes("alipay").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
svc := &PaymentService{
entClient: client,
registry: payment.NewRegistry(),
providersLoaded: true,
}
_, err = svc.GetWebhookProvider(ctx, payment.TypeWxpay, "")
_, err = svc.GetWebhookProviders(ctx, payment.TypeAlipay, "")
require.Error(t, err)
require.Contains(t, err.Error(), "ambiguous")
}
@@ -260,8 +310,10 @@ func TestGetWebhookProviderAllowsSingleInstanceRegistryFallback(t *testing.T) {
providersLoaded: true,
}
prov, err := svc.GetWebhookProvider(ctx, payment.TypeStripe, "")
providers, err := svc.GetWebhookProviders(ctx, payment.TypeStripe, "")
require.NoError(t, err)
require.Len(t, providers, 1)
prov := providers[0]
require.Equal(t, payment.TypeStripe, prov.ProviderKey())
}
@@ -308,7 +360,7 @@ func TestGetWebhookProviderRejectsRegistryFallbackForPinnedOrder(t *testing.T) {
providersLoaded: true,
}
_, err = svc.GetWebhookProvider(ctx, payment.TypeWxpay, "sub2_test_pinned_order")
_, err = svc.GetWebhookProviders(ctx, payment.TypeWxpay, "sub2_test_pinned_order")
require.Error(t, err)
require.Contains(t, err.Error(), "provider instance")
}