fix(payment): support source routing and compatible resume signing
This commit is contained in:
@@ -164,9 +164,8 @@ func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
|
||||
|
||||
db, err := sql.Open("sqlite", "file:payment_handler_public_resolve?mode=memory&cache=shared")
|
||||
require.NoError(t, err)
|
||||
@@ -250,3 +249,120 @@ func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing
|
||||
require.Contains(t, resp.Data, "expires_at")
|
||||
require.Contains(t, resp.Data, "refund_amount")
|
||||
}
|
||||
|
||||
func TestResolveOrderPublicByResumeTokenReturnsBadRequestForMismatchedToken(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
|
||||
|
||||
db, err := sql.Open("sqlite", "file:payment_handler_public_resolve_mismatch?mode=memory&cache=shared")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
|
||||
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
||||
require.NoError(t, err)
|
||||
|
||||
drv := entsql.OpenDB(dialect.SQLite, db)
|
||||
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
|
||||
user, err := client.User.Create().
|
||||
SetEmail("public-resolve-mismatch@example.com").
|
||||
SetPasswordHash("hash").
|
||||
SetUsername("public-resolve-mismatch-user").
|
||||
Save(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
order, err := client.PaymentOrder.Create().
|
||||
SetUserID(user.ID).
|
||||
SetUserEmail(user.Email).
|
||||
SetUserName(user.Username).
|
||||
SetAmount(100).
|
||||
SetPayAmount(103).
|
||||
SetFeeRate(0.03).
|
||||
SetRechargeCode("PUBLIC-RESOLVE-MISMATCH").
|
||||
SetOutTradeNo("resolve-order-mismatch-no").
|
||||
SetPaymentType(payment.TypeAlipay).
|
||||
SetPaymentTradeNo("trade-public-resolve-mismatch").
|
||||
SetOrderType(payment.OrderTypeBalance).
|
||||
SetStatus(service.OrderStatusPaid).
|
||||
SetExpiresAt(time.Now().Add(time.Hour)).
|
||||
SetPaidAt(time.Now()).
|
||||
SetClientIP("127.0.0.1").
|
||||
SetSrcHost("api.example.com").
|
||||
Save(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
resumeSvc := service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
|
||||
token, err := resumeSvc.CreateToken(service.ResumeTokenClaims{
|
||||
OrderID: order.ID,
|
||||
UserID: user.ID + 999,
|
||||
PaymentType: payment.TypeAlipay,
|
||||
CanonicalReturnURL: "https://app.example.com/payment/result",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
|
||||
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil)
|
||||
h := NewPaymentHandler(paymentSvc, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/payment/public/orders/resolve",
|
||||
bytes.NewBufferString(`{"resume_token":"`+token+`"}`),
|
||||
)
|
||||
ctx.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.ResolveOrderPublicByResumeToken(ctx)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
|
||||
var resp struct {
|
||||
Code int `json:"code"`
|
||||
Reason string `json:"reason"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusBadRequest, resp.Code)
|
||||
require.Equal(t, "INVALID_RESUME_TOKEN", resp.Reason)
|
||||
}
|
||||
|
||||
func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
db, err := sql.Open("sqlite", "file:payment_handler_public_verify_blank?mode=memory&cache=shared")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
|
||||
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
||||
require.NoError(t, err)
|
||||
|
||||
drv := entsql.OpenDB(dialect.SQLite, db)
|
||||
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
|
||||
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil)
|
||||
h := NewPaymentHandler(paymentSvc, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/payment/public/orders/verify",
|
||||
bytes.NewBufferString(`{"out_trade_no":" "}`),
|
||||
)
|
||||
ctx.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.VerifyOrderPublic(ctx)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
|
||||
var resp struct {
|
||||
Code int `json:"code"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusBadRequest, resp.Code)
|
||||
require.Equal(t, "INVALID_OUT_TRADE_NO", resp.Reason)
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
|
||||
return nil, fmt.Errorf("query provider instances: %w", err)
|
||||
}
|
||||
typeInstances := pcGroupByPaymentType(instances)
|
||||
typeInstances = pcApplyEnabledVisibleMethodInstances(typeInstances, instances)
|
||||
typeInstances = s.pcApplyEnabledVisibleMethodInstances(ctx, typeInstances, instances)
|
||||
resp := &MethodLimitsResponse{
|
||||
Methods: make(map[string]MethodLimits, len(typeInstances)),
|
||||
}
|
||||
@@ -32,7 +32,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func pcApplyEnabledVisibleMethodInstances(typeInstances map[string][]*dbent.PaymentProviderInstance, instances []*dbent.PaymentProviderInstance) map[string][]*dbent.PaymentProviderInstance {
|
||||
func (s *PaymentConfigService) pcApplyEnabledVisibleMethodInstances(ctx context.Context, typeInstances map[string][]*dbent.PaymentProviderInstance, instances []*dbent.PaymentProviderInstance) map[string][]*dbent.PaymentProviderInstance {
|
||||
if len(typeInstances) == 0 {
|
||||
return typeInstances
|
||||
}
|
||||
@@ -44,11 +44,17 @@ func pcApplyEnabledVisibleMethodInstances(typeInstances map[string][]*dbent.Paym
|
||||
|
||||
for _, method := range []string{payment.TypeAlipay, payment.TypeWxpay} {
|
||||
matching := filterEnabledVisibleMethodInstances(instances, method)
|
||||
if len(matching) != 1 {
|
||||
providerKey, err := s.resolveVisibleMethodProviderKey(ctx, method, matching)
|
||||
if err != nil || providerKey == "" {
|
||||
delete(filtered, method)
|
||||
continue
|
||||
}
|
||||
filtered[method] = []*dbent.PaymentProviderInstance{matching[0]}
|
||||
selectedInstances := filterVisibleMethodInstancesByProviderKey(instances, method, providerKey)
|
||||
if len(selectedInstances) == 0 {
|
||||
delete(filtered, method)
|
||||
continue
|
||||
}
|
||||
filtered[method] = selectedInstances
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
@@ -301,7 +301,35 @@ func TestPcInstanceTypeLimits(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetAvailableMethodLimitsHidesConflictingVisibleMethodProviders(t *testing.T) {
|
||||
func TestGetAvailableMethodLimitsUsesConfiguredVisibleMethodSource(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sourceSetting string
|
||||
wantAlipaySingleMin float64
|
||||
wantAlipaySingleMax float64
|
||||
wantGlobalMin float64
|
||||
wantGlobalMax float64
|
||||
}{
|
||||
{
|
||||
name: "official source",
|
||||
sourceSetting: VisibleMethodSourceOfficialAlipay,
|
||||
wantAlipaySingleMin: 10,
|
||||
wantAlipaySingleMax: 100,
|
||||
wantGlobalMin: 10,
|
||||
wantGlobalMax: 300,
|
||||
},
|
||||
{
|
||||
name: "easypay source",
|
||||
sourceSetting: VisibleMethodSourceEasyPayAlipay,
|
||||
wantAlipaySingleMin: 20,
|
||||
wantAlipaySingleMax: 200,
|
||||
wantGlobalMin: 20,
|
||||
wantGlobalMax: 300,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := newPaymentConfigServiceTestClient(t)
|
||||
|
||||
@@ -341,6 +369,11 @@ func TestGetAvailableMethodLimitsHidesConflictingVisibleMethodProviders(t *testi
|
||||
|
||||
svc := &PaymentConfigService{
|
||||
entClient: client,
|
||||
settingRepo: &paymentConfigSettingRepoStub{
|
||||
values: map[string]string{
|
||||
SettingPaymentVisibleMethodAlipaySource: tt.sourceSetting,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := svc.GetAvailableMethodLimits(ctx)
|
||||
@@ -348,8 +381,12 @@ func TestGetAvailableMethodLimitsHidesConflictingVisibleMethodProviders(t *testi
|
||||
t.Fatalf("GetAvailableMethodLimits returned error: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := resp.Methods[payment.TypeAlipay]; ok {
|
||||
t.Fatalf("alipay should be hidden when multiple enabled providers claim it, got %v", resp.Methods[payment.TypeAlipay])
|
||||
alipayLimits, ok := resp.Methods[payment.TypeAlipay]
|
||||
if !ok {
|
||||
t.Fatalf("expected alipay limits to remain visible, got %v", resp.Methods)
|
||||
}
|
||||
if alipayLimits.SingleMin != tt.wantAlipaySingleMin || alipayLimits.SingleMax != tt.wantAlipaySingleMax {
|
||||
t.Fatalf("alipay limits = %+v, want min=%v max=%v", alipayLimits, tt.wantAlipaySingleMin, tt.wantAlipaySingleMax)
|
||||
}
|
||||
|
||||
wxpayLimits, ok := resp.Methods[payment.TypeWxpay]
|
||||
@@ -359,7 +396,9 @@ func TestGetAvailableMethodLimitsHidesConflictingVisibleMethodProviders(t *testi
|
||||
if wxpayLimits.SingleMin != 30 || wxpayLimits.SingleMax != 300 {
|
||||
t.Fatalf("wxpay limits = %+v, want official-only min=30 max=300", wxpayLimits)
|
||||
}
|
||||
if resp.GlobalMin != 30 || resp.GlobalMax != 300 {
|
||||
t.Fatalf("global range = (%v, %v), want (30, 300)", resp.GlobalMin, resp.GlobalMax)
|
||||
if resp.GlobalMin != tt.wantGlobalMin || resp.GlobalMax != tt.wantGlobalMax {
|
||||
t.Fatalf("global range = (%v, %v), want (%v, %v)", resp.GlobalMin, resp.GlobalMax, tt.wantGlobalMin, tt.wantGlobalMax)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,9 +4,12 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"testing"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -199,7 +202,7 @@ func TestJoinTypes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderInstanceRejectsConflictingVisibleMethodEnablement(t *testing.T) {
|
||||
func TestCreateProviderInstanceAllowsVisibleMethodProvidersFromDifferentSources(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -227,15 +230,14 @@ func TestCreateProviderInstanceRejectsConflictingVisibleMethodEnablement(t *test
|
||||
_, err = svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
|
||||
ProviderKey: "alipay",
|
||||
Name: "Official Alipay",
|
||||
Config: map[string]string{"appId": "app-1"},
|
||||
Config: map[string]string{"appId": "app-1", "privateKey": "private-key"},
|
||||
SupportedTypes: []string{"alipay"},
|
||||
Enabled: true,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "PAYMENT_PROVIDER_CONFLICT", infraerrors.Reason(err))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t *testing.T) {
|
||||
func TestUpdateProviderInstanceAllowsEnablingVisibleMethodProviderFromDifferentSource(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -264,7 +266,7 @@ func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t
|
||||
candidate, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
|
||||
ProviderKey: "wxpay",
|
||||
Name: "Official WeChat",
|
||||
Config: map[string]string{"appId": "wx-app"},
|
||||
Config: validWxpayProviderConfig(t),
|
||||
SupportedTypes: []string{"wxpay"},
|
||||
Enabled: false,
|
||||
})
|
||||
@@ -273,8 +275,7 @@ func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t
|
||||
_, err = svc.UpdateProviderInstance(ctx, candidate.ID, UpdateProviderInstanceRequest{
|
||||
Enabled: boolPtrValue(true),
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "PAYMENT_PROVIDER_CONFLICT", infraerrors.Reason(err))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) {
|
||||
@@ -317,3 +318,25 @@ func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) {
|
||||
func boolPtrValue(v bool) *bool {
|
||||
return &v
|
||||
}
|
||||
|
||||
func validWxpayProviderConfig(t *testing.T) map[string]string {
|
||||
t.Helper()
|
||||
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
privDER, err := x509.MarshalPKCS8PrivateKey(key)
|
||||
require.NoError(t, err)
|
||||
pubDER, err := x509.MarshalPKIXPublicKey(&key.PublicKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
return map[string]string{
|
||||
"appId": "wx-app-test",
|
||||
"mchId": "mch-test",
|
||||
"privateKey": string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privDER})),
|
||||
"apiV3Key": "12345678901234567890123456789012",
|
||||
"publicKey": string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDER})),
|
||||
"publicKeyId": "public-key-id-test",
|
||||
"certSerial": "cert-serial-test",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -234,6 +234,10 @@ func paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, current
|
||||
// if a payment was made, and processes it if so. This handles the case where
|
||||
// the provider's notify callback was missed (e.g. EasyPay popup mode).
|
||||
func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo string, userID int64) (*dbent.PaymentOrder, error) {
|
||||
outTradeNo, err := normalizeOrderLookupOutTradeNo(outTradeNo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
o, err := s.entClient.PaymentOrder.Query().
|
||||
Where(paymentorder.OutTradeNo(outTradeNo)).
|
||||
Only(ctx)
|
||||
@@ -261,6 +265,10 @@ func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo
|
||||
// triggering any upstream reconciliation. Signed resume-token recovery is the
|
||||
// only public recovery path allowed to query upstream state.
|
||||
func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) {
|
||||
outTradeNo, err := normalizeOrderLookupOutTradeNo(outTradeNo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
o, err := s.entClient.PaymentOrder.Query().
|
||||
Where(paymentorder.OutTradeNo(outTradeNo)).
|
||||
Only(ctx)
|
||||
@@ -270,6 +278,27 @@ func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo strin
|
||||
return o, nil
|
||||
}
|
||||
|
||||
func normalizeOrderLookupOutTradeNo(raw string) (string, error) {
|
||||
outTradeNo := strings.TrimSpace(raw)
|
||||
if outTradeNo == "" {
|
||||
return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is required")
|
||||
}
|
||||
if len(outTradeNo) > 64 {
|
||||
return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is invalid")
|
||||
}
|
||||
for _, ch := range outTradeNo {
|
||||
switch {
|
||||
case ch >= 'a' && ch <= 'z':
|
||||
case ch >= 'A' && ch <= 'Z':
|
||||
case ch >= '0' && ch <= '9':
|
||||
case ch == '_' || ch == '-':
|
||||
default:
|
||||
return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is invalid")
|
||||
}
|
||||
}
|
||||
return outTradeNo, nil
|
||||
}
|
||||
|
||||
func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) {
|
||||
now := time.Now()
|
||||
orders, err := s.entClient.PaymentOrder.Query().Where(paymentorder.StatusEQ(OrderStatusPending), paymentorder.ExpiresAtLTE(now)).All(ctx)
|
||||
|
||||
@@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -91,6 +92,8 @@ func TestBuildCreateOrderResponseCopiesJSAPIPayload(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) {
|
||||
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
|
||||
|
||||
svc := newWeChatPaymentOAuthTestService(map[string]string{
|
||||
SettingKeyWeChatConnectEnabled: "true",
|
||||
SettingKeyWeChatConnectAppID: "wx123456",
|
||||
@@ -198,6 +201,44 @@ func TestMaybeBuildWeChatOAuthRequiredResponseRequiresResumeSigningKey(t *testin
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaybeBuildWeChatOAuthRequiredResponseFallsBackToConfiguredLegacySigningKey(t *testing.T) {
|
||||
svc := &PaymentService{
|
||||
configService: &PaymentConfigService{
|
||||
settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{
|
||||
SettingKeyWeChatConnectEnabled: "true",
|
||||
SettingKeyWeChatConnectAppID: "wx123456",
|
||||
SettingKeyWeChatConnectAppSecret: "wechat-secret",
|
||||
SettingKeyWeChatConnectMode: "mp",
|
||||
SettingKeyWeChatConnectScopes: "snsapi_base",
|
||||
SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
|
||||
SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
|
||||
}},
|
||||
// Legacy stable signing key remains available for no-config upgrade compatibility.
|
||||
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{
|
||||
Amount: 12.5,
|
||||
PaymentType: payment.TypeWxpay,
|
||||
IsWeChatBrowser: true,
|
||||
SrcURL: "https://merchant.example/payment?from=wechat",
|
||||
OrderType: payment.OrderTypeBalance,
|
||||
}, 12.5, 12.88, 0.03)
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("expected oauth-required response, got nil")
|
||||
}
|
||||
if resp.ResultType != payment.CreatePaymentResultOAuthRequired {
|
||||
t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultOAuthRequired)
|
||||
}
|
||||
if resp.OAuth == nil || strings.TrimSpace(resp.OAuth.AuthorizeURL) == "" {
|
||||
t.Fatalf("expected oauth redirect payload, got %+v", resp.OAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t *testing.T) {
|
||||
svc := newWeChatPaymentOAuthTestService(map[string]string{
|
||||
SettingKeyWeChatConnectEnabled: "true",
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token string) (*dbent.PaymentOrder, error) {
|
||||
@@ -16,10 +17,13 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
|
||||
|
||||
order, err := s.entClient.PaymentOrder.Get(ctx, claims.OrderID)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
|
||||
}
|
||||
return nil, fmt.Errorf("get order by resume token: %w", err)
|
||||
}
|
||||
if claims.UserID > 0 && order.UserID != claims.UserID {
|
||||
return nil, fmt.Errorf("resume token user mismatch")
|
||||
return nil, invalidResumeTokenMatchError()
|
||||
}
|
||||
snapshot := psOrderProviderSnapshot(order)
|
||||
orderProviderInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID))
|
||||
@@ -33,13 +37,13 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
|
||||
}
|
||||
}
|
||||
if claims.ProviderInstanceID != "" && orderProviderInstanceID != claims.ProviderInstanceID {
|
||||
return nil, fmt.Errorf("resume token provider instance mismatch")
|
||||
return nil, invalidResumeTokenMatchError()
|
||||
}
|
||||
if claims.ProviderKey != "" && orderProviderKey != claims.ProviderKey {
|
||||
return nil, fmt.Errorf("resume token provider key mismatch")
|
||||
if claims.ProviderKey != "" && !strings.EqualFold(orderProviderKey, claims.ProviderKey) {
|
||||
return nil, invalidResumeTokenMatchError()
|
||||
}
|
||||
if claims.PaymentType != "" && strings.TrimSpace(order.PaymentType) != claims.PaymentType {
|
||||
return nil, fmt.Errorf("resume token payment type mismatch")
|
||||
if claims.PaymentType != "" && NormalizeVisibleMethod(order.PaymentType) != NormalizeVisibleMethod(claims.PaymentType) {
|
||||
return nil, invalidResumeTokenMatchError()
|
||||
}
|
||||
if order.Status == OrderStatusPending || order.Status == OrderStatusExpired {
|
||||
result := s.checkPaid(ctx, order)
|
||||
@@ -54,6 +58,10 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
|
||||
return order, nil
|
||||
}
|
||||
|
||||
func invalidResumeTokenMatchError() error {
|
||||
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token does not match the payment order")
|
||||
}
|
||||
|
||||
func (s *PaymentService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) {
|
||||
return s.paymentResume().ParseWeChatPaymentResumeToken(strings.TrimSpace(token))
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -143,7 +144,7 @@ func TestGetPublicOrderByResumeTokenRejectsSnapshotMismatch(t *testing.T) {
|
||||
|
||||
_, err = svc.GetPublicOrderByResumeToken(ctx, token)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "resume token")
|
||||
require.Equal(t, "INVALID_RESUME_TOKEN", infraerrors.Reason(err))
|
||||
}
|
||||
|
||||
func TestGetPublicOrderByResumeTokenUsesSnapshotAuthorityWhenColumnsDiffer(t *testing.T) {
|
||||
@@ -302,3 +303,13 @@ func TestVerifyOrderPublicDoesNotCheckUpstreamForPendingOrder(t *testing.T) {
|
||||
require.Equal(t, order.ID, got.ID)
|
||||
require.Equal(t, 0, provider.queryCount)
|
||||
}
|
||||
|
||||
func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) {
|
||||
svc := &PaymentService{
|
||||
entClient: newPaymentConfigServiceTestClient(t),
|
||||
}
|
||||
|
||||
_, err := svc.VerifyOrderPublic(context.Background(), " ")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "INVALID_OUT_TRADE_NO", infraerrors.Reason(err))
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
@@ -68,6 +69,7 @@ type WeChatPaymentResumeClaims struct {
|
||||
|
||||
type PaymentResumeService struct {
|
||||
signingKey []byte
|
||||
verifyKeys [][]byte
|
||||
}
|
||||
|
||||
type visibleMethodLoadBalancer struct {
|
||||
@@ -75,8 +77,29 @@ type visibleMethodLoadBalancer struct {
|
||||
configService *PaymentConfigService
|
||||
}
|
||||
|
||||
func NewPaymentResumeService(signingKey []byte) *PaymentResumeService {
|
||||
return &PaymentResumeService{signingKey: signingKey}
|
||||
func NewPaymentResumeService(signingKey []byte, verifyFallbacks ...[]byte) *PaymentResumeService {
|
||||
svc := &PaymentResumeService{}
|
||||
if len(signingKey) > 0 {
|
||||
svc.signingKey = append([]byte(nil), signingKey...)
|
||||
svc.verifyKeys = append(svc.verifyKeys, svc.signingKey)
|
||||
}
|
||||
for _, fallback := range verifyFallbacks {
|
||||
if len(fallback) == 0 {
|
||||
continue
|
||||
}
|
||||
cloned := append([]byte(nil), fallback...)
|
||||
duplicate := false
|
||||
for _, existing := range svc.verifyKeys {
|
||||
if bytes.Equal(existing, cloned) {
|
||||
duplicate = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !duplicate {
|
||||
svc.verifyKeys = append(svc.verifyKeys, cloned)
|
||||
}
|
||||
}
|
||||
return svc
|
||||
}
|
||||
|
||||
func (s *PaymentResumeService) isSigningConfigured() bool {
|
||||
@@ -410,7 +433,7 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed")
|
||||
}
|
||||
if !hmac.Equal([]byte(parts[1]), []byte(s.sign(parts[0]))) {
|
||||
if !s.verifySignature(parts[0], parts[1]) {
|
||||
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch")
|
||||
}
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
@@ -420,6 +443,18 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
|
||||
return json.Unmarshal(payload, dest)
|
||||
}
|
||||
|
||||
func (s *PaymentResumeService) verifySignature(payload string, signature string) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
for _, key := range s.verifyKeys {
|
||||
if hmac.Equal([]byte(signature), []byte(signPaymentResumePayload(payload, key))) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func validatePaymentResumeExpiry(expiresAt int64, code, message string) error {
|
||||
if expiresAt <= 0 {
|
||||
return nil
|
||||
@@ -431,7 +466,11 @@ func validatePaymentResumeExpiry(expiresAt int64, code, message string) error {
|
||||
}
|
||||
|
||||
func (s *PaymentResumeService) sign(payload string) string {
|
||||
mac := hmac.New(sha256.New, s.signingKey)
|
||||
return signPaymentResumePayload(payload, s.signingKey)
|
||||
}
|
||||
|
||||
func signPaymentResumePayload(payload string, key []byte) string {
|
||||
mac := hmac.New(sha256.New, key)
|
||||
_, _ = mac.Write([]byte(payload))
|
||||
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
|
||||
}
|
||||
|
||||
@@ -334,6 +334,59 @@ func TestParseWeChatPaymentResumeTokenRejectsExpiredToken(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPaymentServiceParseWeChatPaymentResumeTokenUsesExplicitSigningKey(t *testing.T) {
|
||||
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key")
|
||||
|
||||
token, err := NewPaymentResumeService([]byte("explicit-payment-resume-signing-key")).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
|
||||
OpenID: "openid-explicit-key",
|
||||
PaymentType: payment.TypeWxpay,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
|
||||
}
|
||||
|
||||
svc := &PaymentService{
|
||||
configService: &PaymentConfigService{
|
||||
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
|
||||
},
|
||||
}
|
||||
|
||||
claims, err := svc.ParseWeChatPaymentResumeToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
|
||||
}
|
||||
if claims.OpenID != "openid-explicit-key" {
|
||||
t.Fatalf("openid = %q, want %q", claims.OpenID, "openid-explicit-key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPaymentServiceParseWeChatPaymentResumeTokenAcceptsLegacyEncryptionKeyDuringMigration(t *testing.T) {
|
||||
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key")
|
||||
|
||||
legacyKey := []byte("0123456789abcdef0123456789abcdef")
|
||||
token, err := NewPaymentResumeService(legacyKey).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
|
||||
OpenID: "openid-legacy-key",
|
||||
PaymentType: payment.TypeWxpay,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
|
||||
}
|
||||
|
||||
svc := &PaymentService{
|
||||
configService: &PaymentConfigService{
|
||||
encryptionKey: legacyKey,
|
||||
},
|
||||
}
|
||||
|
||||
claims, err := svc.ParseWeChatPaymentResumeToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
|
||||
}
|
||||
if claims.OpenID != "openid-legacy-key" {
|
||||
t.Fatalf("openid = %q, want %q", claims.OpenID, "openid-legacy-key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeVisibleMethodSource(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -487,7 +540,7 @@ func TestVisibleMethodLoadBalancerUsesConfiguredSourceWhenMultipleProvidersEnabl
|
||||
officialProviderKey = payment.TypeWxpay
|
||||
}
|
||||
|
||||
_, err = client.PaymentProviderInstance.Create().
|
||||
_, err := client.PaymentProviderInstance.Create().
|
||||
SetProviderKey(officialProviderKey).
|
||||
SetName(tt.officialName).
|
||||
SetConfig("{}").
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math/rand/v2"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -44,6 +48,8 @@ const (
|
||||
orderIDPrefix = "sub2_"
|
||||
)
|
||||
|
||||
const paymentResumeSigningKeyEnv = "PAYMENT_RESUME_SIGNING_KEY"
|
||||
|
||||
// --- Types ---
|
||||
|
||||
// generateOutTradeNo creates a unique external order ID for payment providers.
|
||||
@@ -179,7 +185,7 @@ type PaymentService struct {
|
||||
|
||||
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
|
||||
svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
|
||||
svc.resumeService = NewPaymentResumeService(psResumeSigningKey(configService))
|
||||
svc.resumeService = psNewPaymentResumeService(configService)
|
||||
return svc
|
||||
}
|
||||
|
||||
@@ -259,16 +265,54 @@ func (s *PaymentService) paymentResume() *PaymentResumeService {
|
||||
if s.resumeService != nil {
|
||||
return s.resumeService
|
||||
}
|
||||
return NewPaymentResumeService(psResumeSigningKey(s.configService))
|
||||
return psNewPaymentResumeService(s.configService)
|
||||
}
|
||||
|
||||
func psNewPaymentResumeService(configService *PaymentConfigService) *PaymentResumeService {
|
||||
signingKey, verifyFallbacks := psResumeSigningKeys(configService)
|
||||
return NewPaymentResumeService(signingKey, verifyFallbacks...)
|
||||
}
|
||||
|
||||
func psResumeSigningKey(configService *PaymentConfigService) []byte {
|
||||
signingKey, _ := psResumeSigningKeys(configService)
|
||||
return signingKey
|
||||
}
|
||||
|
||||
func psResumeSigningKeys(configService *PaymentConfigService) ([]byte, [][]byte) {
|
||||
signingKey := parsePaymentResumeSigningKey(os.Getenv(paymentResumeSigningKeyEnv))
|
||||
legacyKey := psResumeLegacyVerificationKey(configService)
|
||||
if len(signingKey) == 0 {
|
||||
if len(legacyKey) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return legacyKey, nil
|
||||
}
|
||||
if len(legacyKey) == 0 || bytes.Equal(legacyKey, signingKey) {
|
||||
return signingKey, nil
|
||||
}
|
||||
return signingKey, [][]byte{legacyKey}
|
||||
}
|
||||
|
||||
func psResumeLegacyVerificationKey(configService *PaymentConfigService) []byte {
|
||||
if configService == nil {
|
||||
return nil
|
||||
}
|
||||
return configService.encryptionKey
|
||||
}
|
||||
|
||||
func parsePaymentResumeSigningKey(raw string) []byte {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
if len(raw) >= 64 && len(raw)%2 == 0 {
|
||||
if decoded, err := hex.DecodeString(raw); err == nil && len(decoded) > 0 {
|
||||
return decoded
|
||||
}
|
||||
}
|
||||
return []byte(raw)
|
||||
}
|
||||
|
||||
func psSliceContains(sl []string, s string) bool {
|
||||
for _, v := range sl {
|
||||
if v == s {
|
||||
|
||||
@@ -82,6 +82,41 @@ func filterEnabledVisibleMethodInstances(instances []*dbent.PaymentProviderInsta
|
||||
return filtered
|
||||
}
|
||||
|
||||
func filterVisibleMethodInstancesByProviderKey(instances []*dbent.PaymentProviderInstance, method string, providerKey string) []*dbent.PaymentProviderInstance {
|
||||
filtered := make([]*dbent.PaymentProviderInstance, 0, len(instances))
|
||||
for _, inst := range instances {
|
||||
if !providerSupportsVisibleMethod(inst, method) {
|
||||
continue
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(inst.ProviderKey), strings.TrimSpace(providerKey)) {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, inst)
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func distinctVisibleMethodProviderKeys(instances []*dbent.PaymentProviderInstance) []string {
|
||||
seen := make(map[string]struct{}, len(instances))
|
||||
keys := make([]string, 0, len(instances))
|
||||
for _, inst := range instances {
|
||||
if inst == nil {
|
||||
continue
|
||||
}
|
||||
key := strings.TrimSpace(inst.ProviderKey)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
normalized := strings.ToLower(key)
|
||||
if _, ok := seen[normalized]; ok {
|
||||
continue
|
||||
}
|
||||
seen[normalized] = struct{}{}
|
||||
keys = append(keys, key)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func selectVisibleMethodInstanceByProviderKey(instances []*dbent.PaymentProviderInstance, providerKey string) *dbent.PaymentProviderInstance {
|
||||
providerKey = strings.TrimSpace(providerKey)
|
||||
if providerKey == "" {
|
||||
@@ -117,32 +152,10 @@ func (s *PaymentConfigService) validateVisibleMethodEnablementConflicts(
|
||||
supportedTypes string,
|
||||
enabled bool,
|
||||
) error {
|
||||
if s == nil || s.entClient == nil || !enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
claimedMethods := enabledVisibleMethodsForProvider(providerKey, supportedTypes)
|
||||
if len(claimedMethods) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
query := s.entClient.PaymentProviderInstance.Query().
|
||||
Where(paymentproviderinstance.EnabledEQ(true))
|
||||
if excludeID > 0 {
|
||||
query = query.Where(paymentproviderinstance.IDNEQ(excludeID))
|
||||
}
|
||||
instances, err := query.All(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("query enabled payment providers: %w", err)
|
||||
}
|
||||
|
||||
for _, method := range claimedMethods {
|
||||
for _, inst := range instances {
|
||||
if providerSupportsVisibleMethod(inst, method) {
|
||||
return buildPaymentProviderConflictError(method, inst)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Visible methods are selected by configured source (official/easypay),
|
||||
// so multiple enabled providers can intentionally claim the same user-facing
|
||||
// method. Order creation and limits will route through the configured source.
|
||||
_, _, _, _, _ = ctx, excludeID, providerKey, supportedTypes, enabled
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -172,6 +185,32 @@ func (s *PaymentConfigService) resolveVisibleMethodSourceProviderKey(ctx context
|
||||
return providerKey, nil
|
||||
}
|
||||
|
||||
func (s *PaymentConfigService) resolveVisibleMethodProviderKey(
|
||||
ctx context.Context,
|
||||
method string,
|
||||
matching []*dbent.PaymentProviderInstance,
|
||||
) (string, error) {
|
||||
switch providerKeys := distinctVisibleMethodProviderKeys(matching); len(providerKeys) {
|
||||
case 0:
|
||||
return "", nil
|
||||
case 1:
|
||||
return strings.TrimSpace(providerKeys[0]), nil
|
||||
default:
|
||||
providerKey, err := s.resolveVisibleMethodSourceProviderKey(ctx, method)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
selected := selectVisibleMethodInstanceByProviderKey(matching, providerKey)
|
||||
if selected == nil {
|
||||
return "", infraerrors.BadRequest(
|
||||
"INVALID_PAYMENT_VISIBLE_METHOD_SOURCE",
|
||||
fmt.Sprintf("%s source has no enabled provider instance", method),
|
||||
)
|
||||
}
|
||||
return strings.TrimSpace(selected.ProviderKey), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance(
|
||||
ctx context.Context,
|
||||
method string,
|
||||
@@ -194,23 +233,9 @@ func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance(
|
||||
}
|
||||
|
||||
matching := filterEnabledVisibleMethodInstances(instances, method)
|
||||
switch len(matching) {
|
||||
case 0:
|
||||
return nil, nil
|
||||
case 1:
|
||||
return matching[0], nil
|
||||
default:
|
||||
providerKey, err := s.resolveVisibleMethodSourceProviderKey(ctx, method)
|
||||
providerKey, err := s.resolveVisibleMethodProviderKey(ctx, method, matching)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selected := selectVisibleMethodInstanceByProviderKey(matching, providerKey)
|
||||
if selected == nil {
|
||||
return nil, infraerrors.BadRequest(
|
||||
"INVALID_PAYMENT_VISIBLE_METHOD_SOURCE",
|
||||
fmt.Sprintf("%s source has no enabled provider instance", method),
|
||||
)
|
||||
}
|
||||
return selected, nil
|
||||
}
|
||||
return selectVisibleMethodInstanceByProviderKey(matching, providerKey), nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user