fix(payment): support source routing and compatible resume signing

This commit is contained in:
IanShaw027
2026-04-22 12:30:17 +08:00
parent b2e0712190
commit d6a04bb772
12 changed files with 570 additions and 136 deletions

View File

@@ -164,9 +164,8 @@ func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) {
}
func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
db, err := sql.Open("sqlite", "file:payment_handler_public_resolve?mode=memory&cache=shared")
require.NoError(t, err)
@@ -250,3 +249,120 @@ func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing
require.Contains(t, resp.Data, "expires_at")
require.Contains(t, resp.Data, "refund_amount")
}
func TestResolveOrderPublicByResumeTokenReturnsBadRequestForMismatchedToken(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
db, err := sql.Open("sqlite", "file:payment_handler_public_resolve_mismatch?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
user, err := client.User.Create().
SetEmail("public-resolve-mismatch@example.com").
SetPasswordHash("hash").
SetUsername("public-resolve-mismatch-user").
Save(context.Background())
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(100).
SetPayAmount(103).
SetFeeRate(0.03).
SetRechargeCode("PUBLIC-RESOLVE-MISMATCH").
SetOutTradeNo("resolve-order-mismatch-no").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("trade-public-resolve-mismatch").
SetOrderType(payment.OrderTypeBalance).
SetStatus(service.OrderStatusPaid).
SetExpiresAt(time.Now().Add(time.Hour)).
SetPaidAt(time.Now()).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(context.Background())
require.NoError(t, err)
resumeSvc := service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := resumeSvc.CreateToken(service.ResumeTokenClaims{
OrderID: order.ID,
UserID: user.ID + 999,
PaymentType: payment.TypeAlipay,
CanonicalReturnURL: "https://app.example.com/payment/result",
})
require.NoError(t, err)
configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(
http.MethodPost,
"/api/v1/payment/public/orders/resolve",
bytes.NewBufferString(`{"resume_token":"`+token+`"}`),
)
ctx.Request.Header.Set("Content-Type", "application/json")
h.ResolveOrderPublicByResumeToken(ctx)
require.Equal(t, http.StatusBadRequest, recorder.Code)
var resp struct {
Code int `json:"code"`
Reason string `json:"reason"`
Message string `json:"message"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, http.StatusBadRequest, resp.Code)
require.Equal(t, "INVALID_RESUME_TOKEN", resp.Reason)
}
func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) {
gin.SetMode(gin.TestMode)
db, err := sql.Open("sqlite", "file:payment_handler_public_verify_blank?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(
http.MethodPost,
"/api/v1/payment/public/orders/verify",
bytes.NewBufferString(`{"out_trade_no":" "}`),
)
ctx.Request.Header.Set("Content-Type", "application/json")
h.VerifyOrderPublic(ctx)
require.Equal(t, http.StatusBadRequest, recorder.Code)
var resp struct {
Code int `json:"code"`
Reason string `json:"reason"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, http.StatusBadRequest, resp.Code)
require.Equal(t, "INVALID_OUT_TRADE_NO", resp.Reason)
}

View File

@@ -20,7 +20,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
return nil, fmt.Errorf("query provider instances: %w", err)
}
typeInstances := pcGroupByPaymentType(instances)
typeInstances = pcApplyEnabledVisibleMethodInstances(typeInstances, instances)
typeInstances = s.pcApplyEnabledVisibleMethodInstances(ctx, typeInstances, instances)
resp := &MethodLimitsResponse{
Methods: make(map[string]MethodLimits, len(typeInstances)),
}
@@ -32,7 +32,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
return resp, nil
}
func pcApplyEnabledVisibleMethodInstances(typeInstances map[string][]*dbent.PaymentProviderInstance, instances []*dbent.PaymentProviderInstance) map[string][]*dbent.PaymentProviderInstance {
func (s *PaymentConfigService) pcApplyEnabledVisibleMethodInstances(ctx context.Context, typeInstances map[string][]*dbent.PaymentProviderInstance, instances []*dbent.PaymentProviderInstance) map[string][]*dbent.PaymentProviderInstance {
if len(typeInstances) == 0 {
return typeInstances
}
@@ -44,11 +44,17 @@ func pcApplyEnabledVisibleMethodInstances(typeInstances map[string][]*dbent.Paym
for _, method := range []string{payment.TypeAlipay, payment.TypeWxpay} {
matching := filterEnabledVisibleMethodInstances(instances, method)
if len(matching) != 1 {
providerKey, err := s.resolveVisibleMethodProviderKey(ctx, method, matching)
if err != nil || providerKey == "" {
delete(filtered, method)
continue
}
filtered[method] = []*dbent.PaymentProviderInstance{matching[0]}
selectedInstances := filterVisibleMethodInstancesByProviderKey(instances, method, providerKey)
if len(selectedInstances) == 0 {
delete(filtered, method)
continue
}
filtered[method] = selectedInstances
}
return filtered
}

View File

@@ -301,65 +301,104 @@ func TestPcInstanceTypeLimits(t *testing.T) {
})
}
func TestGetAvailableMethodLimitsHidesConflictingVisibleMethodProviders(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeAlipay).
SetName("Official Alipay").
SetConfig("{}").
SetSupportedTypes("alipay").
SetLimits(`{"alipay":{"singleMin":10,"singleMax":100}}`).
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create official alipay instance: %v", err)
}
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeEasyPay).
SetName("EasyPay Alipay").
SetConfig("{}").
SetSupportedTypes("alipay").
SetLimits(`{"alipay":{"singleMin":20,"singleMax":200}}`).
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create easypay alipay instance: %v", err)
}
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("Official WeChat").
SetConfig("{}").
SetSupportedTypes("wxpay").
SetLimits(`{"wxpay":{"singleMin":30,"singleMax":300}}`).
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create official wxpay instance: %v", err)
func TestGetAvailableMethodLimitsUsesConfiguredVisibleMethodSource(t *testing.T) {
tests := []struct {
name string
sourceSetting string
wantAlipaySingleMin float64
wantAlipaySingleMax float64
wantGlobalMin float64
wantGlobalMax float64
}{
{
name: "official source",
sourceSetting: VisibleMethodSourceOfficialAlipay,
wantAlipaySingleMin: 10,
wantAlipaySingleMax: 100,
wantGlobalMin: 10,
wantGlobalMax: 300,
},
{
name: "easypay source",
sourceSetting: VisibleMethodSourceEasyPayAlipay,
wantAlipaySingleMin: 20,
wantAlipaySingleMax: 200,
wantGlobalMin: 20,
wantGlobalMax: 300,
},
}
svc := &PaymentConfigService{
entClient: client,
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
resp, err := svc.GetAvailableMethodLimits(ctx)
if err != nil {
t.Fatalf("GetAvailableMethodLimits returned error: %v", err)
}
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeAlipay).
SetName("Official Alipay").
SetConfig("{}").
SetSupportedTypes("alipay").
SetLimits(`{"alipay":{"singleMin":10,"singleMax":100}}`).
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create official alipay instance: %v", err)
}
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeEasyPay).
SetName("EasyPay Alipay").
SetConfig("{}").
SetSupportedTypes("alipay").
SetLimits(`{"alipay":{"singleMin":20,"singleMax":200}}`).
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create easypay alipay instance: %v", err)
}
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("Official WeChat").
SetConfig("{}").
SetSupportedTypes("wxpay").
SetLimits(`{"wxpay":{"singleMin":30,"singleMax":300}}`).
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create official wxpay instance: %v", err)
}
if _, ok := resp.Methods[payment.TypeAlipay]; ok {
t.Fatalf("alipay should be hidden when multiple enabled providers claim it, got %v", resp.Methods[payment.TypeAlipay])
}
svc := &PaymentConfigService{
entClient: client,
settingRepo: &paymentConfigSettingRepoStub{
values: map[string]string{
SettingPaymentVisibleMethodAlipaySource: tt.sourceSetting,
},
},
}
wxpayLimits, ok := resp.Methods[payment.TypeWxpay]
if !ok {
t.Fatalf("expected wxpay limits to remain visible, got %v", resp.Methods)
}
if wxpayLimits.SingleMin != 30 || wxpayLimits.SingleMax != 300 {
t.Fatalf("wxpay limits = %+v, want official-only min=30 max=300", wxpayLimits)
}
if resp.GlobalMin != 30 || resp.GlobalMax != 300 {
t.Fatalf("global range = (%v, %v), want (30, 300)", resp.GlobalMin, resp.GlobalMax)
resp, err := svc.GetAvailableMethodLimits(ctx)
if err != nil {
t.Fatalf("GetAvailableMethodLimits returned error: %v", err)
}
alipayLimits, ok := resp.Methods[payment.TypeAlipay]
if !ok {
t.Fatalf("expected alipay limits to remain visible, got %v", resp.Methods)
}
if alipayLimits.SingleMin != tt.wantAlipaySingleMin || alipayLimits.SingleMax != tt.wantAlipaySingleMax {
t.Fatalf("alipay limits = %+v, want min=%v max=%v", alipayLimits, tt.wantAlipaySingleMin, tt.wantAlipaySingleMax)
}
wxpayLimits, ok := resp.Methods[payment.TypeWxpay]
if !ok {
t.Fatalf("expected wxpay limits to remain visible, got %v", resp.Methods)
}
if wxpayLimits.SingleMin != 30 || wxpayLimits.SingleMax != 300 {
t.Fatalf("wxpay limits = %+v, want official-only min=30 max=300", wxpayLimits)
}
if resp.GlobalMin != tt.wantGlobalMin || resp.GlobalMax != tt.wantGlobalMax {
t.Fatalf("global range = (%v, %v), want (%v, %v)", resp.GlobalMin, resp.GlobalMax, tt.wantGlobalMin, tt.wantGlobalMax)
}
})
}
}

View File

@@ -4,9 +4,12 @@ package service
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -199,7 +202,7 @@ func TestJoinTypes(t *testing.T) {
}
}
func TestCreateProviderInstanceRejectsConflictingVisibleMethodEnablement(t *testing.T) {
func TestCreateProviderInstanceAllowsVisibleMethodProvidersFromDifferentSources(t *testing.T) {
t.Parallel()
ctx := context.Background()
@@ -227,15 +230,14 @@ func TestCreateProviderInstanceRejectsConflictingVisibleMethodEnablement(t *test
_, err = svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
ProviderKey: "alipay",
Name: "Official Alipay",
Config: map[string]string{"appId": "app-1"},
Config: map[string]string{"appId": "app-1", "privateKey": "private-key"},
SupportedTypes: []string{"alipay"},
Enabled: true,
})
require.Error(t, err)
require.Equal(t, "PAYMENT_PROVIDER_CONFLICT", infraerrors.Reason(err))
require.NoError(t, err)
}
func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t *testing.T) {
func TestUpdateProviderInstanceAllowsEnablingVisibleMethodProviderFromDifferentSource(t *testing.T) {
t.Parallel()
ctx := context.Background()
@@ -264,7 +266,7 @@ func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t
candidate, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
ProviderKey: "wxpay",
Name: "Official WeChat",
Config: map[string]string{"appId": "wx-app"},
Config: validWxpayProviderConfig(t),
SupportedTypes: []string{"wxpay"},
Enabled: false,
})
@@ -273,8 +275,7 @@ func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t
_, err = svc.UpdateProviderInstance(ctx, candidate.ID, UpdateProviderInstanceRequest{
Enabled: boolPtrValue(true),
})
require.Error(t, err)
require.Equal(t, "PAYMENT_PROVIDER_CONFLICT", infraerrors.Reason(err))
require.NoError(t, err)
}
func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) {
@@ -317,3 +318,25 @@ func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) {
func boolPtrValue(v bool) *bool {
return &v
}
func validWxpayProviderConfig(t *testing.T) map[string]string {
t.Helper()
key, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
privDER, err := x509.MarshalPKCS8PrivateKey(key)
require.NoError(t, err)
pubDER, err := x509.MarshalPKIXPublicKey(&key.PublicKey)
require.NoError(t, err)
return map[string]string{
"appId": "wx-app-test",
"mchId": "mch-test",
"privateKey": string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privDER})),
"apiV3Key": "12345678901234567890123456789012",
"publicKey": string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDER})),
"publicKeyId": "public-key-id-test",
"certSerial": "cert-serial-test",
}
}

View File

@@ -234,6 +234,10 @@ func paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, current
// if a payment was made, and processes it if so. This handles the case where
// the provider's notify callback was missed (e.g. EasyPay popup mode).
func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo string, userID int64) (*dbent.PaymentOrder, error) {
outTradeNo, err := normalizeOrderLookupOutTradeNo(outTradeNo)
if err != nil {
return nil, err
}
o, err := s.entClient.PaymentOrder.Query().
Where(paymentorder.OutTradeNo(outTradeNo)).
Only(ctx)
@@ -261,6 +265,10 @@ func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo
// triggering any upstream reconciliation. Signed resume-token recovery is the
// only public recovery path allowed to query upstream state.
func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) {
outTradeNo, err := normalizeOrderLookupOutTradeNo(outTradeNo)
if err != nil {
return nil, err
}
o, err := s.entClient.PaymentOrder.Query().
Where(paymentorder.OutTradeNo(outTradeNo)).
Only(ctx)
@@ -270,6 +278,27 @@ func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo strin
return o, nil
}
func normalizeOrderLookupOutTradeNo(raw string) (string, error) {
outTradeNo := strings.TrimSpace(raw)
if outTradeNo == "" {
return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is required")
}
if len(outTradeNo) > 64 {
return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is invalid")
}
for _, ch := range outTradeNo {
switch {
case ch >= 'a' && ch <= 'z':
case ch >= 'A' && ch <= 'Z':
case ch >= '0' && ch <= '9':
case ch == '_' || ch == '-':
default:
return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is invalid")
}
}
return outTradeNo, nil
}
func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) {
now := time.Now()
orders, err := s.entClient.PaymentOrder.Query().Where(paymentorder.StatusEQ(OrderStatusPending), paymentorder.ExpiresAtLTE(now)).All(ctx)

View File

@@ -2,6 +2,7 @@ package service
import (
"context"
"strings"
"testing"
"time"
@@ -91,6 +92,8 @@ func TestBuildCreateOrderResponseCopiesJSAPIPayload(t *testing.T) {
}
func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) {
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
svc := newWeChatPaymentOAuthTestService(map[string]string{
SettingKeyWeChatConnectEnabled: "true",
SettingKeyWeChatConnectAppID: "wx123456",
@@ -198,6 +201,44 @@ func TestMaybeBuildWeChatOAuthRequiredResponseRequiresResumeSigningKey(t *testin
}
}
func TestMaybeBuildWeChatOAuthRequiredResponseFallsBackToConfiguredLegacySigningKey(t *testing.T) {
svc := &PaymentService{
configService: &PaymentConfigService{
settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{
SettingKeyWeChatConnectEnabled: "true",
SettingKeyWeChatConnectAppID: "wx123456",
SettingKeyWeChatConnectAppSecret: "wechat-secret",
SettingKeyWeChatConnectMode: "mp",
SettingKeyWeChatConnectScopes: "snsapi_base",
SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
}},
// Legacy stable signing key remains available for no-config upgrade compatibility.
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
},
}
resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{
Amount: 12.5,
PaymentType: payment.TypeWxpay,
IsWeChatBrowser: true,
SrcURL: "https://merchant.example/payment?from=wechat",
OrderType: payment.OrderTypeBalance,
}, 12.5, 12.88, 0.03)
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if resp == nil {
t.Fatal("expected oauth-required response, got nil")
}
if resp.ResultType != payment.CreatePaymentResultOAuthRequired {
t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultOAuthRequired)
}
if resp.OAuth == nil || strings.TrimSpace(resp.OAuth.AuthorizeURL) == "" {
t.Fatalf("expected oauth redirect payload, got %+v", resp.OAuth)
}
}
func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t *testing.T) {
svc := newWeChatPaymentOAuthTestService(map[string]string{
SettingKeyWeChatConnectEnabled: "true",

View File

@@ -6,6 +6,7 @@ import (
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token string) (*dbent.PaymentOrder, error) {
@@ -16,10 +17,13 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
order, err := s.entClient.PaymentOrder.Get(ctx, claims.OrderID)
if err != nil {
if dbent.IsNotFound(err) {
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
return nil, fmt.Errorf("get order by resume token: %w", err)
}
if claims.UserID > 0 && order.UserID != claims.UserID {
return nil, fmt.Errorf("resume token user mismatch")
return nil, invalidResumeTokenMatchError()
}
snapshot := psOrderProviderSnapshot(order)
orderProviderInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID))
@@ -33,13 +37,13 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
}
}
if claims.ProviderInstanceID != "" && orderProviderInstanceID != claims.ProviderInstanceID {
return nil, fmt.Errorf("resume token provider instance mismatch")
return nil, invalidResumeTokenMatchError()
}
if claims.ProviderKey != "" && orderProviderKey != claims.ProviderKey {
return nil, fmt.Errorf("resume token provider key mismatch")
if claims.ProviderKey != "" && !strings.EqualFold(orderProviderKey, claims.ProviderKey) {
return nil, invalidResumeTokenMatchError()
}
if claims.PaymentType != "" && strings.TrimSpace(order.PaymentType) != claims.PaymentType {
return nil, fmt.Errorf("resume token payment type mismatch")
if claims.PaymentType != "" && NormalizeVisibleMethod(order.PaymentType) != NormalizeVisibleMethod(claims.PaymentType) {
return nil, invalidResumeTokenMatchError()
}
if order.Status == OrderStatusPending || order.Status == OrderStatusExpired {
result := s.checkPaid(ctx, order)
@@ -54,6 +58,10 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
return order, nil
}
func invalidResumeTokenMatchError() error {
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token does not match the payment order")
}
func (s *PaymentService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) {
return s.paymentResume().ParseWeChatPaymentResumeToken(strings.TrimSpace(token))
}

View File

@@ -8,6 +8,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
@@ -143,7 +144,7 @@ func TestGetPublicOrderByResumeTokenRejectsSnapshotMismatch(t *testing.T) {
_, err = svc.GetPublicOrderByResumeToken(ctx, token)
require.Error(t, err)
require.Contains(t, err.Error(), "resume token")
require.Equal(t, "INVALID_RESUME_TOKEN", infraerrors.Reason(err))
}
func TestGetPublicOrderByResumeTokenUsesSnapshotAuthorityWhenColumnsDiffer(t *testing.T) {
@@ -302,3 +303,13 @@ func TestVerifyOrderPublicDoesNotCheckUpstreamForPendingOrder(t *testing.T) {
require.Equal(t, order.ID, got.ID)
require.Equal(t, 0, provider.queryCount)
}
func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) {
svc := &PaymentService{
entClient: newPaymentConfigServiceTestClient(t),
}
_, err := svc.VerifyOrderPublic(context.Background(), " ")
require.Error(t, err)
require.Equal(t, "INVALID_OUT_TRADE_NO", infraerrors.Reason(err))
}

View File

@@ -1,6 +1,7 @@
package service
import (
"bytes"
"context"
"crypto/hmac"
"crypto/sha256"
@@ -68,6 +69,7 @@ type WeChatPaymentResumeClaims struct {
type PaymentResumeService struct {
signingKey []byte
verifyKeys [][]byte
}
type visibleMethodLoadBalancer struct {
@@ -75,8 +77,29 @@ type visibleMethodLoadBalancer struct {
configService *PaymentConfigService
}
func NewPaymentResumeService(signingKey []byte) *PaymentResumeService {
return &PaymentResumeService{signingKey: signingKey}
func NewPaymentResumeService(signingKey []byte, verifyFallbacks ...[]byte) *PaymentResumeService {
svc := &PaymentResumeService{}
if len(signingKey) > 0 {
svc.signingKey = append([]byte(nil), signingKey...)
svc.verifyKeys = append(svc.verifyKeys, svc.signingKey)
}
for _, fallback := range verifyFallbacks {
if len(fallback) == 0 {
continue
}
cloned := append([]byte(nil), fallback...)
duplicate := false
for _, existing := range svc.verifyKeys {
if bytes.Equal(existing, cloned) {
duplicate = true
break
}
}
if !duplicate {
svc.verifyKeys = append(svc.verifyKeys, cloned)
}
}
return svc
}
func (s *PaymentResumeService) isSigningConfigured() bool {
@@ -410,7 +433,7 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed")
}
if !hmac.Equal([]byte(parts[1]), []byte(s.sign(parts[0]))) {
if !s.verifySignature(parts[0], parts[1]) {
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[0])
@@ -420,6 +443,18 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
return json.Unmarshal(payload, dest)
}
func (s *PaymentResumeService) verifySignature(payload string, signature string) bool {
if s == nil {
return false
}
for _, key := range s.verifyKeys {
if hmac.Equal([]byte(signature), []byte(signPaymentResumePayload(payload, key))) {
return true
}
}
return false
}
func validatePaymentResumeExpiry(expiresAt int64, code, message string) error {
if expiresAt <= 0 {
return nil
@@ -431,7 +466,11 @@ func validatePaymentResumeExpiry(expiresAt int64, code, message string) error {
}
func (s *PaymentResumeService) sign(payload string) string {
mac := hmac.New(sha256.New, s.signingKey)
return signPaymentResumePayload(payload, s.signingKey)
}
func signPaymentResumePayload(payload string, key []byte) string {
mac := hmac.New(sha256.New, key)
_, _ = mac.Write([]byte(payload))
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
}

View File

@@ -334,6 +334,59 @@ func TestParseWeChatPaymentResumeTokenRejectsExpiredToken(t *testing.T) {
}
}
func TestPaymentServiceParseWeChatPaymentResumeTokenUsesExplicitSigningKey(t *testing.T) {
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key")
token, err := NewPaymentResumeService([]byte("explicit-payment-resume-signing-key")).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
OpenID: "openid-explicit-key",
PaymentType: payment.TypeWxpay,
})
if err != nil {
t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
}
svc := &PaymentService{
configService: &PaymentConfigService{
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
},
}
claims, err := svc.ParseWeChatPaymentResumeToken(token)
if err != nil {
t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
}
if claims.OpenID != "openid-explicit-key" {
t.Fatalf("openid = %q, want %q", claims.OpenID, "openid-explicit-key")
}
}
func TestPaymentServiceParseWeChatPaymentResumeTokenAcceptsLegacyEncryptionKeyDuringMigration(t *testing.T) {
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key")
legacyKey := []byte("0123456789abcdef0123456789abcdef")
token, err := NewPaymentResumeService(legacyKey).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
OpenID: "openid-legacy-key",
PaymentType: payment.TypeWxpay,
})
if err != nil {
t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
}
svc := &PaymentService{
configService: &PaymentConfigService{
encryptionKey: legacyKey,
},
}
claims, err := svc.ParseWeChatPaymentResumeToken(token)
if err != nil {
t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
}
if claims.OpenID != "openid-legacy-key" {
t.Fatalf("openid = %q, want %q", claims.OpenID, "openid-legacy-key")
}
}
func TestNormalizeVisibleMethodSource(t *testing.T) {
t.Parallel()
@@ -424,14 +477,14 @@ func TestVisibleMethodLoadBalancerUsesConfiguredSourceWhenMultipleProvidersEnabl
t.Parallel()
tests := []struct {
name string
method payment.PaymentType
officialName string
officialTypes string
easyPayName string
easyPayTypes string
sourceSetting string
wantProvider string
name string
method payment.PaymentType
officialName string
officialTypes string
easyPayName string
easyPayTypes string
sourceSetting string
wantProvider string
}{
{
name: "alipay uses official source",
@@ -487,7 +540,7 @@ func TestVisibleMethodLoadBalancerUsesConfiguredSourceWhenMultipleProvidersEnabl
officialProviderKey = payment.TypeWxpay
}
_, err = client.PaymentProviderInstance.Create().
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(officialProviderKey).
SetName(tt.officialName).
SetConfig("{}").

View File

@@ -1,10 +1,14 @@
package service
import (
"bytes"
"context"
"encoding/hex"
"fmt"
"log/slog"
"math/rand/v2"
"os"
"strings"
"sync"
"time"
@@ -44,6 +48,8 @@ const (
orderIDPrefix = "sub2_"
)
const paymentResumeSigningKeyEnv = "PAYMENT_RESUME_SIGNING_KEY"
// --- Types ---
// generateOutTradeNo creates a unique external order ID for payment providers.
@@ -179,7 +185,7 @@ type PaymentService struct {
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
svc.resumeService = NewPaymentResumeService(psResumeSigningKey(configService))
svc.resumeService = psNewPaymentResumeService(configService)
return svc
}
@@ -259,16 +265,54 @@ func (s *PaymentService) paymentResume() *PaymentResumeService {
if s.resumeService != nil {
return s.resumeService
}
return NewPaymentResumeService(psResumeSigningKey(s.configService))
return psNewPaymentResumeService(s.configService)
}
func psNewPaymentResumeService(configService *PaymentConfigService) *PaymentResumeService {
signingKey, verifyFallbacks := psResumeSigningKeys(configService)
return NewPaymentResumeService(signingKey, verifyFallbacks...)
}
func psResumeSigningKey(configService *PaymentConfigService) []byte {
signingKey, _ := psResumeSigningKeys(configService)
return signingKey
}
func psResumeSigningKeys(configService *PaymentConfigService) ([]byte, [][]byte) {
signingKey := parsePaymentResumeSigningKey(os.Getenv(paymentResumeSigningKeyEnv))
legacyKey := psResumeLegacyVerificationKey(configService)
if len(signingKey) == 0 {
if len(legacyKey) == 0 {
return nil, nil
}
return legacyKey, nil
}
if len(legacyKey) == 0 || bytes.Equal(legacyKey, signingKey) {
return signingKey, nil
}
return signingKey, [][]byte{legacyKey}
}
func psResumeLegacyVerificationKey(configService *PaymentConfigService) []byte {
if configService == nil {
return nil
}
return configService.encryptionKey
}
func parsePaymentResumeSigningKey(raw string) []byte {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
if len(raw) >= 64 && len(raw)%2 == 0 {
if decoded, err := hex.DecodeString(raw); err == nil && len(decoded) > 0 {
return decoded
}
}
return []byte(raw)
}
func psSliceContains(sl []string, s string) bool {
for _, v := range sl {
if v == s {

View File

@@ -82,6 +82,41 @@ func filterEnabledVisibleMethodInstances(instances []*dbent.PaymentProviderInsta
return filtered
}
func filterVisibleMethodInstancesByProviderKey(instances []*dbent.PaymentProviderInstance, method string, providerKey string) []*dbent.PaymentProviderInstance {
filtered := make([]*dbent.PaymentProviderInstance, 0, len(instances))
for _, inst := range instances {
if !providerSupportsVisibleMethod(inst, method) {
continue
}
if !strings.EqualFold(strings.TrimSpace(inst.ProviderKey), strings.TrimSpace(providerKey)) {
continue
}
filtered = append(filtered, inst)
}
return filtered
}
func distinctVisibleMethodProviderKeys(instances []*dbent.PaymentProviderInstance) []string {
seen := make(map[string]struct{}, len(instances))
keys := make([]string, 0, len(instances))
for _, inst := range instances {
if inst == nil {
continue
}
key := strings.TrimSpace(inst.ProviderKey)
if key == "" {
continue
}
normalized := strings.ToLower(key)
if _, ok := seen[normalized]; ok {
continue
}
seen[normalized] = struct{}{}
keys = append(keys, key)
}
return keys
}
func selectVisibleMethodInstanceByProviderKey(instances []*dbent.PaymentProviderInstance, providerKey string) *dbent.PaymentProviderInstance {
providerKey = strings.TrimSpace(providerKey)
if providerKey == "" {
@@ -117,32 +152,10 @@ func (s *PaymentConfigService) validateVisibleMethodEnablementConflicts(
supportedTypes string,
enabled bool,
) error {
if s == nil || s.entClient == nil || !enabled {
return nil
}
claimedMethods := enabledVisibleMethodsForProvider(providerKey, supportedTypes)
if len(claimedMethods) == 0 {
return nil
}
query := s.entClient.PaymentProviderInstance.Query().
Where(paymentproviderinstance.EnabledEQ(true))
if excludeID > 0 {
query = query.Where(paymentproviderinstance.IDNEQ(excludeID))
}
instances, err := query.All(ctx)
if err != nil {
return fmt.Errorf("query enabled payment providers: %w", err)
}
for _, method := range claimedMethods {
for _, inst := range instances {
if providerSupportsVisibleMethod(inst, method) {
return buildPaymentProviderConflictError(method, inst)
}
}
}
// Visible methods are selected by configured source (official/easypay),
// so multiple enabled providers can intentionally claim the same user-facing
// method. Order creation and limits will route through the configured source.
_, _, _, _, _ = ctx, excludeID, providerKey, supportedTypes, enabled
return nil
}
@@ -172,6 +185,32 @@ func (s *PaymentConfigService) resolveVisibleMethodSourceProviderKey(ctx context
return providerKey, nil
}
func (s *PaymentConfigService) resolveVisibleMethodProviderKey(
ctx context.Context,
method string,
matching []*dbent.PaymentProviderInstance,
) (string, error) {
switch providerKeys := distinctVisibleMethodProviderKeys(matching); len(providerKeys) {
case 0:
return "", nil
case 1:
return strings.TrimSpace(providerKeys[0]), nil
default:
providerKey, err := s.resolveVisibleMethodSourceProviderKey(ctx, method)
if err != nil {
return "", err
}
selected := selectVisibleMethodInstanceByProviderKey(matching, providerKey)
if selected == nil {
return "", infraerrors.BadRequest(
"INVALID_PAYMENT_VISIBLE_METHOD_SOURCE",
fmt.Sprintf("%s source has no enabled provider instance", method),
)
}
return strings.TrimSpace(selected.ProviderKey), nil
}
}
func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance(
ctx context.Context,
method string,
@@ -194,23 +233,9 @@ func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance(
}
matching := filterEnabledVisibleMethodInstances(instances, method)
switch len(matching) {
case 0:
return nil, nil
case 1:
return matching[0], nil
default:
providerKey, err := s.resolveVisibleMethodSourceProviderKey(ctx, method)
if err != nil {
return nil, err
}
selected := selectVisibleMethodInstanceByProviderKey(matching, providerKey)
if selected == nil {
return nil, infraerrors.BadRequest(
"INVALID_PAYMENT_VISIBLE_METHOD_SOURCE",
fmt.Sprintf("%s source has no enabled provider instance", method),
)
}
return selected, nil
providerKey, err := s.resolveVisibleMethodProviderKey(ctx, method, matching)
if err != nil {
return nil, err
}
return selectVisibleMethodInstanceByProviderKey(matching, providerKey), nil
}