fix(payment): support source routing and compatible resume signing

This commit is contained in:
IanShaw027
2026-04-22 12:30:17 +08:00
parent b2e0712190
commit d6a04bb772
12 changed files with 570 additions and 136 deletions

View File

@@ -164,9 +164,8 @@ func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) {
} }
func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing.T) { func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
db, err := sql.Open("sqlite", "file:payment_handler_public_resolve?mode=memory&cache=shared") db, err := sql.Open("sqlite", "file:payment_handler_public_resolve?mode=memory&cache=shared")
require.NoError(t, err) require.NoError(t, err)
@@ -250,3 +249,120 @@ func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing
require.Contains(t, resp.Data, "expires_at") require.Contains(t, resp.Data, "expires_at")
require.Contains(t, resp.Data, "refund_amount") require.Contains(t, resp.Data, "refund_amount")
} }
func TestResolveOrderPublicByResumeTokenReturnsBadRequestForMismatchedToken(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
db, err := sql.Open("sqlite", "file:payment_handler_public_resolve_mismatch?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
user, err := client.User.Create().
SetEmail("public-resolve-mismatch@example.com").
SetPasswordHash("hash").
SetUsername("public-resolve-mismatch-user").
Save(context.Background())
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(100).
SetPayAmount(103).
SetFeeRate(0.03).
SetRechargeCode("PUBLIC-RESOLVE-MISMATCH").
SetOutTradeNo("resolve-order-mismatch-no").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("trade-public-resolve-mismatch").
SetOrderType(payment.OrderTypeBalance).
SetStatus(service.OrderStatusPaid).
SetExpiresAt(time.Now().Add(time.Hour)).
SetPaidAt(time.Now()).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(context.Background())
require.NoError(t, err)
resumeSvc := service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := resumeSvc.CreateToken(service.ResumeTokenClaims{
OrderID: order.ID,
UserID: user.ID + 999,
PaymentType: payment.TypeAlipay,
CanonicalReturnURL: "https://app.example.com/payment/result",
})
require.NoError(t, err)
configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(
http.MethodPost,
"/api/v1/payment/public/orders/resolve",
bytes.NewBufferString(`{"resume_token":"`+token+`"}`),
)
ctx.Request.Header.Set("Content-Type", "application/json")
h.ResolveOrderPublicByResumeToken(ctx)
require.Equal(t, http.StatusBadRequest, recorder.Code)
var resp struct {
Code int `json:"code"`
Reason string `json:"reason"`
Message string `json:"message"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, http.StatusBadRequest, resp.Code)
require.Equal(t, "INVALID_RESUME_TOKEN", resp.Reason)
}
func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) {
gin.SetMode(gin.TestMode)
db, err := sql.Open("sqlite", "file:payment_handler_public_verify_blank?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(
http.MethodPost,
"/api/v1/payment/public/orders/verify",
bytes.NewBufferString(`{"out_trade_no":" "}`),
)
ctx.Request.Header.Set("Content-Type", "application/json")
h.VerifyOrderPublic(ctx)
require.Equal(t, http.StatusBadRequest, recorder.Code)
var resp struct {
Code int `json:"code"`
Reason string `json:"reason"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, http.StatusBadRequest, resp.Code)
require.Equal(t, "INVALID_OUT_TRADE_NO", resp.Reason)
}

View File

@@ -20,7 +20,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
return nil, fmt.Errorf("query provider instances: %w", err) return nil, fmt.Errorf("query provider instances: %w", err)
} }
typeInstances := pcGroupByPaymentType(instances) typeInstances := pcGroupByPaymentType(instances)
typeInstances = pcApplyEnabledVisibleMethodInstances(typeInstances, instances) typeInstances = s.pcApplyEnabledVisibleMethodInstances(ctx, typeInstances, instances)
resp := &MethodLimitsResponse{ resp := &MethodLimitsResponse{
Methods: make(map[string]MethodLimits, len(typeInstances)), Methods: make(map[string]MethodLimits, len(typeInstances)),
} }
@@ -32,7 +32,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
return resp, nil return resp, nil
} }
func pcApplyEnabledVisibleMethodInstances(typeInstances map[string][]*dbent.PaymentProviderInstance, instances []*dbent.PaymentProviderInstance) map[string][]*dbent.PaymentProviderInstance { func (s *PaymentConfigService) pcApplyEnabledVisibleMethodInstances(ctx context.Context, typeInstances map[string][]*dbent.PaymentProviderInstance, instances []*dbent.PaymentProviderInstance) map[string][]*dbent.PaymentProviderInstance {
if len(typeInstances) == 0 { if len(typeInstances) == 0 {
return typeInstances return typeInstances
} }
@@ -44,11 +44,17 @@ func pcApplyEnabledVisibleMethodInstances(typeInstances map[string][]*dbent.Paym
for _, method := range []string{payment.TypeAlipay, payment.TypeWxpay} { for _, method := range []string{payment.TypeAlipay, payment.TypeWxpay} {
matching := filterEnabledVisibleMethodInstances(instances, method) matching := filterEnabledVisibleMethodInstances(instances, method)
if len(matching) != 1 { providerKey, err := s.resolveVisibleMethodProviderKey(ctx, method, matching)
if err != nil || providerKey == "" {
delete(filtered, method) delete(filtered, method)
continue continue
} }
filtered[method] = []*dbent.PaymentProviderInstance{matching[0]} selectedInstances := filterVisibleMethodInstancesByProviderKey(instances, method, providerKey)
if len(selectedInstances) == 0 {
delete(filtered, method)
continue
}
filtered[method] = selectedInstances
} }
return filtered return filtered
} }

View File

@@ -301,65 +301,104 @@ func TestPcInstanceTypeLimits(t *testing.T) {
}) })
} }
func TestGetAvailableMethodLimitsHidesConflictingVisibleMethodProviders(t *testing.T) { func TestGetAvailableMethodLimitsUsesConfiguredVisibleMethodSource(t *testing.T) {
ctx := context.Background() tests := []struct {
client := newPaymentConfigServiceTestClient(t) name string
sourceSetting string
_, err := client.PaymentProviderInstance.Create(). wantAlipaySingleMin float64
SetProviderKey(payment.TypeAlipay). wantAlipaySingleMax float64
SetName("Official Alipay"). wantGlobalMin float64
SetConfig("{}"). wantGlobalMax float64
SetSupportedTypes("alipay"). }{
SetLimits(`{"alipay":{"singleMin":10,"singleMax":100}}`). {
SetEnabled(true). name: "official source",
Save(ctx) sourceSetting: VisibleMethodSourceOfficialAlipay,
if err != nil { wantAlipaySingleMin: 10,
t.Fatalf("create official alipay instance: %v", err) wantAlipaySingleMax: 100,
} wantGlobalMin: 10,
_, err = client.PaymentProviderInstance.Create(). wantGlobalMax: 300,
SetProviderKey(payment.TypeEasyPay). },
SetName("EasyPay Alipay"). {
SetConfig("{}"). name: "easypay source",
SetSupportedTypes("alipay"). sourceSetting: VisibleMethodSourceEasyPayAlipay,
SetLimits(`{"alipay":{"singleMin":20,"singleMax":200}}`). wantAlipaySingleMin: 20,
SetEnabled(true). wantAlipaySingleMax: 200,
Save(ctx) wantGlobalMin: 20,
if err != nil { wantGlobalMax: 300,
t.Fatalf("create easypay alipay instance: %v", err) },
}
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("Official WeChat").
SetConfig("{}").
SetSupportedTypes("wxpay").
SetLimits(`{"wxpay":{"singleMin":30,"singleMax":300}}`).
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create official wxpay instance: %v", err)
} }
svc := &PaymentConfigService{ for _, tt := range tests {
entClient: client, t.Run(tt.name, func(t *testing.T) {
} ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
resp, err := svc.GetAvailableMethodLimits(ctx) _, err := client.PaymentProviderInstance.Create().
if err != nil { SetProviderKey(payment.TypeAlipay).
t.Fatalf("GetAvailableMethodLimits returned error: %v", err) SetName("Official Alipay").
} SetConfig("{}").
SetSupportedTypes("alipay").
SetLimits(`{"alipay":{"singleMin":10,"singleMax":100}}`).
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create official alipay instance: %v", err)
}
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeEasyPay).
SetName("EasyPay Alipay").
SetConfig("{}").
SetSupportedTypes("alipay").
SetLimits(`{"alipay":{"singleMin":20,"singleMax":200}}`).
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create easypay alipay instance: %v", err)
}
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("Official WeChat").
SetConfig("{}").
SetSupportedTypes("wxpay").
SetLimits(`{"wxpay":{"singleMin":30,"singleMax":300}}`).
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create official wxpay instance: %v", err)
}
if _, ok := resp.Methods[payment.TypeAlipay]; ok { svc := &PaymentConfigService{
t.Fatalf("alipay should be hidden when multiple enabled providers claim it, got %v", resp.Methods[payment.TypeAlipay]) entClient: client,
} settingRepo: &paymentConfigSettingRepoStub{
values: map[string]string{
SettingPaymentVisibleMethodAlipaySource: tt.sourceSetting,
},
},
}
wxpayLimits, ok := resp.Methods[payment.TypeWxpay] resp, err := svc.GetAvailableMethodLimits(ctx)
if !ok { if err != nil {
t.Fatalf("expected wxpay limits to remain visible, got %v", resp.Methods) t.Fatalf("GetAvailableMethodLimits returned error: %v", err)
} }
if wxpayLimits.SingleMin != 30 || wxpayLimits.SingleMax != 300 {
t.Fatalf("wxpay limits = %+v, want official-only min=30 max=300", wxpayLimits) alipayLimits, ok := resp.Methods[payment.TypeAlipay]
} if !ok {
if resp.GlobalMin != 30 || resp.GlobalMax != 300 { t.Fatalf("expected alipay limits to remain visible, got %v", resp.Methods)
t.Fatalf("global range = (%v, %v), want (30, 300)", resp.GlobalMin, resp.GlobalMax) }
if alipayLimits.SingleMin != tt.wantAlipaySingleMin || alipayLimits.SingleMax != tt.wantAlipaySingleMax {
t.Fatalf("alipay limits = %+v, want min=%v max=%v", alipayLimits, tt.wantAlipaySingleMin, tt.wantAlipaySingleMax)
}
wxpayLimits, ok := resp.Methods[payment.TypeWxpay]
if !ok {
t.Fatalf("expected wxpay limits to remain visible, got %v", resp.Methods)
}
if wxpayLimits.SingleMin != 30 || wxpayLimits.SingleMax != 300 {
t.Fatalf("wxpay limits = %+v, want official-only min=30 max=300", wxpayLimits)
}
if resp.GlobalMin != tt.wantGlobalMin || resp.GlobalMax != tt.wantGlobalMax {
t.Fatalf("global range = (%v, %v), want (%v, %v)", resp.GlobalMin, resp.GlobalMax, tt.wantGlobalMin, tt.wantGlobalMax)
}
})
} }
} }

View File

@@ -4,9 +4,12 @@ package service
import ( import (
"context" "context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"testing" "testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -199,7 +202,7 @@ func TestJoinTypes(t *testing.T) {
} }
} }
func TestCreateProviderInstanceRejectsConflictingVisibleMethodEnablement(t *testing.T) { func TestCreateProviderInstanceAllowsVisibleMethodProvidersFromDifferentSources(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background() ctx := context.Background()
@@ -227,15 +230,14 @@ func TestCreateProviderInstanceRejectsConflictingVisibleMethodEnablement(t *test
_, err = svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{ _, err = svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
ProviderKey: "alipay", ProviderKey: "alipay",
Name: "Official Alipay", Name: "Official Alipay",
Config: map[string]string{"appId": "app-1"}, Config: map[string]string{"appId": "app-1", "privateKey": "private-key"},
SupportedTypes: []string{"alipay"}, SupportedTypes: []string{"alipay"},
Enabled: true, Enabled: true,
}) })
require.Error(t, err) require.NoError(t, err)
require.Equal(t, "PAYMENT_PROVIDER_CONFLICT", infraerrors.Reason(err))
} }
func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t *testing.T) { func TestUpdateProviderInstanceAllowsEnablingVisibleMethodProviderFromDifferentSource(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background() ctx := context.Background()
@@ -264,7 +266,7 @@ func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t
candidate, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{ candidate, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
ProviderKey: "wxpay", ProviderKey: "wxpay",
Name: "Official WeChat", Name: "Official WeChat",
Config: map[string]string{"appId": "wx-app"}, Config: validWxpayProviderConfig(t),
SupportedTypes: []string{"wxpay"}, SupportedTypes: []string{"wxpay"},
Enabled: false, Enabled: false,
}) })
@@ -273,8 +275,7 @@ func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t
_, err = svc.UpdateProviderInstance(ctx, candidate.ID, UpdateProviderInstanceRequest{ _, err = svc.UpdateProviderInstance(ctx, candidate.ID, UpdateProviderInstanceRequest{
Enabled: boolPtrValue(true), Enabled: boolPtrValue(true),
}) })
require.Error(t, err) require.NoError(t, err)
require.Equal(t, "PAYMENT_PROVIDER_CONFLICT", infraerrors.Reason(err))
} }
func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) { func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) {
@@ -317,3 +318,25 @@ func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) {
func boolPtrValue(v bool) *bool { func boolPtrValue(v bool) *bool {
return &v return &v
} }
func validWxpayProviderConfig(t *testing.T) map[string]string {
t.Helper()
key, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
privDER, err := x509.MarshalPKCS8PrivateKey(key)
require.NoError(t, err)
pubDER, err := x509.MarshalPKIXPublicKey(&key.PublicKey)
require.NoError(t, err)
return map[string]string{
"appId": "wx-app-test",
"mchId": "mch-test",
"privateKey": string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privDER})),
"apiV3Key": "12345678901234567890123456789012",
"publicKey": string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDER})),
"publicKeyId": "public-key-id-test",
"certSerial": "cert-serial-test",
}
}

View File

@@ -234,6 +234,10 @@ func paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, current
// if a payment was made, and processes it if so. This handles the case where // if a payment was made, and processes it if so. This handles the case where
// the provider's notify callback was missed (e.g. EasyPay popup mode). // the provider's notify callback was missed (e.g. EasyPay popup mode).
func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo string, userID int64) (*dbent.PaymentOrder, error) { func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo string, userID int64) (*dbent.PaymentOrder, error) {
outTradeNo, err := normalizeOrderLookupOutTradeNo(outTradeNo)
if err != nil {
return nil, err
}
o, err := s.entClient.PaymentOrder.Query(). o, err := s.entClient.PaymentOrder.Query().
Where(paymentorder.OutTradeNo(outTradeNo)). Where(paymentorder.OutTradeNo(outTradeNo)).
Only(ctx) Only(ctx)
@@ -261,6 +265,10 @@ func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo
// triggering any upstream reconciliation. Signed resume-token recovery is the // triggering any upstream reconciliation. Signed resume-token recovery is the
// only public recovery path allowed to query upstream state. // only public recovery path allowed to query upstream state.
func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) { func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) {
outTradeNo, err := normalizeOrderLookupOutTradeNo(outTradeNo)
if err != nil {
return nil, err
}
o, err := s.entClient.PaymentOrder.Query(). o, err := s.entClient.PaymentOrder.Query().
Where(paymentorder.OutTradeNo(outTradeNo)). Where(paymentorder.OutTradeNo(outTradeNo)).
Only(ctx) Only(ctx)
@@ -270,6 +278,27 @@ func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo strin
return o, nil return o, nil
} }
func normalizeOrderLookupOutTradeNo(raw string) (string, error) {
outTradeNo := strings.TrimSpace(raw)
if outTradeNo == "" {
return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is required")
}
if len(outTradeNo) > 64 {
return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is invalid")
}
for _, ch := range outTradeNo {
switch {
case ch >= 'a' && ch <= 'z':
case ch >= 'A' && ch <= 'Z':
case ch >= '0' && ch <= '9':
case ch == '_' || ch == '-':
default:
return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is invalid")
}
}
return outTradeNo, nil
}
func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) { func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) {
now := time.Now() now := time.Now()
orders, err := s.entClient.PaymentOrder.Query().Where(paymentorder.StatusEQ(OrderStatusPending), paymentorder.ExpiresAtLTE(now)).All(ctx) orders, err := s.entClient.PaymentOrder.Query().Where(paymentorder.StatusEQ(OrderStatusPending), paymentorder.ExpiresAtLTE(now)).All(ctx)

View File

@@ -2,6 +2,7 @@ package service
import ( import (
"context" "context"
"strings"
"testing" "testing"
"time" "time"
@@ -91,6 +92,8 @@ func TestBuildCreateOrderResponseCopiesJSAPIPayload(t *testing.T) {
} }
func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) { func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) {
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
svc := newWeChatPaymentOAuthTestService(map[string]string{ svc := newWeChatPaymentOAuthTestService(map[string]string{
SettingKeyWeChatConnectEnabled: "true", SettingKeyWeChatConnectEnabled: "true",
SettingKeyWeChatConnectAppID: "wx123456", SettingKeyWeChatConnectAppID: "wx123456",
@@ -198,6 +201,44 @@ func TestMaybeBuildWeChatOAuthRequiredResponseRequiresResumeSigningKey(t *testin
} }
} }
func TestMaybeBuildWeChatOAuthRequiredResponseFallsBackToConfiguredLegacySigningKey(t *testing.T) {
svc := &PaymentService{
configService: &PaymentConfigService{
settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{
SettingKeyWeChatConnectEnabled: "true",
SettingKeyWeChatConnectAppID: "wx123456",
SettingKeyWeChatConnectAppSecret: "wechat-secret",
SettingKeyWeChatConnectMode: "mp",
SettingKeyWeChatConnectScopes: "snsapi_base",
SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
}},
// Legacy stable signing key remains available for no-config upgrade compatibility.
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
},
}
resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{
Amount: 12.5,
PaymentType: payment.TypeWxpay,
IsWeChatBrowser: true,
SrcURL: "https://merchant.example/payment?from=wechat",
OrderType: payment.OrderTypeBalance,
}, 12.5, 12.88, 0.03)
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if resp == nil {
t.Fatal("expected oauth-required response, got nil")
}
if resp.ResultType != payment.CreatePaymentResultOAuthRequired {
t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultOAuthRequired)
}
if resp.OAuth == nil || strings.TrimSpace(resp.OAuth.AuthorizeURL) == "" {
t.Fatalf("expected oauth redirect payload, got %+v", resp.OAuth)
}
}
func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t *testing.T) { func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t *testing.T) {
svc := newWeChatPaymentOAuthTestService(map[string]string{ svc := newWeChatPaymentOAuthTestService(map[string]string{
SettingKeyWeChatConnectEnabled: "true", SettingKeyWeChatConnectEnabled: "true",

View File

@@ -6,6 +6,7 @@ import (
"strings" "strings"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
) )
func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token string) (*dbent.PaymentOrder, error) { func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token string) (*dbent.PaymentOrder, error) {
@@ -16,10 +17,13 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
order, err := s.entClient.PaymentOrder.Get(ctx, claims.OrderID) order, err := s.entClient.PaymentOrder.Get(ctx, claims.OrderID)
if err != nil { if err != nil {
if dbent.IsNotFound(err) {
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
return nil, fmt.Errorf("get order by resume token: %w", err) return nil, fmt.Errorf("get order by resume token: %w", err)
} }
if claims.UserID > 0 && order.UserID != claims.UserID { if claims.UserID > 0 && order.UserID != claims.UserID {
return nil, fmt.Errorf("resume token user mismatch") return nil, invalidResumeTokenMatchError()
} }
snapshot := psOrderProviderSnapshot(order) snapshot := psOrderProviderSnapshot(order)
orderProviderInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID)) orderProviderInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID))
@@ -33,13 +37,13 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
} }
} }
if claims.ProviderInstanceID != "" && orderProviderInstanceID != claims.ProviderInstanceID { if claims.ProviderInstanceID != "" && orderProviderInstanceID != claims.ProviderInstanceID {
return nil, fmt.Errorf("resume token provider instance mismatch") return nil, invalidResumeTokenMatchError()
} }
if claims.ProviderKey != "" && orderProviderKey != claims.ProviderKey { if claims.ProviderKey != "" && !strings.EqualFold(orderProviderKey, claims.ProviderKey) {
return nil, fmt.Errorf("resume token provider key mismatch") return nil, invalidResumeTokenMatchError()
} }
if claims.PaymentType != "" && strings.TrimSpace(order.PaymentType) != claims.PaymentType { if claims.PaymentType != "" && NormalizeVisibleMethod(order.PaymentType) != NormalizeVisibleMethod(claims.PaymentType) {
return nil, fmt.Errorf("resume token payment type mismatch") return nil, invalidResumeTokenMatchError()
} }
if order.Status == OrderStatusPending || order.Status == OrderStatusExpired { if order.Status == OrderStatusPending || order.Status == OrderStatusExpired {
result := s.checkPaid(ctx, order) result := s.checkPaid(ctx, order)
@@ -54,6 +58,10 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
return order, nil return order, nil
} }
func invalidResumeTokenMatchError() error {
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token does not match the payment order")
}
func (s *PaymentService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) { func (s *PaymentService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) {
return s.paymentResume().ParseWeChatPaymentResumeToken(strings.TrimSpace(token)) return s.paymentResume().ParseWeChatPaymentResumeToken(strings.TrimSpace(token))
} }

View File

@@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/payment" "github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -143,7 +144,7 @@ func TestGetPublicOrderByResumeTokenRejectsSnapshotMismatch(t *testing.T) {
_, err = svc.GetPublicOrderByResumeToken(ctx, token) _, err = svc.GetPublicOrderByResumeToken(ctx, token)
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "resume token") require.Equal(t, "INVALID_RESUME_TOKEN", infraerrors.Reason(err))
} }
func TestGetPublicOrderByResumeTokenUsesSnapshotAuthorityWhenColumnsDiffer(t *testing.T) { func TestGetPublicOrderByResumeTokenUsesSnapshotAuthorityWhenColumnsDiffer(t *testing.T) {
@@ -302,3 +303,13 @@ func TestVerifyOrderPublicDoesNotCheckUpstreamForPendingOrder(t *testing.T) {
require.Equal(t, order.ID, got.ID) require.Equal(t, order.ID, got.ID)
require.Equal(t, 0, provider.queryCount) require.Equal(t, 0, provider.queryCount)
} }
func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) {
svc := &PaymentService{
entClient: newPaymentConfigServiceTestClient(t),
}
_, err := svc.VerifyOrderPublic(context.Background(), " ")
require.Error(t, err)
require.Equal(t, "INVALID_OUT_TRADE_NO", infraerrors.Reason(err))
}

View File

@@ -1,6 +1,7 @@
package service package service
import ( import (
"bytes"
"context" "context"
"crypto/hmac" "crypto/hmac"
"crypto/sha256" "crypto/sha256"
@@ -68,6 +69,7 @@ type WeChatPaymentResumeClaims struct {
type PaymentResumeService struct { type PaymentResumeService struct {
signingKey []byte signingKey []byte
verifyKeys [][]byte
} }
type visibleMethodLoadBalancer struct { type visibleMethodLoadBalancer struct {
@@ -75,8 +77,29 @@ type visibleMethodLoadBalancer struct {
configService *PaymentConfigService configService *PaymentConfigService
} }
func NewPaymentResumeService(signingKey []byte) *PaymentResumeService { func NewPaymentResumeService(signingKey []byte, verifyFallbacks ...[]byte) *PaymentResumeService {
return &PaymentResumeService{signingKey: signingKey} svc := &PaymentResumeService{}
if len(signingKey) > 0 {
svc.signingKey = append([]byte(nil), signingKey...)
svc.verifyKeys = append(svc.verifyKeys, svc.signingKey)
}
for _, fallback := range verifyFallbacks {
if len(fallback) == 0 {
continue
}
cloned := append([]byte(nil), fallback...)
duplicate := false
for _, existing := range svc.verifyKeys {
if bytes.Equal(existing, cloned) {
duplicate = true
break
}
}
if !duplicate {
svc.verifyKeys = append(svc.verifyKeys, cloned)
}
}
return svc
} }
func (s *PaymentResumeService) isSigningConfigured() bool { func (s *PaymentResumeService) isSigningConfigured() bool {
@@ -410,7 +433,7 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
if len(parts) != 2 || parts[0] == "" || parts[1] == "" { if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed") return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed")
} }
if !hmac.Equal([]byte(parts[1]), []byte(s.sign(parts[0]))) { if !s.verifySignature(parts[0], parts[1]) {
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch") return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch")
} }
payload, err := base64.RawURLEncoding.DecodeString(parts[0]) payload, err := base64.RawURLEncoding.DecodeString(parts[0])
@@ -420,6 +443,18 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
return json.Unmarshal(payload, dest) return json.Unmarshal(payload, dest)
} }
func (s *PaymentResumeService) verifySignature(payload string, signature string) bool {
if s == nil {
return false
}
for _, key := range s.verifyKeys {
if hmac.Equal([]byte(signature), []byte(signPaymentResumePayload(payload, key))) {
return true
}
}
return false
}
func validatePaymentResumeExpiry(expiresAt int64, code, message string) error { func validatePaymentResumeExpiry(expiresAt int64, code, message string) error {
if expiresAt <= 0 { if expiresAt <= 0 {
return nil return nil
@@ -431,7 +466,11 @@ func validatePaymentResumeExpiry(expiresAt int64, code, message string) error {
} }
func (s *PaymentResumeService) sign(payload string) string { func (s *PaymentResumeService) sign(payload string) string {
mac := hmac.New(sha256.New, s.signingKey) return signPaymentResumePayload(payload, s.signingKey)
}
func signPaymentResumePayload(payload string, key []byte) string {
mac := hmac.New(sha256.New, key)
_, _ = mac.Write([]byte(payload)) _, _ = mac.Write([]byte(payload))
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
} }

View File

@@ -334,6 +334,59 @@ func TestParseWeChatPaymentResumeTokenRejectsExpiredToken(t *testing.T) {
} }
} }
func TestPaymentServiceParseWeChatPaymentResumeTokenUsesExplicitSigningKey(t *testing.T) {
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key")
token, err := NewPaymentResumeService([]byte("explicit-payment-resume-signing-key")).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
OpenID: "openid-explicit-key",
PaymentType: payment.TypeWxpay,
})
if err != nil {
t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
}
svc := &PaymentService{
configService: &PaymentConfigService{
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
},
}
claims, err := svc.ParseWeChatPaymentResumeToken(token)
if err != nil {
t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
}
if claims.OpenID != "openid-explicit-key" {
t.Fatalf("openid = %q, want %q", claims.OpenID, "openid-explicit-key")
}
}
func TestPaymentServiceParseWeChatPaymentResumeTokenAcceptsLegacyEncryptionKeyDuringMigration(t *testing.T) {
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key")
legacyKey := []byte("0123456789abcdef0123456789abcdef")
token, err := NewPaymentResumeService(legacyKey).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
OpenID: "openid-legacy-key",
PaymentType: payment.TypeWxpay,
})
if err != nil {
t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
}
svc := &PaymentService{
configService: &PaymentConfigService{
encryptionKey: legacyKey,
},
}
claims, err := svc.ParseWeChatPaymentResumeToken(token)
if err != nil {
t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
}
if claims.OpenID != "openid-legacy-key" {
t.Fatalf("openid = %q, want %q", claims.OpenID, "openid-legacy-key")
}
}
func TestNormalizeVisibleMethodSource(t *testing.T) { func TestNormalizeVisibleMethodSource(t *testing.T) {
t.Parallel() t.Parallel()
@@ -424,14 +477,14 @@ func TestVisibleMethodLoadBalancerUsesConfiguredSourceWhenMultipleProvidersEnabl
t.Parallel() t.Parallel()
tests := []struct { tests := []struct {
name string name string
method payment.PaymentType method payment.PaymentType
officialName string officialName string
officialTypes string officialTypes string
easyPayName string easyPayName string
easyPayTypes string easyPayTypes string
sourceSetting string sourceSetting string
wantProvider string wantProvider string
}{ }{
{ {
name: "alipay uses official source", name: "alipay uses official source",
@@ -487,7 +540,7 @@ func TestVisibleMethodLoadBalancerUsesConfiguredSourceWhenMultipleProvidersEnabl
officialProviderKey = payment.TypeWxpay officialProviderKey = payment.TypeWxpay
} }
_, err = client.PaymentProviderInstance.Create(). _, err := client.PaymentProviderInstance.Create().
SetProviderKey(officialProviderKey). SetProviderKey(officialProviderKey).
SetName(tt.officialName). SetName(tt.officialName).
SetConfig("{}"). SetConfig("{}").

View File

@@ -1,10 +1,14 @@
package service package service
import ( import (
"bytes"
"context" "context"
"encoding/hex"
"fmt" "fmt"
"log/slog" "log/slog"
"math/rand/v2" "math/rand/v2"
"os"
"strings"
"sync" "sync"
"time" "time"
@@ -44,6 +48,8 @@ const (
orderIDPrefix = "sub2_" orderIDPrefix = "sub2_"
) )
const paymentResumeSigningKeyEnv = "PAYMENT_RESUME_SIGNING_KEY"
// --- Types --- // --- Types ---
// generateOutTradeNo creates a unique external order ID for payment providers. // generateOutTradeNo creates a unique external order ID for payment providers.
@@ -179,7 +185,7 @@ type PaymentService struct {
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService { func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo} svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
svc.resumeService = NewPaymentResumeService(psResumeSigningKey(configService)) svc.resumeService = psNewPaymentResumeService(configService)
return svc return svc
} }
@@ -259,16 +265,54 @@ func (s *PaymentService) paymentResume() *PaymentResumeService {
if s.resumeService != nil { if s.resumeService != nil {
return s.resumeService return s.resumeService
} }
return NewPaymentResumeService(psResumeSigningKey(s.configService)) return psNewPaymentResumeService(s.configService)
}
func psNewPaymentResumeService(configService *PaymentConfigService) *PaymentResumeService {
signingKey, verifyFallbacks := psResumeSigningKeys(configService)
return NewPaymentResumeService(signingKey, verifyFallbacks...)
} }
func psResumeSigningKey(configService *PaymentConfigService) []byte { func psResumeSigningKey(configService *PaymentConfigService) []byte {
signingKey, _ := psResumeSigningKeys(configService)
return signingKey
}
func psResumeSigningKeys(configService *PaymentConfigService) ([]byte, [][]byte) {
signingKey := parsePaymentResumeSigningKey(os.Getenv(paymentResumeSigningKeyEnv))
legacyKey := psResumeLegacyVerificationKey(configService)
if len(signingKey) == 0 {
if len(legacyKey) == 0 {
return nil, nil
}
return legacyKey, nil
}
if len(legacyKey) == 0 || bytes.Equal(legacyKey, signingKey) {
return signingKey, nil
}
return signingKey, [][]byte{legacyKey}
}
func psResumeLegacyVerificationKey(configService *PaymentConfigService) []byte {
if configService == nil { if configService == nil {
return nil return nil
} }
return configService.encryptionKey return configService.encryptionKey
} }
func parsePaymentResumeSigningKey(raw string) []byte {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
if len(raw) >= 64 && len(raw)%2 == 0 {
if decoded, err := hex.DecodeString(raw); err == nil && len(decoded) > 0 {
return decoded
}
}
return []byte(raw)
}
func psSliceContains(sl []string, s string) bool { func psSliceContains(sl []string, s string) bool {
for _, v := range sl { for _, v := range sl {
if v == s { if v == s {

View File

@@ -82,6 +82,41 @@ func filterEnabledVisibleMethodInstances(instances []*dbent.PaymentProviderInsta
return filtered return filtered
} }
func filterVisibleMethodInstancesByProviderKey(instances []*dbent.PaymentProviderInstance, method string, providerKey string) []*dbent.PaymentProviderInstance {
filtered := make([]*dbent.PaymentProviderInstance, 0, len(instances))
for _, inst := range instances {
if !providerSupportsVisibleMethod(inst, method) {
continue
}
if !strings.EqualFold(strings.TrimSpace(inst.ProviderKey), strings.TrimSpace(providerKey)) {
continue
}
filtered = append(filtered, inst)
}
return filtered
}
func distinctVisibleMethodProviderKeys(instances []*dbent.PaymentProviderInstance) []string {
seen := make(map[string]struct{}, len(instances))
keys := make([]string, 0, len(instances))
for _, inst := range instances {
if inst == nil {
continue
}
key := strings.TrimSpace(inst.ProviderKey)
if key == "" {
continue
}
normalized := strings.ToLower(key)
if _, ok := seen[normalized]; ok {
continue
}
seen[normalized] = struct{}{}
keys = append(keys, key)
}
return keys
}
func selectVisibleMethodInstanceByProviderKey(instances []*dbent.PaymentProviderInstance, providerKey string) *dbent.PaymentProviderInstance { func selectVisibleMethodInstanceByProviderKey(instances []*dbent.PaymentProviderInstance, providerKey string) *dbent.PaymentProviderInstance {
providerKey = strings.TrimSpace(providerKey) providerKey = strings.TrimSpace(providerKey)
if providerKey == "" { if providerKey == "" {
@@ -117,32 +152,10 @@ func (s *PaymentConfigService) validateVisibleMethodEnablementConflicts(
supportedTypes string, supportedTypes string,
enabled bool, enabled bool,
) error { ) error {
if s == nil || s.entClient == nil || !enabled { // Visible methods are selected by configured source (official/easypay),
return nil // so multiple enabled providers can intentionally claim the same user-facing
} // method. Order creation and limits will route through the configured source.
_, _, _, _, _ = ctx, excludeID, providerKey, supportedTypes, enabled
claimedMethods := enabledVisibleMethodsForProvider(providerKey, supportedTypes)
if len(claimedMethods) == 0 {
return nil
}
query := s.entClient.PaymentProviderInstance.Query().
Where(paymentproviderinstance.EnabledEQ(true))
if excludeID > 0 {
query = query.Where(paymentproviderinstance.IDNEQ(excludeID))
}
instances, err := query.All(ctx)
if err != nil {
return fmt.Errorf("query enabled payment providers: %w", err)
}
for _, method := range claimedMethods {
for _, inst := range instances {
if providerSupportsVisibleMethod(inst, method) {
return buildPaymentProviderConflictError(method, inst)
}
}
}
return nil return nil
} }
@@ -172,6 +185,32 @@ func (s *PaymentConfigService) resolveVisibleMethodSourceProviderKey(ctx context
return providerKey, nil return providerKey, nil
} }
func (s *PaymentConfigService) resolveVisibleMethodProviderKey(
ctx context.Context,
method string,
matching []*dbent.PaymentProviderInstance,
) (string, error) {
switch providerKeys := distinctVisibleMethodProviderKeys(matching); len(providerKeys) {
case 0:
return "", nil
case 1:
return strings.TrimSpace(providerKeys[0]), nil
default:
providerKey, err := s.resolveVisibleMethodSourceProviderKey(ctx, method)
if err != nil {
return "", err
}
selected := selectVisibleMethodInstanceByProviderKey(matching, providerKey)
if selected == nil {
return "", infraerrors.BadRequest(
"INVALID_PAYMENT_VISIBLE_METHOD_SOURCE",
fmt.Sprintf("%s source has no enabled provider instance", method),
)
}
return strings.TrimSpace(selected.ProviderKey), nil
}
}
func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance( func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance(
ctx context.Context, ctx context.Context,
method string, method string,
@@ -194,23 +233,9 @@ func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance(
} }
matching := filterEnabledVisibleMethodInstances(instances, method) matching := filterEnabledVisibleMethodInstances(instances, method)
switch len(matching) { providerKey, err := s.resolveVisibleMethodProviderKey(ctx, method, matching)
case 0: if err != nil {
return nil, nil return nil, err
case 1:
return matching[0], nil
default:
providerKey, err := s.resolveVisibleMethodSourceProviderKey(ctx, method)
if err != nil {
return nil, err
}
selected := selectVisibleMethodInstanceByProviderKey(matching, providerKey)
if selected == nil {
return nil, infraerrors.BadRequest(
"INVALID_PAYMENT_VISIBLE_METHOD_SOURCE",
fmt.Sprintf("%s source has no enabled provider instance", method),
)
}
return selected, nil
} }
return selectVisibleMethodInstanceByProviderKey(matching, providerKey), nil
} }