fix(payment): restore upgrade-safe payment flows

This commit is contained in:
IanShaw027
2026-04-22 14:57:16 +08:00
parent 36aed35957
commit 1aab084ecb
14 changed files with 645 additions and 68 deletions

View File

@@ -116,6 +116,17 @@ var providerSensitiveConfigFields = map[string]map[string]struct{}{
payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}},
}
// providerPendingOrderProtectedConfigFields lists config keys that cannot be
// changed while the instance has in-progress orders. This includes secrets plus
// all provider identity fields that are snapshotted into orders or used by
// webhook/refund verification.
var providerPendingOrderProtectedConfigFields = map[string]map[string]struct{}{
payment.TypeEasyPay: {"pkey": {}, "pid": {}},
payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}, "appid": {}},
payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}, "appid": {}, "mpappid": {}, "mchid": {}, "publickeyid": {}, "certserial": {}},
payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}},
}
func isSensitiveProviderConfigField(providerKey, fieldName string) bool {
fields, ok := providerSensitiveConfigFields[providerKey]
if !ok {
@@ -125,6 +136,28 @@ func isSensitiveProviderConfigField(providerKey, fieldName string) bool {
return found
}
func hasPendingOrderProtectedConfigChange(providerKey string, currentConfig, nextConfig map[string]string) bool {
fields, ok := providerPendingOrderProtectedConfigFields[providerKey]
if !ok {
return false
}
for fieldName := range fields {
if providerConfigFieldValue(currentConfig, fieldName) != providerConfigFieldValue(nextConfig, fieldName) {
return true
}
}
return false
}
func providerConfigFieldValue(config map[string]string, fieldName string) string {
for key, value := range config {
if strings.EqualFold(key, fieldName) {
return value
}
}
return ""
}
func (s *PaymentConfigService) countPendingOrders(ctx context.Context, providerInstanceID int64) (int, error) {
return s.entClient.PaymentOrder.Query().
Where(
@@ -190,6 +223,18 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
if err != nil {
return nil, fmt.Errorf("load provider instance: %w", err)
}
var pendingOrderCount *int
getPendingOrderCount := func() (int, error) {
if pendingOrderCount != nil {
return *pendingOrderCount, nil
}
count, err := s.countPendingOrders(ctx, id)
if err != nil {
return 0, fmt.Errorf("check pending orders: %w", err)
}
pendingOrderCount = &count
return count, nil
}
nextEnabled := current.Enabled
if req.Enabled != nil {
nextEnabled = *req.Enabled
@@ -201,18 +246,20 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
if err := s.validateVisibleMethodEnablementConflicts(ctx, id, current.ProviderKey, nextSupportedTypes, nextEnabled); err != nil {
return nil, err
}
var mergedConfig map[string]string
if req.Config != nil {
hasSensitive := false
for k, v := range req.Config {
if v != "" && isSensitiveProviderConfigField(current.ProviderKey, k) {
hasSensitive = true
break
}
currentConfig, err := s.decryptConfig(current.Config)
if err != nil {
return nil, fmt.Errorf("decrypt existing config: %w", err)
}
if hasSensitive {
count, err := s.countPendingOrders(ctx, id)
mergedConfig, err = s.mergeConfig(ctx, id, req.Config)
if err != nil {
return nil, err
}
if hasPendingOrderProtectedConfigChange(current.ProviderKey, currentConfig, mergedConfig) {
count, err := getPendingOrderCount()
if err != nil {
return nil, fmt.Errorf("check pending orders: %w", err)
return nil, err
}
if count > 0 {
return nil, infraerrors.Conflict("PENDING_ORDERS", "instance has pending orders").
@@ -221,9 +268,9 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
}
}
if req.Enabled != nil && !*req.Enabled {
count, err := s.countPendingOrders(ctx, id)
count, err := getPendingOrderCount()
if err != nil {
return nil, fmt.Errorf("check pending orders: %w", err)
return nil, err
}
if count > 0 {
return nil, infraerrors.Conflict("PENDING_ORDERS", "instance has pending orders").
@@ -237,13 +284,6 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
if req.Enabled != nil {
finalEnabled = *req.Enabled
}
var mergedConfig map[string]string
if req.Config != nil {
mergedConfig, err = s.mergeConfig(ctx, id, req.Config)
if err != nil {
return nil, err
}
}
if finalEnabled {
configToValidate := mergedConfig
if configToValidate == nil {
@@ -269,9 +309,9 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
}
if req.SupportedTypes != nil {
// Check pending orders before removing payment types
count, err := s.countPendingOrders(ctx, id)
count, err := getPendingOrderCount()
if err != nil {
return nil, fmt.Errorf("check pending orders: %w", err)
return nil, err
}
if count > 0 {
// Load current instance to compare types

View File

@@ -8,8 +8,13 @@ import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"strconv"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -315,10 +320,263 @@ func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) {
require.Equal(t, "alipay,wxpay", saved.SupportedTypes)
}
func TestUpdateProviderInstanceRejectsProtectedConfigChangesWhilePendingOrders(t *testing.T) {
t.Parallel()
tests := []struct {
name string
providerKey string
createConfig func(*testing.T) map[string]string
supportedType []string
updateConfig map[string]string
fieldName string
wantValue string
}{
{
name: "wxpay appId",
providerKey: payment.TypeWxpay,
createConfig: validWxpayProviderConfig,
supportedType: []string{payment.TypeWxpay},
updateConfig: map[string]string{"appId": "wx-app-updated"},
fieldName: "appId",
wantValue: "wx-app-test",
},
{
name: "wxpay mpAppId",
providerKey: payment.TypeWxpay,
createConfig: validWxpayProviderConfigWithJSAPIAppID,
supportedType: []string{payment.TypeWxpay},
updateConfig: map[string]string{"mpAppId": "wx-mp-app-updated"},
fieldName: "mpAppId",
wantValue: "wx-mp-app-test",
},
{
name: "wxpay mchId",
providerKey: payment.TypeWxpay,
createConfig: validWxpayProviderConfig,
supportedType: []string{payment.TypeWxpay},
updateConfig: map[string]string{"mchId": "mch-updated"},
fieldName: "mchId",
wantValue: "mch-test",
},
{
name: "wxpay publicKeyId",
providerKey: payment.TypeWxpay,
createConfig: validWxpayProviderConfig,
supportedType: []string{payment.TypeWxpay},
updateConfig: map[string]string{"publicKeyId": "public-key-id-updated"},
fieldName: "publicKeyId",
wantValue: "public-key-id-test",
},
{
name: "wxpay certSerial",
providerKey: payment.TypeWxpay,
createConfig: validWxpayProviderConfig,
supportedType: []string{payment.TypeWxpay},
updateConfig: map[string]string{"certSerial": "cert-serial-updated"},
fieldName: "certSerial",
wantValue: "cert-serial-test",
},
{
name: "alipay appId",
providerKey: payment.TypeAlipay,
createConfig: validAlipayProviderConfig,
supportedType: []string{payment.TypeAlipay},
updateConfig: map[string]string{"appId": "alipay-app-updated"},
fieldName: "appId",
wantValue: "alipay-app-test",
},
{
name: "easypay pid",
providerKey: payment.TypeEasyPay,
createConfig: validEasyPayProviderConfig,
supportedType: []string{payment.TypeAlipay},
updateConfig: map[string]string{"pid": "pid-updated"},
fieldName: "pid",
wantValue: "pid-test",
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
svc := &PaymentConfigService{
entClient: client,
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
}
instance, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
ProviderKey: tc.providerKey,
Name: "protected-config-instance",
Config: tc.createConfig(t),
SupportedTypes: tc.supportedType,
Enabled: true,
})
require.NoError(t, err)
createPendingProviderConfigOrder(t, ctx, client, instance)
updated, err := svc.UpdateProviderInstance(ctx, instance.ID, UpdateProviderInstanceRequest{
Config: tc.updateConfig,
})
require.Nil(t, updated)
require.Error(t, err)
require.Equal(t, "PENDING_ORDERS", infraerrors.Reason(err))
saved, err := client.PaymentProviderInstance.Get(ctx, instance.ID)
require.NoError(t, err)
cfg, err := svc.decryptConfig(saved.Config)
require.NoError(t, err)
require.Equal(t, tc.wantValue, cfg[tc.fieldName])
})
}
}
func TestUpdateProviderInstanceAllowsSafeConfigChangesWhilePendingOrders(t *testing.T) {
t.Parallel()
tests := []struct {
name string
providerKey string
createConfig func(*testing.T) map[string]string
supportedType []string
updateConfig map[string]string
fieldName string
wantValue string
}{
{
name: "wxpay notifyUrl",
providerKey: payment.TypeWxpay,
createConfig: validWxpayProviderConfig,
supportedType: []string{payment.TypeWxpay},
updateConfig: map[string]string{"notifyUrl": "https://merchant.example.com/wxpay/notify-v2"},
fieldName: "notifyUrl",
wantValue: "https://merchant.example.com/wxpay/notify-v2",
},
{
name: "alipay same appId",
providerKey: payment.TypeAlipay,
createConfig: validAlipayProviderConfig,
supportedType: []string{payment.TypeAlipay},
updateConfig: map[string]string{"appId": "alipay-app-test"},
fieldName: "appId",
wantValue: "alipay-app-test",
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
svc := &PaymentConfigService{
entClient: client,
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
}
instance, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
ProviderKey: tc.providerKey,
Name: "safe-config-instance",
Config: tc.createConfig(t),
SupportedTypes: tc.supportedType,
Enabled: true,
})
require.NoError(t, err)
createPendingProviderConfigOrder(t, ctx, client, instance)
updated, err := svc.UpdateProviderInstance(ctx, instance.ID, UpdateProviderInstanceRequest{
Config: tc.updateConfig,
})
require.NoError(t, err)
require.NotNil(t, updated)
saved, err := client.PaymentProviderInstance.Get(ctx, instance.ID)
require.NoError(t, err)
cfg, err := svc.decryptConfig(saved.Config)
require.NoError(t, err)
require.Equal(t, tc.wantValue, cfg[tc.fieldName])
})
}
}
func createPendingProviderConfigOrder(t *testing.T, ctx context.Context, client *dbent.Client, instance *dbent.PaymentProviderInstance) {
t.Helper()
user, err := client.User.Create().
SetEmail("provider-config-pending@example.com").
SetPasswordHash("hash").
SetUsername("provider-config-pending-user").
Save(ctx)
require.NoError(t, err)
instanceID := strconv.FormatInt(instance.ID, 10)
_, err = client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("PENDING-PROVIDER-CONFIG-" + instanceID).
SetOutTradeNo("sub2_pending_provider_config_" + instanceID).
SetPaymentType(providerPendingOrderPaymentType(instance.ProviderKey)).
SetPaymentTradeNo("").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
SetProviderInstanceID(instanceID).
SetProviderKey(instance.ProviderKey).
Save(ctx)
require.NoError(t, err)
}
func providerPendingOrderPaymentType(providerKey string) string {
switch providerKey {
case payment.TypeWxpay:
return payment.TypeWxpay
case payment.TypeAlipay:
return payment.TypeAlipay
default:
return payment.TypeAlipay
}
}
func boolPtrValue(v bool) *bool {
return &v
}
func validAlipayProviderConfig(t *testing.T) map[string]string {
t.Helper()
return map[string]string{
"appId": "alipay-app-test",
"privateKey": "alipay-private-key-test",
"notifyUrl": "https://merchant.example.com/alipay/notify",
"returnUrl": "https://merchant.example.com/alipay/return",
}
}
func validEasyPayProviderConfig(t *testing.T) map[string]string {
t.Helper()
return map[string]string{
"pid": "pid-test",
"pkey": "pkey-test",
"apiBase": "https://pay.example.com",
"notifyUrl": "https://merchant.example.com/easypay/notify",
"returnUrl": "https://merchant.example.com/easypay/return",
}
}
func validWxpayProviderConfig(t *testing.T) map[string]string {
t.Helper()
@@ -340,3 +598,11 @@ func validWxpayProviderConfig(t *testing.T) map[string]string {
"certSerial": "cert-serial-test",
}
}
func validWxpayProviderConfigWithJSAPIAppID(t *testing.T) map[string]string {
t.Helper()
cfg := validWxpayProviderConfig(t)
cfg["mpAppId"] = "wx-mp-app-test"
return cfg
}

View File

@@ -387,6 +387,45 @@ func TestPaymentServiceParseWeChatPaymentResumeTokenAcceptsLegacyEncryptionKeyDu
}
}
func TestNewConfiguredPaymentResumeServicePrefersExplicitSigningKeyAndKeepsLegacyVerificationFallback(t *testing.T) {
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key")
legacyKey := []byte("0123456789abcdef0123456789abcdef")
svc := newLegacyAwarePaymentResumeService(legacyKey)
explicitToken, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
OpenID: "openid-explicit-key",
PaymentType: payment.TypeWxpay,
})
if err != nil {
t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
}
explicitClaims, err := NewPaymentResumeService([]byte("explicit-payment-resume-signing-key")).ParseWeChatPaymentResumeToken(explicitToken)
if err != nil {
t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
}
if explicitClaims.OpenID != "openid-explicit-key" {
t.Fatalf("openid = %q, want %q", explicitClaims.OpenID, "openid-explicit-key")
}
legacyToken, err := NewPaymentResumeService(legacyKey).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
OpenID: "openid-legacy-key",
PaymentType: payment.TypeWxpay,
})
if err != nil {
t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
}
legacyClaims, err := svc.ParseWeChatPaymentResumeToken(legacyToken)
if err != nil {
t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
}
if legacyClaims.OpenID != "openid-legacy-key" {
t.Fatalf("openid = %q, want %q", legacyClaims.OpenID, "openid-legacy-key")
}
}
func TestNormalizeVisibleMethodSource(t *testing.T) {
t.Parallel()

View File

@@ -268,8 +268,16 @@ func (s *PaymentService) paymentResume() *PaymentResumeService {
return psNewPaymentResumeService(s.configService)
}
func NewLegacyAwarePaymentResumeService(legacyKey []byte) *PaymentResumeService {
return newLegacyAwarePaymentResumeService(legacyKey)
}
func psNewPaymentResumeService(configService *PaymentConfigService) *PaymentResumeService {
signingKey, verifyFallbacks := psResumeSigningKeys(configService)
return newLegacyAwarePaymentResumeService(psResumeLegacyVerificationKey(configService))
}
func newLegacyAwarePaymentResumeService(legacyKey []byte) *PaymentResumeService {
signingKey, verifyFallbacks := resolvePaymentResumeSigningKeys(legacyKey)
return NewPaymentResumeService(signingKey, verifyFallbacks...)
}
@@ -279,8 +287,18 @@ func psResumeSigningKey(configService *PaymentConfigService) []byte {
}
func psResumeSigningKeys(configService *PaymentConfigService) ([]byte, [][]byte) {
return resolvePaymentResumeSigningKeys(psResumeLegacyVerificationKey(configService))
}
func psResumeLegacyVerificationKey(configService *PaymentConfigService) []byte {
if configService == nil {
return nil
}
return configService.encryptionKey
}
func resolvePaymentResumeSigningKeys(legacyKey []byte) ([]byte, [][]byte) {
signingKey := parsePaymentResumeSigningKey(os.Getenv(paymentResumeSigningKeyEnv))
legacyKey := psResumeLegacyVerificationKey(configService)
if len(signingKey) == 0 {
if len(legacyKey) == 0 {
return nil, nil
@@ -293,13 +311,6 @@ func psResumeSigningKeys(configService *PaymentConfigService) ([]byte, [][]byte)
return signingKey, [][]byte{legacyKey}
}
func psResumeLegacyVerificationKey(configService *PaymentConfigService) []byte {
if configService == nil {
return nil
}
return configService.encryptionKey
}
func parsePaymentResumeSigningKey(raw string) []byte {
raw = strings.TrimSpace(raw)
if raw == "" {