fix: harden payment resume and wxpay webhook routing
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
@@ -77,9 +79,13 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
|
||||
// This is needed when multiple instances of the same provider exist (e.g. multiple EasyPay accounts).
|
||||
outTradeNo := extractOutTradeNo(rawBody, providerKey)
|
||||
|
||||
provider, err := h.paymentService.GetWebhookProvider(c.Request.Context(), providerKey, outTradeNo)
|
||||
providers, err := h.paymentService.GetWebhookProviders(c.Request.Context(), providerKey, outTradeNo)
|
||||
if err != nil {
|
||||
slog.Warn("[Payment Webhook] provider not found", "provider", providerKey, "outTradeNo", outTradeNo, "error", err)
|
||||
if providerKey == payment.TypeWxpay {
|
||||
c.String(http.StatusBadRequest, "verify failed")
|
||||
return
|
||||
}
|
||||
writeSuccessResponse(c, providerKey)
|
||||
return
|
||||
}
|
||||
@@ -89,7 +95,7 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
|
||||
headers[strings.ToLower(k)] = c.GetHeader(k)
|
||||
}
|
||||
|
||||
notification, err := provider.VerifyNotification(c.Request.Context(), rawBody, headers)
|
||||
resolvedProviderKey, notification, err := verifyNotificationWithProviders(c.Request.Context(), providers, rawBody, headers)
|
||||
if err != nil {
|
||||
truncatedBody := rawBody
|
||||
if len(truncatedBody) > webhookLogTruncateLen {
|
||||
@@ -103,17 +109,17 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
|
||||
|
||||
// nil notification means irrelevant event (e.g. Stripe non-payment event); return success.
|
||||
if notification == nil {
|
||||
writeSuccessResponse(c, providerKey)
|
||||
writeSuccessResponse(c, resolvedProviderKey)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.paymentService.HandlePaymentNotification(c.Request.Context(), notification, providerKey); err != nil {
|
||||
slog.Error("[Payment Webhook] handle notification failed", "provider", providerKey, "error", err)
|
||||
if err := h.paymentService.HandlePaymentNotification(c.Request.Context(), notification, resolvedProviderKey); err != nil {
|
||||
slog.Error("[Payment Webhook] handle notification failed", "provider", resolvedProviderKey, "error", err)
|
||||
c.String(http.StatusInternalServerError, "handle failed")
|
||||
return
|
||||
}
|
||||
|
||||
writeSuccessResponse(c, providerKey)
|
||||
writeSuccessResponse(c, resolvedProviderKey)
|
||||
}
|
||||
|
||||
// extractOutTradeNo parses the webhook body to find the out_trade_no.
|
||||
@@ -131,6 +137,25 @@ func extractOutTradeNo(rawBody, providerKey string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func verifyNotificationWithProviders(ctx context.Context, providers []payment.Provider, rawBody string, headers map[string]string) (string, *payment.PaymentNotification, error) {
|
||||
var lastErr error
|
||||
for _, provider := range providers {
|
||||
if provider == nil {
|
||||
continue
|
||||
}
|
||||
notification, err := provider.VerifyNotification(ctx, rawBody, headers)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
return provider.ProviderKey(), notification, nil
|
||||
}
|
||||
if lastErr != nil {
|
||||
return "", nil, lastErr
|
||||
}
|
||||
return "", nil, fmt.Errorf("no webhook provider could verify notification")
|
||||
}
|
||||
|
||||
// wxpaySuccessResponse is the JSON response expected by WeChat Pay webhook.
|
||||
type wxpaySuccessResponse struct {
|
||||
Code string `json:"code"`
|
||||
|
||||
@@ -3,11 +3,14 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -131,3 +134,70 @@ func TestExtractOutTradeNo(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyNotificationWithProvidersReturnsMatchedProvider(t *testing.T) {
|
||||
firstErr := errors.New("wrong provider")
|
||||
providers := []payment.Provider{
|
||||
webhookHandlerProviderStub{
|
||||
key: payment.TypeWxpay,
|
||||
verifyErr: firstErr,
|
||||
},
|
||||
webhookHandlerProviderStub{
|
||||
key: payment.TypeWxpay,
|
||||
notification: &payment.PaymentNotification{
|
||||
OrderID: "sub2_42",
|
||||
TradeNo: "trade-42",
|
||||
Status: payment.NotificationStatusSuccess,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
providerKey, notification, err := verifyNotificationWithProviders(context.Background(), providers, "{}", map[string]string{"wechatpay-signature": "sig"})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, payment.TypeWxpay, providerKey)
|
||||
require.NotNil(t, notification)
|
||||
require.Equal(t, "sub2_42", notification.OrderID)
|
||||
}
|
||||
|
||||
func TestVerifyNotificationWithProvidersFailsWhenAllProvidersReject(t *testing.T) {
|
||||
providers := []payment.Provider{
|
||||
webhookHandlerProviderStub{
|
||||
key: payment.TypeWxpay,
|
||||
verifyErr: errors.New("verify failed a"),
|
||||
},
|
||||
webhookHandlerProviderStub{
|
||||
key: payment.TypeWxpay,
|
||||
verifyErr: errors.New("verify failed b"),
|
||||
},
|
||||
}
|
||||
|
||||
_, _, err := verifyNotificationWithProviders(context.Background(), providers, "{}", nil)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
type webhookHandlerProviderStub struct {
|
||||
key string
|
||||
notification *payment.PaymentNotification
|
||||
verifyErr error
|
||||
}
|
||||
|
||||
func (p webhookHandlerProviderStub) Name() string { return p.key }
|
||||
func (p webhookHandlerProviderStub) ProviderKey() string { return p.key }
|
||||
func (p webhookHandlerProviderStub) SupportedTypes() []payment.PaymentType {
|
||||
return []payment.PaymentType{payment.PaymentType(p.key)}
|
||||
}
|
||||
func (p webhookHandlerProviderStub) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
|
||||
panic("unexpected call")
|
||||
}
|
||||
func (p webhookHandlerProviderStub) QueryOrder(context.Context, string) (*payment.QueryOrderResponse, error) {
|
||||
panic("unexpected call")
|
||||
}
|
||||
func (p webhookHandlerProviderStub) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) {
|
||||
if p.verifyErr != nil {
|
||||
return nil, p.verifyErr
|
||||
}
|
||||
return p.notification, nil
|
||||
}
|
||||
func (p webhookHandlerProviderStub) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) {
|
||||
panic("unexpected call")
|
||||
}
|
||||
|
||||
@@ -45,11 +45,31 @@ type DefaultLoadBalancer struct {
|
||||
counter atomic.Uint64
|
||||
}
|
||||
|
||||
type contextKey string
|
||||
|
||||
const wxpayJSAPIAppIDContextKey contextKey = "payment.wxpay.jsapi_app_id"
|
||||
|
||||
// NewDefaultLoadBalancer creates a new load balancer.
|
||||
func NewDefaultLoadBalancer(db *dbent.Client, encryptionKey []byte) *DefaultLoadBalancer {
|
||||
return &DefaultLoadBalancer{db: db, encryptionKey: encryptionKey}
|
||||
}
|
||||
|
||||
func WithWxpayJSAPIAppID(ctx context.Context, appID string) context.Context {
|
||||
appID = strings.TrimSpace(appID)
|
||||
if appID == "" {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, wxpayJSAPIAppIDContextKey, appID)
|
||||
}
|
||||
|
||||
func wxpayJSAPIAppIDFromContext(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
appID, _ := ctx.Value(wxpayJSAPIAppIDContextKey).(string)
|
||||
return strings.TrimSpace(appID)
|
||||
}
|
||||
|
||||
// instanceCandidate pairs an instance with its pre-fetched daily usage.
|
||||
type instanceCandidate struct {
|
||||
inst *dbent.PaymentProviderInstance
|
||||
@@ -116,6 +136,7 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances(
|
||||
}
|
||||
|
||||
var matched []*dbent.PaymentProviderInstance
|
||||
expectedWxpayJSAPIAppID := wxpayJSAPIAppIDFromContext(ctx)
|
||||
for _, inst := range instances {
|
||||
// Stripe: match by provider_key because supported_types lists sub-types (card,link,alipay,wxpay),
|
||||
// not "stripe" itself. The checkout page aggregates all sub-types under "stripe".
|
||||
@@ -124,6 +145,16 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances(
|
||||
matched = append(matched, inst)
|
||||
}
|
||||
} else if InstanceSupportsType(inst.SupportedTypes, paymentType) {
|
||||
if expectedWxpayJSAPIAppID != "" && normalizeVisibleMethodSupportType(paymentType) == TypeWxpay && inst.ProviderKey == TypeWxpay {
|
||||
config, cfgErr := lb.decryptConfig(inst.Config)
|
||||
if cfgErr != nil {
|
||||
slog.Warn("skip wxpay instance with unreadable config during jsapi filtering", "instance_id", inst.ID, "error", cfgErr)
|
||||
continue
|
||||
}
|
||||
if resolveWxpayJSAPIAppID(config) != expectedWxpayJSAPIAppID {
|
||||
continue
|
||||
}
|
||||
}
|
||||
matched = append(matched, inst)
|
||||
}
|
||||
}
|
||||
@@ -358,6 +389,13 @@ func legacyVisibleMethodAlias(paymentType PaymentType) PaymentType {
|
||||
}
|
||||
}
|
||||
|
||||
func resolveWxpayJSAPIAppID(config map[string]string) string {
|
||||
if appID := strings.TrimSpace(config["mpAppId"]); appID != "" {
|
||||
return appID
|
||||
}
|
||||
return strings.TrimSpace(config["appId"])
|
||||
}
|
||||
|
||||
// GetInstanceConfig decrypts and returns the configuration for a provider instance by ID.
|
||||
func (lb *DefaultLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) {
|
||||
inst, err := lb.db.PaymentProviderInstance.Get(ctx, instanceID)
|
||||
|
||||
@@ -216,7 +216,11 @@ func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, user
|
||||
}
|
||||
|
||||
func (s *PaymentService) selectCreateOrderInstance(ctx context.Context, req CreateOrderRequest, cfg *PaymentConfig, payAmount float64) (*payment.InstanceSelection, error) {
|
||||
sel, err := s.loadBalancer.SelectInstance(ctx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount)
|
||||
selectCtx, err := s.prepareCreateOrderSelectionContext(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sel, err := s.loadBalancer.SelectInstance(selectCtx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount)
|
||||
if err != nil {
|
||||
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment method (%s) is not configured", req.PaymentType))
|
||||
}
|
||||
@@ -226,6 +230,44 @@ func (s *PaymentService) selectCreateOrderInstance(ctx context.Context, req Crea
|
||||
return sel, nil
|
||||
}
|
||||
|
||||
func (s *PaymentService) prepareCreateOrderSelectionContext(ctx context.Context, req CreateOrderRequest) (context.Context, error) {
|
||||
if !requestNeedsWeChatJSAPICompatibility(req) {
|
||||
return ctx, nil
|
||||
}
|
||||
if !s.usesOfficialWxpayVisibleMethod(ctx) {
|
||||
return ctx, nil
|
||||
}
|
||||
expectedAppID, _, err := s.getWeChatPaymentOAuthCredential(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return payment.WithWxpayJSAPIAppID(ctx, expectedAppID), nil
|
||||
}
|
||||
|
||||
func requestNeedsWeChatJSAPICompatibility(req CreateOrderRequest) bool {
|
||||
if payment.GetBasePaymentType(req.PaymentType) != payment.TypeWxpay {
|
||||
return false
|
||||
}
|
||||
return req.IsWeChatBrowser || strings.TrimSpace(req.OpenID) != ""
|
||||
}
|
||||
|
||||
func (s *PaymentService) usesOfficialWxpayVisibleMethod(ctx context.Context) bool {
|
||||
if s == nil || s.configService == nil || s.configService.settingRepo == nil {
|
||||
return false
|
||||
}
|
||||
vals, err := s.configService.settingRepo.GetMultiple(ctx, []string{
|
||||
SettingPaymentVisibleMethodWxpayEnabled,
|
||||
SettingPaymentVisibleMethodWxpaySource,
|
||||
})
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if vals[SettingPaymentVisibleMethodWxpayEnabled] != "true" {
|
||||
return false
|
||||
}
|
||||
return NormalizeVisibleMethodSource(payment.TypeWxpay, vals[SettingPaymentVisibleMethodWxpaySource]) == VisibleMethodSourceOfficialWechat
|
||||
}
|
||||
|
||||
func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, limitAmount float64, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan, sel *payment.InstanceSelection) (*CreateOrderResponse, error) {
|
||||
prov, err := provider.CreateProvider(sel.ProviderKey, sel.InstanceID, sel.Config)
|
||||
if err != nil {
|
||||
@@ -239,16 +281,18 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
|
||||
}
|
||||
resumeToken := ""
|
||||
if resume := s.paymentResume(); resume != nil {
|
||||
resumeToken, err = resume.CreateToken(ResumeTokenClaims{
|
||||
OrderID: order.ID,
|
||||
UserID: order.UserID,
|
||||
ProviderInstanceID: sel.InstanceID,
|
||||
ProviderKey: sel.ProviderKey,
|
||||
PaymentType: req.PaymentType,
|
||||
CanonicalReturnURL: canonicalReturnURL,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create payment resume token: %w", err)
|
||||
if resume.isSigningConfigured() {
|
||||
resumeToken, err = resume.CreateToken(ResumeTokenClaims{
|
||||
OrderID: order.ID,
|
||||
UserID: order.UserID,
|
||||
ProviderInstanceID: sel.InstanceID,
|
||||
ProviderKey: sel.ProviderKey,
|
||||
PaymentType: req.PaymentType,
|
||||
CanonicalReturnURL: canonicalReturnURL,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create payment resume token: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
providerReturnURL, err := buildPaymentReturnURL(canonicalReturnURL, order.ID, resumeToken)
|
||||
|
||||
112
backend/internal/service/payment_order_jsapi_test.go
Normal file
112
backend/internal/service/payment_order_jsapi_test.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
)
|
||||
|
||||
const jsapiTestEncryptionKey = "0123456789abcdef0123456789abcdef"
|
||||
|
||||
func TestSelectCreateOrderInstancePrefersJSAPICompatibleWxpayInstance(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := newPaymentConfigServiceTestClient(t)
|
||||
|
||||
compatibleConfig := mustEncryptJSAPITestConfig(t, map[string]string{
|
||||
"appId": "wx-merchant-app",
|
||||
"mpAppId": "wx-mp-app",
|
||||
"mchId": "mch-compatible",
|
||||
"privateKey": "private-key",
|
||||
"apiV3Key": jsapiTestEncryptionKey,
|
||||
"publicKey": "public-key",
|
||||
"publicKeyId": "key-compatible",
|
||||
"certSerial": "serial-compatible",
|
||||
})
|
||||
incompatibleConfig := mustEncryptJSAPITestConfig(t, map[string]string{
|
||||
"appId": "wx-merchant-other",
|
||||
"mpAppId": "wx-mp-other",
|
||||
"mchId": "mch-incompatible",
|
||||
"privateKey": "private-key",
|
||||
"apiV3Key": jsapiTestEncryptionKey,
|
||||
"publicKey": "public-key",
|
||||
"publicKeyId": "key-incompatible",
|
||||
"certSerial": "serial-incompatible",
|
||||
})
|
||||
|
||||
compatible, err := client.PaymentProviderInstance.Create().
|
||||
SetProviderKey(payment.TypeWxpay).
|
||||
SetName("wxpay-compatible").
|
||||
SetConfig(compatibleConfig).
|
||||
SetSupportedTypes("wxpay").
|
||||
SetEnabled(true).
|
||||
SetSortOrder(1).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("create compatible wxpay instance: %v", err)
|
||||
}
|
||||
_, err = client.PaymentProviderInstance.Create().
|
||||
SetProviderKey(payment.TypeWxpay).
|
||||
SetName("wxpay-incompatible").
|
||||
SetConfig(incompatibleConfig).
|
||||
SetSupportedTypes("wxpay").
|
||||
SetEnabled(true).
|
||||
SetSortOrder(2).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("create incompatible wxpay instance: %v", err)
|
||||
}
|
||||
|
||||
configService := &PaymentConfigService{
|
||||
entClient: client,
|
||||
settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{
|
||||
SettingPaymentVisibleMethodWxpayEnabled: "true",
|
||||
SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat,
|
||||
}},
|
||||
encryptionKey: []byte(jsapiTestEncryptionKey),
|
||||
}
|
||||
loadBalancer := newVisibleMethodLoadBalancer(
|
||||
payment.NewDefaultLoadBalancer(client, []byte(jsapiTestEncryptionKey)),
|
||||
configService,
|
||||
)
|
||||
svc := &PaymentService{
|
||||
entClient: client,
|
||||
loadBalancer: loadBalancer,
|
||||
configService: configService,
|
||||
}
|
||||
|
||||
t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx-mp-app")
|
||||
t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wechat-secret")
|
||||
|
||||
sel, err := svc.selectCreateOrderInstance(ctx, CreateOrderRequest{
|
||||
PaymentType: payment.TypeWxpay,
|
||||
OpenID: "openid-123",
|
||||
IsWeChatBrowser: true,
|
||||
}, &PaymentConfig{LoadBalanceStrategy: string(payment.StrategyRoundRobin)}, 12.5)
|
||||
if err != nil {
|
||||
t.Fatalf("selectCreateOrderInstance returned error: %v", err)
|
||||
}
|
||||
if sel == nil {
|
||||
t.Fatal("expected selected instance, got nil")
|
||||
}
|
||||
expectedInstanceID := fmt.Sprintf("%d", compatible.ID)
|
||||
if sel.InstanceID != expectedInstanceID {
|
||||
t.Fatalf("selected instance id = %q, want %q", sel.InstanceID, expectedInstanceID)
|
||||
}
|
||||
}
|
||||
|
||||
func mustEncryptJSAPITestConfig(t *testing.T, config map[string]string) string {
|
||||
t.Helper()
|
||||
|
||||
data, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal config: %v", err)
|
||||
}
|
||||
encrypted, err := payment.Encrypt(string(data), []byte(jsapiTestEncryptionKey))
|
||||
if err != nil {
|
||||
t.Fatalf("encrypt config: %v", err)
|
||||
}
|
||||
return encrypted
|
||||
}
|
||||
@@ -33,6 +33,9 @@ const (
|
||||
VisibleMethodSourceEasyPayWechat = "easypay_wxpay"
|
||||
|
||||
wechatPaymentResumeTokenType = "wechat_payment_resume"
|
||||
|
||||
paymentResumeNotConfiguredCode = "PAYMENT_RESUME_NOT_CONFIGURED"
|
||||
paymentResumeNotConfiguredMessage = "payment resume tokens require a configured signing key"
|
||||
)
|
||||
|
||||
type ResumeTokenClaims struct {
|
||||
@@ -70,6 +73,17 @@ func NewPaymentResumeService(signingKey []byte) *PaymentResumeService {
|
||||
return &PaymentResumeService{signingKey: signingKey}
|
||||
}
|
||||
|
||||
func (s *PaymentResumeService) isSigningConfigured() bool {
|
||||
return s != nil && len(s.signingKey) > 0
|
||||
}
|
||||
|
||||
func (s *PaymentResumeService) ensureSigningKey() error {
|
||||
if s.isSigningConfigured() {
|
||||
return nil
|
||||
}
|
||||
return infraerrors.ServiceUnavailable(paymentResumeNotConfiguredCode, paymentResumeNotConfiguredMessage)
|
||||
}
|
||||
|
||||
func NormalizeVisibleMethod(method string) string {
|
||||
return payment.GetBasePaymentType(strings.TrimSpace(method))
|
||||
}
|
||||
@@ -240,6 +254,9 @@ func buildPaymentReturnURL(base string, orderID int64, resumeToken string) (stri
|
||||
}
|
||||
|
||||
func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, error) {
|
||||
if err := s.ensureSigningKey(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if claims.OrderID <= 0 {
|
||||
return "", fmt.Errorf("resume token requires order id")
|
||||
}
|
||||
@@ -250,6 +267,9 @@ func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, er
|
||||
}
|
||||
|
||||
func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, error) {
|
||||
if err := s.ensureSigningKey(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var claims ResumeTokenClaims
|
||||
if err := s.parseSignedToken(token, &claims); err != nil {
|
||||
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is invalid")
|
||||
@@ -261,6 +281,9 @@ func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, err
|
||||
}
|
||||
|
||||
func (s *PaymentResumeService) CreateWeChatPaymentResumeToken(claims WeChatPaymentResumeClaims) (string, error) {
|
||||
if err := s.ensureSigningKey(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
claims.OpenID = strings.TrimSpace(claims.OpenID)
|
||||
if claims.OpenID == "" {
|
||||
return "", fmt.Errorf("wechat payment resume token requires openid")
|
||||
@@ -282,6 +305,9 @@ func (s *PaymentResumeService) CreateWeChatPaymentResumeToken(claims WeChatPayme
|
||||
}
|
||||
|
||||
func (s *PaymentResumeService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) {
|
||||
if err := s.ensureSigningKey(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var claims WeChatPaymentResumeClaims
|
||||
if err := s.parseSignedToken(token, &claims); err != nil {
|
||||
return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token payload is invalid")
|
||||
@@ -330,11 +356,7 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
|
||||
}
|
||||
|
||||
func (s *PaymentResumeService) sign(payload string) string {
|
||||
key := s.signingKey
|
||||
if len(key) == 0 {
|
||||
key = []byte(paymentResumeFallbackSigningKey)
|
||||
}
|
||||
mac := hmac.New(sha256.New, key)
|
||||
mac := hmac.New(sha256.New, s.signingKey)
|
||||
_, _ = mac.Write([]byte(payload))
|
||||
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
|
||||
}
|
||||
|
||||
@@ -4,6 +4,10 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"testing"
|
||||
@@ -150,6 +154,27 @@ func TestPaymentResumeTokenRoundTrip(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateTokenRejectsMissingSigningKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewPaymentResumeService(nil)
|
||||
_, err := svc.CreateToken(ResumeTokenClaims{OrderID: 42})
|
||||
if err == nil {
|
||||
t.Fatal("CreateToken should reject missing signing key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTokenRejectsFallbackSignedTokenWhenSigningKeyMissing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
token := mustCreateFallbackSignedToken(t, ResumeTokenClaims{OrderID: 42, UserID: 7})
|
||||
svc := NewPaymentResumeService(nil)
|
||||
_, err := svc.ParseToken(token)
|
||||
if err == nil {
|
||||
t.Fatal("ParseToken should reject tokens when signing key is missing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeChatPaymentResumeTokenRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -183,6 +208,31 @@ func TestWeChatPaymentResumeTokenRoundTrip(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateWeChatPaymentResumeTokenRejectsMissingSigningKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewPaymentResumeService(nil)
|
||||
_, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{OpenID: "openid-123"})
|
||||
if err == nil {
|
||||
t.Fatal("CreateWeChatPaymentResumeToken should reject missing signing key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseWeChatPaymentResumeTokenRejectsFallbackSignedTokenWhenSigningKeyMissing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
token := mustCreateFallbackSignedToken(t, WeChatPaymentResumeClaims{
|
||||
TokenType: wechatPaymentResumeTokenType,
|
||||
OpenID: "openid-123",
|
||||
PaymentType: payment.TypeWxpay,
|
||||
})
|
||||
svc := NewPaymentResumeService(nil)
|
||||
_, err := svc.ParseWeChatPaymentResumeToken(token)
|
||||
if err == nil {
|
||||
t.Fatal("ParseWeChatPaymentResumeToken should reject tokens when signing key is missing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeVisibleMethodSource(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -315,3 +365,17 @@ func (c *captureLoadBalancer) SelectInstance(_ context.Context, providerKey stri
|
||||
c.lastPaymentType = paymentType
|
||||
return &payment.InstanceSelection{ProviderKey: providerKey, SupportedTypes: paymentType}, nil
|
||||
}
|
||||
|
||||
func mustCreateFallbackSignedToken(t *testing.T, claims any) string {
|
||||
t.Helper()
|
||||
|
||||
payload, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal claims: %v", err)
|
||||
}
|
||||
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
|
||||
mac := hmac.New(sha256.New, []byte(paymentResumeFallbackSigningKey))
|
||||
_, _ = mac.Write([]byte(encodedPayload))
|
||||
signature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
|
||||
return encodedPayload + "." + signature
|
||||
}
|
||||
|
||||
@@ -16,33 +16,70 @@ import (
|
||||
// It resolves the original provider instance from the order whenever possible and
|
||||
// only falls back to a registry provider for legacy/single-instance scenarios.
|
||||
func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) {
|
||||
providers, err := s.GetWebhookProviders(ctx, providerKey, outTradeNo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(providers) == 0 {
|
||||
return nil, payment.ErrProviderNotFound
|
||||
}
|
||||
return providers[0], nil
|
||||
}
|
||||
|
||||
// GetWebhookProviders returns provider candidates that can verify the webhook.
|
||||
// Official WeChat Pay may require multiple candidates because the callback body
|
||||
// cannot be bound to a merchant before decryption.
|
||||
func (s *PaymentService) GetWebhookProviders(ctx context.Context, providerKey, outTradeNo string) ([]payment.Provider, error) {
|
||||
if outTradeNo != "" {
|
||||
order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx)
|
||||
if err == nil {
|
||||
if psHasPinnedProviderInstance(order) {
|
||||
return s.getPinnedOrderProvider(ctx, order)
|
||||
prov, err := s.getPinnedOrderProvider(ctx, order)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []payment.Provider{prov}, nil
|
||||
}
|
||||
inst, err := s.getOrderProviderInstance(ctx, order)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load order provider instance: %w", err)
|
||||
}
|
||||
if inst != nil {
|
||||
return s.createProviderFromInstance(ctx, inst)
|
||||
prov, err := s.createProviderFromInstance(ctx, inst)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []payment.Provider{prov}, nil
|
||||
}
|
||||
if strings.TrimSpace(providerKey) == payment.TypeWxpay {
|
||||
return s.getEnabledWebhookProvidersByKey(ctx, providerKey)
|
||||
}
|
||||
if !s.webhookRegistryFallbackAllowed(ctx, providerKey) {
|
||||
return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey)
|
||||
}
|
||||
s.EnsureProviders(ctx)
|
||||
return s.registry.GetProviderByKey(providerKey)
|
||||
prov, err := s.registry.GetProviderByKey(providerKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []payment.Provider{prov}, nil
|
||||
}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(providerKey) == payment.TypeWxpay {
|
||||
return s.getEnabledWebhookProvidersByKey(ctx, providerKey)
|
||||
}
|
||||
|
||||
if !s.webhookRegistryFallbackAllowed(ctx, providerKey) {
|
||||
return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey)
|
||||
}
|
||||
|
||||
s.EnsureProviders(ctx)
|
||||
return s.registry.GetProviderByKey(providerKey)
|
||||
prov, err := s.registry.GetProviderByKey(providerKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []payment.Provider{prov}, nil
|
||||
}
|
||||
|
||||
func (s *PaymentService) getPinnedOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
|
||||
@@ -78,3 +115,34 @@ func (s *PaymentService) webhookRegistryFallbackAllowed(ctx context.Context, pro
|
||||
func psHasPinnedProviderInstance(order *dbent.PaymentOrder) bool {
|
||||
return order != nil && order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != ""
|
||||
}
|
||||
|
||||
func (s *PaymentService) getEnabledWebhookProvidersByKey(ctx context.Context, providerKey string) ([]payment.Provider, error) {
|
||||
providerKey = strings.TrimSpace(providerKey)
|
||||
instances, err := s.entClient.PaymentProviderInstance.Query().
|
||||
Where(
|
||||
paymentproviderinstance.ProviderKeyEQ(providerKey),
|
||||
paymentproviderinstance.EnabledEQ(true),
|
||||
).
|
||||
Order(dbent.Asc(paymentproviderinstance.FieldSortOrder)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query webhook provider instances: %w", err)
|
||||
}
|
||||
if len(instances) == 0 {
|
||||
return nil, payment.ErrProviderNotFound
|
||||
}
|
||||
|
||||
providers := make([]payment.Provider, 0, len(instances))
|
||||
for _, inst := range instances {
|
||||
prov, provErr := s.createProviderFromInstance(ctx, inst)
|
||||
if provErr != nil {
|
||||
slog.Warn("skip webhook provider instance", "provider", providerKey, "instanceID", inst.ID, "error", provErr)
|
||||
continue
|
||||
}
|
||||
providers = append(providers, prov)
|
||||
}
|
||||
if len(providers) == 0 {
|
||||
return nil, payment.ErrProviderNotFound
|
||||
}
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
@@ -208,10 +208,28 @@ func TestGetOrderProviderInstanceLeavesProviderKeyMatchUnresolvedWhenTypeNotSupp
|
||||
func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := newPaymentConfigServiceTestClient(t)
|
||||
wxpayConfigA := encryptWebhookProviderConfig(t, map[string]string{
|
||||
"appId": "wx-app-a",
|
||||
"mchId": "mch-a",
|
||||
"privateKey": "private-key-a",
|
||||
"apiV3Key": webhookProviderTestEncryptionKey,
|
||||
"publicKey": "public-key-a",
|
||||
"publicKeyId": "public-key-id-a",
|
||||
"certSerial": "cert-serial-a",
|
||||
})
|
||||
wxpayConfigB := encryptWebhookProviderConfig(t, map[string]string{
|
||||
"appId": "wx-app-b",
|
||||
"mchId": "mch-b",
|
||||
"privateKey": "private-key-b",
|
||||
"apiV3Key": webhookProviderTestEncryptionKey,
|
||||
"publicKey": "public-key-b",
|
||||
"publicKeyId": "public-key-id-b",
|
||||
"certSerial": "cert-serial-b",
|
||||
})
|
||||
_, err := client.PaymentProviderInstance.Create().
|
||||
SetProviderKey(payment.TypeWxpay).
|
||||
SetName("wxpay-a").
|
||||
SetConfig("{}").
|
||||
SetConfig(wxpayConfigA).
|
||||
SetSupportedTypes("wxpay").
|
||||
SetEnabled(true).
|
||||
Save(ctx)
|
||||
@@ -219,19 +237,51 @@ func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) {
|
||||
_, err = client.PaymentProviderInstance.Create().
|
||||
SetProviderKey(payment.TypeWxpay).
|
||||
SetName("wxpay-b").
|
||||
SetConfig("{}").
|
||||
SetConfig(wxpayConfigB).
|
||||
SetSupportedTypes("wxpay").
|
||||
SetEnabled(true).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
svc := &PaymentService{
|
||||
entClient: client,
|
||||
loadBalancer: newWebhookProviderTestLoadBalancer(client),
|
||||
registry: payment.NewRegistry(),
|
||||
providersLoaded: true,
|
||||
}
|
||||
|
||||
providers, err := svc.GetWebhookProviders(ctx, payment.TypeWxpay, "")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, providers, 2)
|
||||
}
|
||||
|
||||
func TestGetWebhookProvidersRejectAmbiguousFallbackForNonWxpay(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := newPaymentConfigServiceTestClient(t)
|
||||
_, err := client.PaymentProviderInstance.Create().
|
||||
SetProviderKey(payment.TypeAlipay).
|
||||
SetName("alipay-a").
|
||||
SetConfig("{}").
|
||||
SetSupportedTypes("alipay").
|
||||
SetEnabled(true).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
_, err = client.PaymentProviderInstance.Create().
|
||||
SetProviderKey(payment.TypeAlipay).
|
||||
SetName("alipay-b").
|
||||
SetConfig("{}").
|
||||
SetSupportedTypes("alipay").
|
||||
SetEnabled(true).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
svc := &PaymentService{
|
||||
entClient: client,
|
||||
registry: payment.NewRegistry(),
|
||||
providersLoaded: true,
|
||||
}
|
||||
|
||||
_, err = svc.GetWebhookProvider(ctx, payment.TypeWxpay, "")
|
||||
_, err = svc.GetWebhookProviders(ctx, payment.TypeAlipay, "")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "ambiguous")
|
||||
}
|
||||
@@ -260,8 +310,10 @@ func TestGetWebhookProviderAllowsSingleInstanceRegistryFallback(t *testing.T) {
|
||||
providersLoaded: true,
|
||||
}
|
||||
|
||||
prov, err := svc.GetWebhookProvider(ctx, payment.TypeStripe, "")
|
||||
providers, err := svc.GetWebhookProviders(ctx, payment.TypeStripe, "")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, providers, 1)
|
||||
prov := providers[0]
|
||||
require.Equal(t, payment.TypeStripe, prov.ProviderKey())
|
||||
}
|
||||
|
||||
@@ -308,7 +360,7 @@ func TestGetWebhookProviderRejectsRegistryFallbackForPinnedOrder(t *testing.T) {
|
||||
providersLoaded: true,
|
||||
}
|
||||
|
||||
_, err = svc.GetWebhookProvider(ctx, payment.TypeWxpay, "sub2_test_pinned_order")
|
||||
_, err = svc.GetWebhookProviders(ctx, payment.TypeWxpay, "sub2_test_pinned_order")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "provider instance")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user