fix: harden payment resume and wxpay webhook routing

This commit is contained in:
IanShaw027
2026-04-21 01:40:56 +08:00
parent 0a461d8248
commit 1d8432b8a4
9 changed files with 526 additions and 31 deletions

View File

@@ -216,7 +216,11 @@ func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, user
}
func (s *PaymentService) selectCreateOrderInstance(ctx context.Context, req CreateOrderRequest, cfg *PaymentConfig, payAmount float64) (*payment.InstanceSelection, error) {
sel, err := s.loadBalancer.SelectInstance(ctx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount)
selectCtx, err := s.prepareCreateOrderSelectionContext(ctx, req)
if err != nil {
return nil, err
}
sel, err := s.loadBalancer.SelectInstance(selectCtx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount)
if err != nil {
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment method (%s) is not configured", req.PaymentType))
}
@@ -226,6 +230,44 @@ func (s *PaymentService) selectCreateOrderInstance(ctx context.Context, req Crea
return sel, nil
}
func (s *PaymentService) prepareCreateOrderSelectionContext(ctx context.Context, req CreateOrderRequest) (context.Context, error) {
if !requestNeedsWeChatJSAPICompatibility(req) {
return ctx, nil
}
if !s.usesOfficialWxpayVisibleMethod(ctx) {
return ctx, nil
}
expectedAppID, _, err := s.getWeChatPaymentOAuthCredential(ctx)
if err != nil {
return nil, err
}
return payment.WithWxpayJSAPIAppID(ctx, expectedAppID), nil
}
func requestNeedsWeChatJSAPICompatibility(req CreateOrderRequest) bool {
if payment.GetBasePaymentType(req.PaymentType) != payment.TypeWxpay {
return false
}
return req.IsWeChatBrowser || strings.TrimSpace(req.OpenID) != ""
}
func (s *PaymentService) usesOfficialWxpayVisibleMethod(ctx context.Context) bool {
if s == nil || s.configService == nil || s.configService.settingRepo == nil {
return false
}
vals, err := s.configService.settingRepo.GetMultiple(ctx, []string{
SettingPaymentVisibleMethodWxpayEnabled,
SettingPaymentVisibleMethodWxpaySource,
})
if err != nil {
return false
}
if vals[SettingPaymentVisibleMethodWxpayEnabled] != "true" {
return false
}
return NormalizeVisibleMethodSource(payment.TypeWxpay, vals[SettingPaymentVisibleMethodWxpaySource]) == VisibleMethodSourceOfficialWechat
}
func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, limitAmount float64, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan, sel *payment.InstanceSelection) (*CreateOrderResponse, error) {
prov, err := provider.CreateProvider(sel.ProviderKey, sel.InstanceID, sel.Config)
if err != nil {
@@ -239,16 +281,18 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
}
resumeToken := ""
if resume := s.paymentResume(); resume != nil {
resumeToken, err = resume.CreateToken(ResumeTokenClaims{
OrderID: order.ID,
UserID: order.UserID,
ProviderInstanceID: sel.InstanceID,
ProviderKey: sel.ProviderKey,
PaymentType: req.PaymentType,
CanonicalReturnURL: canonicalReturnURL,
})
if err != nil {
return nil, fmt.Errorf("create payment resume token: %w", err)
if resume.isSigningConfigured() {
resumeToken, err = resume.CreateToken(ResumeTokenClaims{
OrderID: order.ID,
UserID: order.UserID,
ProviderInstanceID: sel.InstanceID,
ProviderKey: sel.ProviderKey,
PaymentType: req.PaymentType,
CanonicalReturnURL: canonicalReturnURL,
})
if err != nil {
return nil, fmt.Errorf("create payment resume token: %w", err)
}
}
}
providerReturnURL, err := buildPaymentReturnURL(canonicalReturnURL, order.ID, resumeToken)

View File

@@ -0,0 +1,112 @@
package service
import (
"context"
"encoding/json"
"fmt"
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
const jsapiTestEncryptionKey = "0123456789abcdef0123456789abcdef"
func TestSelectCreateOrderInstancePrefersJSAPICompatibleWxpayInstance(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
compatibleConfig := mustEncryptJSAPITestConfig(t, map[string]string{
"appId": "wx-merchant-app",
"mpAppId": "wx-mp-app",
"mchId": "mch-compatible",
"privateKey": "private-key",
"apiV3Key": jsapiTestEncryptionKey,
"publicKey": "public-key",
"publicKeyId": "key-compatible",
"certSerial": "serial-compatible",
})
incompatibleConfig := mustEncryptJSAPITestConfig(t, map[string]string{
"appId": "wx-merchant-other",
"mpAppId": "wx-mp-other",
"mchId": "mch-incompatible",
"privateKey": "private-key",
"apiV3Key": jsapiTestEncryptionKey,
"publicKey": "public-key",
"publicKeyId": "key-incompatible",
"certSerial": "serial-incompatible",
})
compatible, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("wxpay-compatible").
SetConfig(compatibleConfig).
SetSupportedTypes("wxpay").
SetEnabled(true).
SetSortOrder(1).
Save(ctx)
if err != nil {
t.Fatalf("create compatible wxpay instance: %v", err)
}
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("wxpay-incompatible").
SetConfig(incompatibleConfig).
SetSupportedTypes("wxpay").
SetEnabled(true).
SetSortOrder(2).
Save(ctx)
if err != nil {
t.Fatalf("create incompatible wxpay instance: %v", err)
}
configService := &PaymentConfigService{
entClient: client,
settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{
SettingPaymentVisibleMethodWxpayEnabled: "true",
SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat,
}},
encryptionKey: []byte(jsapiTestEncryptionKey),
}
loadBalancer := newVisibleMethodLoadBalancer(
payment.NewDefaultLoadBalancer(client, []byte(jsapiTestEncryptionKey)),
configService,
)
svc := &PaymentService{
entClient: client,
loadBalancer: loadBalancer,
configService: configService,
}
t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx-mp-app")
t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wechat-secret")
sel, err := svc.selectCreateOrderInstance(ctx, CreateOrderRequest{
PaymentType: payment.TypeWxpay,
OpenID: "openid-123",
IsWeChatBrowser: true,
}, &PaymentConfig{LoadBalanceStrategy: string(payment.StrategyRoundRobin)}, 12.5)
if err != nil {
t.Fatalf("selectCreateOrderInstance returned error: %v", err)
}
if sel == nil {
t.Fatal("expected selected instance, got nil")
}
expectedInstanceID := fmt.Sprintf("%d", compatible.ID)
if sel.InstanceID != expectedInstanceID {
t.Fatalf("selected instance id = %q, want %q", sel.InstanceID, expectedInstanceID)
}
}
func mustEncryptJSAPITestConfig(t *testing.T, config map[string]string) string {
t.Helper()
data, err := json.Marshal(config)
if err != nil {
t.Fatalf("marshal config: %v", err)
}
encrypted, err := payment.Encrypt(string(data), []byte(jsapiTestEncryptionKey))
if err != nil {
t.Fatalf("encrypt config: %v", err)
}
return encrypted
}

View File

@@ -33,6 +33,9 @@ const (
VisibleMethodSourceEasyPayWechat = "easypay_wxpay"
wechatPaymentResumeTokenType = "wechat_payment_resume"
paymentResumeNotConfiguredCode = "PAYMENT_RESUME_NOT_CONFIGURED"
paymentResumeNotConfiguredMessage = "payment resume tokens require a configured signing key"
)
type ResumeTokenClaims struct {
@@ -70,6 +73,17 @@ func NewPaymentResumeService(signingKey []byte) *PaymentResumeService {
return &PaymentResumeService{signingKey: signingKey}
}
func (s *PaymentResumeService) isSigningConfigured() bool {
return s != nil && len(s.signingKey) > 0
}
func (s *PaymentResumeService) ensureSigningKey() error {
if s.isSigningConfigured() {
return nil
}
return infraerrors.ServiceUnavailable(paymentResumeNotConfiguredCode, paymentResumeNotConfiguredMessage)
}
func NormalizeVisibleMethod(method string) string {
return payment.GetBasePaymentType(strings.TrimSpace(method))
}
@@ -240,6 +254,9 @@ func buildPaymentReturnURL(base string, orderID int64, resumeToken string) (stri
}
func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, error) {
if err := s.ensureSigningKey(); err != nil {
return "", err
}
if claims.OrderID <= 0 {
return "", fmt.Errorf("resume token requires order id")
}
@@ -250,6 +267,9 @@ func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, er
}
func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, error) {
if err := s.ensureSigningKey(); err != nil {
return nil, err
}
var claims ResumeTokenClaims
if err := s.parseSignedToken(token, &claims); err != nil {
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is invalid")
@@ -261,6 +281,9 @@ func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, err
}
func (s *PaymentResumeService) CreateWeChatPaymentResumeToken(claims WeChatPaymentResumeClaims) (string, error) {
if err := s.ensureSigningKey(); err != nil {
return "", err
}
claims.OpenID = strings.TrimSpace(claims.OpenID)
if claims.OpenID == "" {
return "", fmt.Errorf("wechat payment resume token requires openid")
@@ -282,6 +305,9 @@ func (s *PaymentResumeService) CreateWeChatPaymentResumeToken(claims WeChatPayme
}
func (s *PaymentResumeService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) {
if err := s.ensureSigningKey(); err != nil {
return nil, err
}
var claims WeChatPaymentResumeClaims
if err := s.parseSignedToken(token, &claims); err != nil {
return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token payload is invalid")
@@ -330,11 +356,7 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
}
func (s *PaymentResumeService) sign(payload string) string {
key := s.signingKey
if len(key) == 0 {
key = []byte(paymentResumeFallbackSigningKey)
}
mac := hmac.New(sha256.New, key)
mac := hmac.New(sha256.New, s.signingKey)
_, _ = mac.Write([]byte(payload))
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
}

View File

@@ -4,6 +4,10 @@ package service
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"net/url"
"strconv"
"testing"
@@ -150,6 +154,27 @@ func TestPaymentResumeTokenRoundTrip(t *testing.T) {
}
}
func TestCreateTokenRejectsMissingSigningKey(t *testing.T) {
t.Parallel()
svc := NewPaymentResumeService(nil)
_, err := svc.CreateToken(ResumeTokenClaims{OrderID: 42})
if err == nil {
t.Fatal("CreateToken should reject missing signing key")
}
}
func TestParseTokenRejectsFallbackSignedTokenWhenSigningKeyMissing(t *testing.T) {
t.Parallel()
token := mustCreateFallbackSignedToken(t, ResumeTokenClaims{OrderID: 42, UserID: 7})
svc := NewPaymentResumeService(nil)
_, err := svc.ParseToken(token)
if err == nil {
t.Fatal("ParseToken should reject tokens when signing key is missing")
}
}
func TestWeChatPaymentResumeTokenRoundTrip(t *testing.T) {
t.Parallel()
@@ -183,6 +208,31 @@ func TestWeChatPaymentResumeTokenRoundTrip(t *testing.T) {
}
}
func TestCreateWeChatPaymentResumeTokenRejectsMissingSigningKey(t *testing.T) {
t.Parallel()
svc := NewPaymentResumeService(nil)
_, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{OpenID: "openid-123"})
if err == nil {
t.Fatal("CreateWeChatPaymentResumeToken should reject missing signing key")
}
}
func TestParseWeChatPaymentResumeTokenRejectsFallbackSignedTokenWhenSigningKeyMissing(t *testing.T) {
t.Parallel()
token := mustCreateFallbackSignedToken(t, WeChatPaymentResumeClaims{
TokenType: wechatPaymentResumeTokenType,
OpenID: "openid-123",
PaymentType: payment.TypeWxpay,
})
svc := NewPaymentResumeService(nil)
_, err := svc.ParseWeChatPaymentResumeToken(token)
if err == nil {
t.Fatal("ParseWeChatPaymentResumeToken should reject tokens when signing key is missing")
}
}
func TestNormalizeVisibleMethodSource(t *testing.T) {
t.Parallel()
@@ -315,3 +365,17 @@ func (c *captureLoadBalancer) SelectInstance(_ context.Context, providerKey stri
c.lastPaymentType = paymentType
return &payment.InstanceSelection{ProviderKey: providerKey, SupportedTypes: paymentType}, nil
}
func mustCreateFallbackSignedToken(t *testing.T, claims any) string {
t.Helper()
payload, err := json.Marshal(claims)
if err != nil {
t.Fatalf("marshal claims: %v", err)
}
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
mac := hmac.New(sha256.New, []byte(paymentResumeFallbackSigningKey))
_, _ = mac.Write([]byte(encodedPayload))
signature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
return encodedPayload + "." + signature
}

View File

@@ -16,33 +16,70 @@ import (
// It resolves the original provider instance from the order whenever possible and
// only falls back to a registry provider for legacy/single-instance scenarios.
func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) {
providers, err := s.GetWebhookProviders(ctx, providerKey, outTradeNo)
if err != nil {
return nil, err
}
if len(providers) == 0 {
return nil, payment.ErrProviderNotFound
}
return providers[0], nil
}
// GetWebhookProviders returns provider candidates that can verify the webhook.
// Official WeChat Pay may require multiple candidates because the callback body
// cannot be bound to a merchant before decryption.
func (s *PaymentService) GetWebhookProviders(ctx context.Context, providerKey, outTradeNo string) ([]payment.Provider, error) {
if outTradeNo != "" {
order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx)
if err == nil {
if psHasPinnedProviderInstance(order) {
return s.getPinnedOrderProvider(ctx, order)
prov, err := s.getPinnedOrderProvider(ctx, order)
if err != nil {
return nil, err
}
return []payment.Provider{prov}, nil
}
inst, err := s.getOrderProviderInstance(ctx, order)
if err != nil {
return nil, fmt.Errorf("load order provider instance: %w", err)
}
if inst != nil {
return s.createProviderFromInstance(ctx, inst)
prov, err := s.createProviderFromInstance(ctx, inst)
if err != nil {
return nil, err
}
return []payment.Provider{prov}, nil
}
if strings.TrimSpace(providerKey) == payment.TypeWxpay {
return s.getEnabledWebhookProvidersByKey(ctx, providerKey)
}
if !s.webhookRegistryFallbackAllowed(ctx, providerKey) {
return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey)
}
s.EnsureProviders(ctx)
return s.registry.GetProviderByKey(providerKey)
prov, err := s.registry.GetProviderByKey(providerKey)
if err != nil {
return nil, err
}
return []payment.Provider{prov}, nil
}
}
if strings.TrimSpace(providerKey) == payment.TypeWxpay {
return s.getEnabledWebhookProvidersByKey(ctx, providerKey)
}
if !s.webhookRegistryFallbackAllowed(ctx, providerKey) {
return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey)
}
s.EnsureProviders(ctx)
return s.registry.GetProviderByKey(providerKey)
prov, err := s.registry.GetProviderByKey(providerKey)
if err != nil {
return nil, err
}
return []payment.Provider{prov}, nil
}
func (s *PaymentService) getPinnedOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
@@ -78,3 +115,34 @@ func (s *PaymentService) webhookRegistryFallbackAllowed(ctx context.Context, pro
func psHasPinnedProviderInstance(order *dbent.PaymentOrder) bool {
return order != nil && order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != ""
}
func (s *PaymentService) getEnabledWebhookProvidersByKey(ctx context.Context, providerKey string) ([]payment.Provider, error) {
providerKey = strings.TrimSpace(providerKey)
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(
paymentproviderinstance.ProviderKeyEQ(providerKey),
paymentproviderinstance.EnabledEQ(true),
).
Order(dbent.Asc(paymentproviderinstance.FieldSortOrder)).
All(ctx)
if err != nil {
return nil, fmt.Errorf("query webhook provider instances: %w", err)
}
if len(instances) == 0 {
return nil, payment.ErrProviderNotFound
}
providers := make([]payment.Provider, 0, len(instances))
for _, inst := range instances {
prov, provErr := s.createProviderFromInstance(ctx, inst)
if provErr != nil {
slog.Warn("skip webhook provider instance", "provider", providerKey, "instanceID", inst.ID, "error", provErr)
continue
}
providers = append(providers, prov)
}
if len(providers) == 0 {
return nil, payment.ErrProviderNotFound
}
return providers, nil
}

View File

@@ -208,10 +208,28 @@ func TestGetOrderProviderInstanceLeavesProviderKeyMatchUnresolvedWhenTypeNotSupp
func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
wxpayConfigA := encryptWebhookProviderConfig(t, map[string]string{
"appId": "wx-app-a",
"mchId": "mch-a",
"privateKey": "private-key-a",
"apiV3Key": webhookProviderTestEncryptionKey,
"publicKey": "public-key-a",
"publicKeyId": "public-key-id-a",
"certSerial": "cert-serial-a",
})
wxpayConfigB := encryptWebhookProviderConfig(t, map[string]string{
"appId": "wx-app-b",
"mchId": "mch-b",
"privateKey": "private-key-b",
"apiV3Key": webhookProviderTestEncryptionKey,
"publicKey": "public-key-b",
"publicKeyId": "public-key-id-b",
"certSerial": "cert-serial-b",
})
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("wxpay-a").
SetConfig("{}").
SetConfig(wxpayConfigA).
SetSupportedTypes("wxpay").
SetEnabled(true).
Save(ctx)
@@ -219,19 +237,51 @@ func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) {
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("wxpay-b").
SetConfig("{}").
SetConfig(wxpayConfigB).
SetSupportedTypes("wxpay").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
svc := &PaymentService{
entClient: client,
loadBalancer: newWebhookProviderTestLoadBalancer(client),
registry: payment.NewRegistry(),
providersLoaded: true,
}
providers, err := svc.GetWebhookProviders(ctx, payment.TypeWxpay, "")
require.NoError(t, err)
require.Len(t, providers, 2)
}
func TestGetWebhookProvidersRejectAmbiguousFallbackForNonWxpay(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeAlipay).
SetName("alipay-a").
SetConfig("{}").
SetSupportedTypes("alipay").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeAlipay).
SetName("alipay-b").
SetConfig("{}").
SetSupportedTypes("alipay").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
svc := &PaymentService{
entClient: client,
registry: payment.NewRegistry(),
providersLoaded: true,
}
_, err = svc.GetWebhookProvider(ctx, payment.TypeWxpay, "")
_, err = svc.GetWebhookProviders(ctx, payment.TypeAlipay, "")
require.Error(t, err)
require.Contains(t, err.Error(), "ambiguous")
}
@@ -260,8 +310,10 @@ func TestGetWebhookProviderAllowsSingleInstanceRegistryFallback(t *testing.T) {
providersLoaded: true,
}
prov, err := svc.GetWebhookProvider(ctx, payment.TypeStripe, "")
providers, err := svc.GetWebhookProviders(ctx, payment.TypeStripe, "")
require.NoError(t, err)
require.Len(t, providers, 1)
prov := providers[0]
require.Equal(t, payment.TypeStripe, prov.ProviderKey())
}
@@ -308,7 +360,7 @@ func TestGetWebhookProviderRejectsRegistryFallbackForPinnedOrder(t *testing.T) {
providersLoaded: true,
}
_, err = svc.GetWebhookProvider(ctx, payment.TypeWxpay, "sub2_test_pinned_order")
_, err = svc.GetWebhookProviders(ctx, payment.TypeWxpay, "sub2_test_pinned_order")
require.Error(t, err)
require.Contains(t, err.Error(), "provider instance")
}