fix(payment): support source routing and compatible resume signing
This commit is contained in:
@@ -164,9 +164,8 @@ func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing.T) {
|
func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
|
||||||
|
|
||||||
db, err := sql.Open("sqlite", "file:payment_handler_public_resolve?mode=memory&cache=shared")
|
db, err := sql.Open("sqlite", "file:payment_handler_public_resolve?mode=memory&cache=shared")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -250,3 +249,120 @@ func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing
|
|||||||
require.Contains(t, resp.Data, "expires_at")
|
require.Contains(t, resp.Data, "expires_at")
|
||||||
require.Contains(t, resp.Data, "refund_amount")
|
require.Contains(t, resp.Data, "refund_amount")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestResolveOrderPublicByResumeTokenReturnsBadRequestForMismatchedToken(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
|
||||||
|
|
||||||
|
db, err := sql.Open("sqlite", "file:payment_handler_public_resolve_mismatch?mode=memory&cache=shared")
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() { _ = db.Close() })
|
||||||
|
|
||||||
|
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
drv := entsql.OpenDB(dialect.SQLite, db)
|
||||||
|
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||||
|
t.Cleanup(func() { _ = client.Close() })
|
||||||
|
|
||||||
|
user, err := client.User.Create().
|
||||||
|
SetEmail("public-resolve-mismatch@example.com").
|
||||||
|
SetPasswordHash("hash").
|
||||||
|
SetUsername("public-resolve-mismatch-user").
|
||||||
|
Save(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
order, err := client.PaymentOrder.Create().
|
||||||
|
SetUserID(user.ID).
|
||||||
|
SetUserEmail(user.Email).
|
||||||
|
SetUserName(user.Username).
|
||||||
|
SetAmount(100).
|
||||||
|
SetPayAmount(103).
|
||||||
|
SetFeeRate(0.03).
|
||||||
|
SetRechargeCode("PUBLIC-RESOLVE-MISMATCH").
|
||||||
|
SetOutTradeNo("resolve-order-mismatch-no").
|
||||||
|
SetPaymentType(payment.TypeAlipay).
|
||||||
|
SetPaymentTradeNo("trade-public-resolve-mismatch").
|
||||||
|
SetOrderType(payment.OrderTypeBalance).
|
||||||
|
SetStatus(service.OrderStatusPaid).
|
||||||
|
SetExpiresAt(time.Now().Add(time.Hour)).
|
||||||
|
SetPaidAt(time.Now()).
|
||||||
|
SetClientIP("127.0.0.1").
|
||||||
|
SetSrcHost("api.example.com").
|
||||||
|
Save(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
resumeSvc := service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
|
||||||
|
token, err := resumeSvc.CreateToken(service.ResumeTokenClaims{
|
||||||
|
OrderID: order.ID,
|
||||||
|
UserID: user.ID + 999,
|
||||||
|
PaymentType: payment.TypeAlipay,
|
||||||
|
CanonicalReturnURL: "https://app.example.com/payment/result",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
|
||||||
|
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil)
|
||||||
|
h := NewPaymentHandler(paymentSvc, nil, nil)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(recorder)
|
||||||
|
ctx.Request = httptest.NewRequest(
|
||||||
|
http.MethodPost,
|
||||||
|
"/api/v1/payment/public/orders/resolve",
|
||||||
|
bytes.NewBufferString(`{"resume_token":"`+token+`"}`),
|
||||||
|
)
|
||||||
|
ctx.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
h.ResolveOrderPublicByResumeToken(ctx)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
|
||||||
|
require.Equal(t, http.StatusBadRequest, resp.Code)
|
||||||
|
require.Equal(t, "INVALID_RESUME_TOKEN", resp.Reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
db, err := sql.Open("sqlite", "file:payment_handler_public_verify_blank?mode=memory&cache=shared")
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() { _ = db.Close() })
|
||||||
|
|
||||||
|
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
drv := entsql.OpenDB(dialect.SQLite, db)
|
||||||
|
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||||
|
t.Cleanup(func() { _ = client.Close() })
|
||||||
|
|
||||||
|
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil)
|
||||||
|
h := NewPaymentHandler(paymentSvc, nil, nil)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(recorder)
|
||||||
|
ctx.Request = httptest.NewRequest(
|
||||||
|
http.MethodPost,
|
||||||
|
"/api/v1/payment/public/orders/verify",
|
||||||
|
bytes.NewBufferString(`{"out_trade_no":" "}`),
|
||||||
|
)
|
||||||
|
ctx.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
h.VerifyOrderPublic(ctx)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
}
|
||||||
|
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
|
||||||
|
require.Equal(t, http.StatusBadRequest, resp.Code)
|
||||||
|
require.Equal(t, "INVALID_OUT_TRADE_NO", resp.Reason)
|
||||||
|
}
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
|
|||||||
return nil, fmt.Errorf("query provider instances: %w", err)
|
return nil, fmt.Errorf("query provider instances: %w", err)
|
||||||
}
|
}
|
||||||
typeInstances := pcGroupByPaymentType(instances)
|
typeInstances := pcGroupByPaymentType(instances)
|
||||||
typeInstances = pcApplyEnabledVisibleMethodInstances(typeInstances, instances)
|
typeInstances = s.pcApplyEnabledVisibleMethodInstances(ctx, typeInstances, instances)
|
||||||
resp := &MethodLimitsResponse{
|
resp := &MethodLimitsResponse{
|
||||||
Methods: make(map[string]MethodLimits, len(typeInstances)),
|
Methods: make(map[string]MethodLimits, len(typeInstances)),
|
||||||
}
|
}
|
||||||
@@ -32,7 +32,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func pcApplyEnabledVisibleMethodInstances(typeInstances map[string][]*dbent.PaymentProviderInstance, instances []*dbent.PaymentProviderInstance) map[string][]*dbent.PaymentProviderInstance {
|
func (s *PaymentConfigService) pcApplyEnabledVisibleMethodInstances(ctx context.Context, typeInstances map[string][]*dbent.PaymentProviderInstance, instances []*dbent.PaymentProviderInstance) map[string][]*dbent.PaymentProviderInstance {
|
||||||
if len(typeInstances) == 0 {
|
if len(typeInstances) == 0 {
|
||||||
return typeInstances
|
return typeInstances
|
||||||
}
|
}
|
||||||
@@ -44,11 +44,17 @@ func pcApplyEnabledVisibleMethodInstances(typeInstances map[string][]*dbent.Paym
|
|||||||
|
|
||||||
for _, method := range []string{payment.TypeAlipay, payment.TypeWxpay} {
|
for _, method := range []string{payment.TypeAlipay, payment.TypeWxpay} {
|
||||||
matching := filterEnabledVisibleMethodInstances(instances, method)
|
matching := filterEnabledVisibleMethodInstances(instances, method)
|
||||||
if len(matching) != 1 {
|
providerKey, err := s.resolveVisibleMethodProviderKey(ctx, method, matching)
|
||||||
|
if err != nil || providerKey == "" {
|
||||||
delete(filtered, method)
|
delete(filtered, method)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
filtered[method] = []*dbent.PaymentProviderInstance{matching[0]}
|
selectedInstances := filterVisibleMethodInstancesByProviderKey(instances, method, providerKey)
|
||||||
|
if len(selectedInstances) == 0 {
|
||||||
|
delete(filtered, method)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filtered[method] = selectedInstances
|
||||||
}
|
}
|
||||||
return filtered
|
return filtered
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -301,65 +301,104 @@ func TestPcInstanceTypeLimits(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetAvailableMethodLimitsHidesConflictingVisibleMethodProviders(t *testing.T) {
|
func TestGetAvailableMethodLimitsUsesConfiguredVisibleMethodSource(t *testing.T) {
|
||||||
ctx := context.Background()
|
tests := []struct {
|
||||||
client := newPaymentConfigServiceTestClient(t)
|
name string
|
||||||
|
sourceSetting string
|
||||||
_, err := client.PaymentProviderInstance.Create().
|
wantAlipaySingleMin float64
|
||||||
SetProviderKey(payment.TypeAlipay).
|
wantAlipaySingleMax float64
|
||||||
SetName("Official Alipay").
|
wantGlobalMin float64
|
||||||
SetConfig("{}").
|
wantGlobalMax float64
|
||||||
SetSupportedTypes("alipay").
|
}{
|
||||||
SetLimits(`{"alipay":{"singleMin":10,"singleMax":100}}`).
|
{
|
||||||
SetEnabled(true).
|
name: "official source",
|
||||||
Save(ctx)
|
sourceSetting: VisibleMethodSourceOfficialAlipay,
|
||||||
if err != nil {
|
wantAlipaySingleMin: 10,
|
||||||
t.Fatalf("create official alipay instance: %v", err)
|
wantAlipaySingleMax: 100,
|
||||||
}
|
wantGlobalMin: 10,
|
||||||
_, err = client.PaymentProviderInstance.Create().
|
wantGlobalMax: 300,
|
||||||
SetProviderKey(payment.TypeEasyPay).
|
},
|
||||||
SetName("EasyPay Alipay").
|
{
|
||||||
SetConfig("{}").
|
name: "easypay source",
|
||||||
SetSupportedTypes("alipay").
|
sourceSetting: VisibleMethodSourceEasyPayAlipay,
|
||||||
SetLimits(`{"alipay":{"singleMin":20,"singleMax":200}}`).
|
wantAlipaySingleMin: 20,
|
||||||
SetEnabled(true).
|
wantAlipaySingleMax: 200,
|
||||||
Save(ctx)
|
wantGlobalMin: 20,
|
||||||
if err != nil {
|
wantGlobalMax: 300,
|
||||||
t.Fatalf("create easypay alipay instance: %v", err)
|
},
|
||||||
}
|
|
||||||
_, err = client.PaymentProviderInstance.Create().
|
|
||||||
SetProviderKey(payment.TypeWxpay).
|
|
||||||
SetName("Official WeChat").
|
|
||||||
SetConfig("{}").
|
|
||||||
SetSupportedTypes("wxpay").
|
|
||||||
SetLimits(`{"wxpay":{"singleMin":30,"singleMax":300}}`).
|
|
||||||
SetEnabled(true).
|
|
||||||
Save(ctx)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("create official wxpay instance: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
svc := &PaymentConfigService{
|
for _, tt := range tests {
|
||||||
entClient: client,
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
}
|
ctx := context.Background()
|
||||||
|
client := newPaymentConfigServiceTestClient(t)
|
||||||
|
|
||||||
resp, err := svc.GetAvailableMethodLimits(ctx)
|
_, err := client.PaymentProviderInstance.Create().
|
||||||
if err != nil {
|
SetProviderKey(payment.TypeAlipay).
|
||||||
t.Fatalf("GetAvailableMethodLimits returned error: %v", err)
|
SetName("Official Alipay").
|
||||||
}
|
SetConfig("{}").
|
||||||
|
SetSupportedTypes("alipay").
|
||||||
|
SetLimits(`{"alipay":{"singleMin":10,"singleMax":100}}`).
|
||||||
|
SetEnabled(true).
|
||||||
|
Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create official alipay instance: %v", err)
|
||||||
|
}
|
||||||
|
_, err = client.PaymentProviderInstance.Create().
|
||||||
|
SetProviderKey(payment.TypeEasyPay).
|
||||||
|
SetName("EasyPay Alipay").
|
||||||
|
SetConfig("{}").
|
||||||
|
SetSupportedTypes("alipay").
|
||||||
|
SetLimits(`{"alipay":{"singleMin":20,"singleMax":200}}`).
|
||||||
|
SetEnabled(true).
|
||||||
|
Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create easypay alipay instance: %v", err)
|
||||||
|
}
|
||||||
|
_, err = client.PaymentProviderInstance.Create().
|
||||||
|
SetProviderKey(payment.TypeWxpay).
|
||||||
|
SetName("Official WeChat").
|
||||||
|
SetConfig("{}").
|
||||||
|
SetSupportedTypes("wxpay").
|
||||||
|
SetLimits(`{"wxpay":{"singleMin":30,"singleMax":300}}`).
|
||||||
|
SetEnabled(true).
|
||||||
|
Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create official wxpay instance: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if _, ok := resp.Methods[payment.TypeAlipay]; ok {
|
svc := &PaymentConfigService{
|
||||||
t.Fatalf("alipay should be hidden when multiple enabled providers claim it, got %v", resp.Methods[payment.TypeAlipay])
|
entClient: client,
|
||||||
}
|
settingRepo: &paymentConfigSettingRepoStub{
|
||||||
|
values: map[string]string{
|
||||||
|
SettingPaymentVisibleMethodAlipaySource: tt.sourceSetting,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
wxpayLimits, ok := resp.Methods[payment.TypeWxpay]
|
resp, err := svc.GetAvailableMethodLimits(ctx)
|
||||||
if !ok {
|
if err != nil {
|
||||||
t.Fatalf("expected wxpay limits to remain visible, got %v", resp.Methods)
|
t.Fatalf("GetAvailableMethodLimits returned error: %v", err)
|
||||||
}
|
}
|
||||||
if wxpayLimits.SingleMin != 30 || wxpayLimits.SingleMax != 300 {
|
|
||||||
t.Fatalf("wxpay limits = %+v, want official-only min=30 max=300", wxpayLimits)
|
alipayLimits, ok := resp.Methods[payment.TypeAlipay]
|
||||||
}
|
if !ok {
|
||||||
if resp.GlobalMin != 30 || resp.GlobalMax != 300 {
|
t.Fatalf("expected alipay limits to remain visible, got %v", resp.Methods)
|
||||||
t.Fatalf("global range = (%v, %v), want (30, 300)", resp.GlobalMin, resp.GlobalMax)
|
}
|
||||||
|
if alipayLimits.SingleMin != tt.wantAlipaySingleMin || alipayLimits.SingleMax != tt.wantAlipaySingleMax {
|
||||||
|
t.Fatalf("alipay limits = %+v, want min=%v max=%v", alipayLimits, tt.wantAlipaySingleMin, tt.wantAlipaySingleMax)
|
||||||
|
}
|
||||||
|
|
||||||
|
wxpayLimits, ok := resp.Methods[payment.TypeWxpay]
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected wxpay limits to remain visible, got %v", resp.Methods)
|
||||||
|
}
|
||||||
|
if wxpayLimits.SingleMin != 30 || wxpayLimits.SingleMax != 300 {
|
||||||
|
t.Fatalf("wxpay limits = %+v, want official-only min=30 max=300", wxpayLimits)
|
||||||
|
}
|
||||||
|
if resp.GlobalMin != tt.wantGlobalMin || resp.GlobalMax != tt.wantGlobalMax {
|
||||||
|
t.Fatalf("global range = (%v, %v), want (%v, %v)", resp.GlobalMin, resp.GlobalMax, tt.wantGlobalMin, tt.wantGlobalMax)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,9 +4,12 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@@ -199,7 +202,7 @@ func TestJoinTypes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateProviderInstanceRejectsConflictingVisibleMethodEnablement(t *testing.T) {
|
func TestCreateProviderInstanceAllowsVisibleMethodProvidersFromDifferentSources(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
@@ -227,15 +230,14 @@ func TestCreateProviderInstanceRejectsConflictingVisibleMethodEnablement(t *test
|
|||||||
_, err = svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
|
_, err = svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
|
||||||
ProviderKey: "alipay",
|
ProviderKey: "alipay",
|
||||||
Name: "Official Alipay",
|
Name: "Official Alipay",
|
||||||
Config: map[string]string{"appId": "app-1"},
|
Config: map[string]string{"appId": "app-1", "privateKey": "private-key"},
|
||||||
SupportedTypes: []string{"alipay"},
|
SupportedTypes: []string{"alipay"},
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
})
|
})
|
||||||
require.Error(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, "PAYMENT_PROVIDER_CONFLICT", infraerrors.Reason(err))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t *testing.T) {
|
func TestUpdateProviderInstanceAllowsEnablingVisibleMethodProviderFromDifferentSource(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
@@ -264,7 +266,7 @@ func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t
|
|||||||
candidate, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
|
candidate, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
|
||||||
ProviderKey: "wxpay",
|
ProviderKey: "wxpay",
|
||||||
Name: "Official WeChat",
|
Name: "Official WeChat",
|
||||||
Config: map[string]string{"appId": "wx-app"},
|
Config: validWxpayProviderConfig(t),
|
||||||
SupportedTypes: []string{"wxpay"},
|
SupportedTypes: []string{"wxpay"},
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
})
|
})
|
||||||
@@ -273,8 +275,7 @@ func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t
|
|||||||
_, err = svc.UpdateProviderInstance(ctx, candidate.ID, UpdateProviderInstanceRequest{
|
_, err = svc.UpdateProviderInstance(ctx, candidate.ID, UpdateProviderInstanceRequest{
|
||||||
Enabled: boolPtrValue(true),
|
Enabled: boolPtrValue(true),
|
||||||
})
|
})
|
||||||
require.Error(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, "PAYMENT_PROVIDER_CONFLICT", infraerrors.Reason(err))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) {
|
func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) {
|
||||||
@@ -317,3 +318,25 @@ func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) {
|
|||||||
func boolPtrValue(v bool) *bool {
|
func boolPtrValue(v bool) *bool {
|
||||||
return &v
|
return &v
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validWxpayProviderConfig(t *testing.T) map[string]string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
privDER, err := x509.MarshalPKCS8PrivateKey(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
pubDER, err := x509.MarshalPKIXPublicKey(&key.PublicKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return map[string]string{
|
||||||
|
"appId": "wx-app-test",
|
||||||
|
"mchId": "mch-test",
|
||||||
|
"privateKey": string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privDER})),
|
||||||
|
"apiV3Key": "12345678901234567890123456789012",
|
||||||
|
"publicKey": string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDER})),
|
||||||
|
"publicKeyId": "public-key-id-test",
|
||||||
|
"certSerial": "cert-serial-test",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -234,6 +234,10 @@ func paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, current
|
|||||||
// if a payment was made, and processes it if so. This handles the case where
|
// if a payment was made, and processes it if so. This handles the case where
|
||||||
// the provider's notify callback was missed (e.g. EasyPay popup mode).
|
// the provider's notify callback was missed (e.g. EasyPay popup mode).
|
||||||
func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo string, userID int64) (*dbent.PaymentOrder, error) {
|
func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo string, userID int64) (*dbent.PaymentOrder, error) {
|
||||||
|
outTradeNo, err := normalizeOrderLookupOutTradeNo(outTradeNo)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
o, err := s.entClient.PaymentOrder.Query().
|
o, err := s.entClient.PaymentOrder.Query().
|
||||||
Where(paymentorder.OutTradeNo(outTradeNo)).
|
Where(paymentorder.OutTradeNo(outTradeNo)).
|
||||||
Only(ctx)
|
Only(ctx)
|
||||||
@@ -261,6 +265,10 @@ func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo
|
|||||||
// triggering any upstream reconciliation. Signed resume-token recovery is the
|
// triggering any upstream reconciliation. Signed resume-token recovery is the
|
||||||
// only public recovery path allowed to query upstream state.
|
// only public recovery path allowed to query upstream state.
|
||||||
func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) {
|
func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) {
|
||||||
|
outTradeNo, err := normalizeOrderLookupOutTradeNo(outTradeNo)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
o, err := s.entClient.PaymentOrder.Query().
|
o, err := s.entClient.PaymentOrder.Query().
|
||||||
Where(paymentorder.OutTradeNo(outTradeNo)).
|
Where(paymentorder.OutTradeNo(outTradeNo)).
|
||||||
Only(ctx)
|
Only(ctx)
|
||||||
@@ -270,6 +278,27 @@ func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo strin
|
|||||||
return o, nil
|
return o, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeOrderLookupOutTradeNo(raw string) (string, error) {
|
||||||
|
outTradeNo := strings.TrimSpace(raw)
|
||||||
|
if outTradeNo == "" {
|
||||||
|
return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is required")
|
||||||
|
}
|
||||||
|
if len(outTradeNo) > 64 {
|
||||||
|
return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is invalid")
|
||||||
|
}
|
||||||
|
for _, ch := range outTradeNo {
|
||||||
|
switch {
|
||||||
|
case ch >= 'a' && ch <= 'z':
|
||||||
|
case ch >= 'A' && ch <= 'Z':
|
||||||
|
case ch >= '0' && ch <= '9':
|
||||||
|
case ch == '_' || ch == '-':
|
||||||
|
default:
|
||||||
|
return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is invalid")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return outTradeNo, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) {
|
func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
orders, err := s.entClient.PaymentOrder.Query().Where(paymentorder.StatusEQ(OrderStatusPending), paymentorder.ExpiresAtLTE(now)).All(ctx)
|
orders, err := s.entClient.PaymentOrder.Query().Where(paymentorder.StatusEQ(OrderStatusPending), paymentorder.ExpiresAtLTE(now)).All(ctx)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -91,6 +92,8 @@ func TestBuildCreateOrderResponseCopiesJSAPIPayload(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) {
|
func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) {
|
||||||
|
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
|
||||||
|
|
||||||
svc := newWeChatPaymentOAuthTestService(map[string]string{
|
svc := newWeChatPaymentOAuthTestService(map[string]string{
|
||||||
SettingKeyWeChatConnectEnabled: "true",
|
SettingKeyWeChatConnectEnabled: "true",
|
||||||
SettingKeyWeChatConnectAppID: "wx123456",
|
SettingKeyWeChatConnectAppID: "wx123456",
|
||||||
@@ -198,6 +201,44 @@ func TestMaybeBuildWeChatOAuthRequiredResponseRequiresResumeSigningKey(t *testin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMaybeBuildWeChatOAuthRequiredResponseFallsBackToConfiguredLegacySigningKey(t *testing.T) {
|
||||||
|
svc := &PaymentService{
|
||||||
|
configService: &PaymentConfigService{
|
||||||
|
settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{
|
||||||
|
SettingKeyWeChatConnectEnabled: "true",
|
||||||
|
SettingKeyWeChatConnectAppID: "wx123456",
|
||||||
|
SettingKeyWeChatConnectAppSecret: "wechat-secret",
|
||||||
|
SettingKeyWeChatConnectMode: "mp",
|
||||||
|
SettingKeyWeChatConnectScopes: "snsapi_base",
|
||||||
|
SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
|
||||||
|
SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
|
||||||
|
}},
|
||||||
|
// Legacy stable signing key remains available for no-config upgrade compatibility.
|
||||||
|
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{
|
||||||
|
Amount: 12.5,
|
||||||
|
PaymentType: payment.TypeWxpay,
|
||||||
|
IsWeChatBrowser: true,
|
||||||
|
SrcURL: "https://merchant.example/payment?from=wechat",
|
||||||
|
OrderType: payment.OrderTypeBalance,
|
||||||
|
}, 12.5, 12.88, 0.03)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected nil error, got %v", err)
|
||||||
|
}
|
||||||
|
if resp == nil {
|
||||||
|
t.Fatal("expected oauth-required response, got nil")
|
||||||
|
}
|
||||||
|
if resp.ResultType != payment.CreatePaymentResultOAuthRequired {
|
||||||
|
t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultOAuthRequired)
|
||||||
|
}
|
||||||
|
if resp.OAuth == nil || strings.TrimSpace(resp.OAuth.AuthorizeURL) == "" {
|
||||||
|
t.Fatalf("expected oauth redirect payload, got %+v", resp.OAuth)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t *testing.T) {
|
func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t *testing.T) {
|
||||||
svc := newWeChatPaymentOAuthTestService(map[string]string{
|
svc := newWeChatPaymentOAuthTestService(map[string]string{
|
||||||
SettingKeyWeChatConnectEnabled: "true",
|
SettingKeyWeChatConnectEnabled: "true",
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token string) (*dbent.PaymentOrder, error) {
|
func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token string) (*dbent.PaymentOrder, error) {
|
||||||
@@ -16,10 +17,13 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
|
|||||||
|
|
||||||
order, err := s.entClient.PaymentOrder.Get(ctx, claims.OrderID)
|
order, err := s.entClient.PaymentOrder.Get(ctx, claims.OrderID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if dbent.IsNotFound(err) {
|
||||||
|
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
|
||||||
|
}
|
||||||
return nil, fmt.Errorf("get order by resume token: %w", err)
|
return nil, fmt.Errorf("get order by resume token: %w", err)
|
||||||
}
|
}
|
||||||
if claims.UserID > 0 && order.UserID != claims.UserID {
|
if claims.UserID > 0 && order.UserID != claims.UserID {
|
||||||
return nil, fmt.Errorf("resume token user mismatch")
|
return nil, invalidResumeTokenMatchError()
|
||||||
}
|
}
|
||||||
snapshot := psOrderProviderSnapshot(order)
|
snapshot := psOrderProviderSnapshot(order)
|
||||||
orderProviderInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID))
|
orderProviderInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID))
|
||||||
@@ -33,13 +37,13 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if claims.ProviderInstanceID != "" && orderProviderInstanceID != claims.ProviderInstanceID {
|
if claims.ProviderInstanceID != "" && orderProviderInstanceID != claims.ProviderInstanceID {
|
||||||
return nil, fmt.Errorf("resume token provider instance mismatch")
|
return nil, invalidResumeTokenMatchError()
|
||||||
}
|
}
|
||||||
if claims.ProviderKey != "" && orderProviderKey != claims.ProviderKey {
|
if claims.ProviderKey != "" && !strings.EqualFold(orderProviderKey, claims.ProviderKey) {
|
||||||
return nil, fmt.Errorf("resume token provider key mismatch")
|
return nil, invalidResumeTokenMatchError()
|
||||||
}
|
}
|
||||||
if claims.PaymentType != "" && strings.TrimSpace(order.PaymentType) != claims.PaymentType {
|
if claims.PaymentType != "" && NormalizeVisibleMethod(order.PaymentType) != NormalizeVisibleMethod(claims.PaymentType) {
|
||||||
return nil, fmt.Errorf("resume token payment type mismatch")
|
return nil, invalidResumeTokenMatchError()
|
||||||
}
|
}
|
||||||
if order.Status == OrderStatusPending || order.Status == OrderStatusExpired {
|
if order.Status == OrderStatusPending || order.Status == OrderStatusExpired {
|
||||||
result := s.checkPaid(ctx, order)
|
result := s.checkPaid(ctx, order)
|
||||||
@@ -54,6 +58,10 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
|
|||||||
return order, nil
|
return order, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func invalidResumeTokenMatchError() error {
|
||||||
|
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token does not match the payment order")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *PaymentService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) {
|
func (s *PaymentService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) {
|
||||||
return s.paymentResume().ParseWeChatPaymentResumeToken(strings.TrimSpace(token))
|
return s.paymentResume().ParseWeChatPaymentResumeToken(strings.TrimSpace(token))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||||
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -143,7 +144,7 @@ func TestGetPublicOrderByResumeTokenRejectsSnapshotMismatch(t *testing.T) {
|
|||||||
|
|
||||||
_, err = svc.GetPublicOrderByResumeToken(ctx, token)
|
_, err = svc.GetPublicOrderByResumeToken(ctx, token)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), "resume token")
|
require.Equal(t, "INVALID_RESUME_TOKEN", infraerrors.Reason(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetPublicOrderByResumeTokenUsesSnapshotAuthorityWhenColumnsDiffer(t *testing.T) {
|
func TestGetPublicOrderByResumeTokenUsesSnapshotAuthorityWhenColumnsDiffer(t *testing.T) {
|
||||||
@@ -302,3 +303,13 @@ func TestVerifyOrderPublicDoesNotCheckUpstreamForPendingOrder(t *testing.T) {
|
|||||||
require.Equal(t, order.ID, got.ID)
|
require.Equal(t, order.ID, got.ID)
|
||||||
require.Equal(t, 0, provider.queryCount)
|
require.Equal(t, 0, provider.queryCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) {
|
||||||
|
svc := &PaymentService{
|
||||||
|
entClient: newPaymentConfigServiceTestClient(t),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := svc.VerifyOrderPublic(context.Background(), " ")
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, "INVALID_OUT_TRADE_NO", infraerrors.Reason(err))
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
@@ -68,6 +69,7 @@ type WeChatPaymentResumeClaims struct {
|
|||||||
|
|
||||||
type PaymentResumeService struct {
|
type PaymentResumeService struct {
|
||||||
signingKey []byte
|
signingKey []byte
|
||||||
|
verifyKeys [][]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
type visibleMethodLoadBalancer struct {
|
type visibleMethodLoadBalancer struct {
|
||||||
@@ -75,8 +77,29 @@ type visibleMethodLoadBalancer struct {
|
|||||||
configService *PaymentConfigService
|
configService *PaymentConfigService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPaymentResumeService(signingKey []byte) *PaymentResumeService {
|
func NewPaymentResumeService(signingKey []byte, verifyFallbacks ...[]byte) *PaymentResumeService {
|
||||||
return &PaymentResumeService{signingKey: signingKey}
|
svc := &PaymentResumeService{}
|
||||||
|
if len(signingKey) > 0 {
|
||||||
|
svc.signingKey = append([]byte(nil), signingKey...)
|
||||||
|
svc.verifyKeys = append(svc.verifyKeys, svc.signingKey)
|
||||||
|
}
|
||||||
|
for _, fallback := range verifyFallbacks {
|
||||||
|
if len(fallback) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cloned := append([]byte(nil), fallback...)
|
||||||
|
duplicate := false
|
||||||
|
for _, existing := range svc.verifyKeys {
|
||||||
|
if bytes.Equal(existing, cloned) {
|
||||||
|
duplicate = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !duplicate {
|
||||||
|
svc.verifyKeys = append(svc.verifyKeys, cloned)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PaymentResumeService) isSigningConfigured() bool {
|
func (s *PaymentResumeService) isSigningConfigured() bool {
|
||||||
@@ -410,7 +433,7 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
|
|||||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||||
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed")
|
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed")
|
||||||
}
|
}
|
||||||
if !hmac.Equal([]byte(parts[1]), []byte(s.sign(parts[0]))) {
|
if !s.verifySignature(parts[0], parts[1]) {
|
||||||
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch")
|
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch")
|
||||||
}
|
}
|
||||||
payload, err := base64.RawURLEncoding.DecodeString(parts[0])
|
payload, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||||
@@ -420,6 +443,18 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
|
|||||||
return json.Unmarshal(payload, dest)
|
return json.Unmarshal(payload, dest)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *PaymentResumeService) verifySignature(payload string, signature string) bool {
|
||||||
|
if s == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, key := range s.verifyKeys {
|
||||||
|
if hmac.Equal([]byte(signature), []byte(signPaymentResumePayload(payload, key))) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func validatePaymentResumeExpiry(expiresAt int64, code, message string) error {
|
func validatePaymentResumeExpiry(expiresAt int64, code, message string) error {
|
||||||
if expiresAt <= 0 {
|
if expiresAt <= 0 {
|
||||||
return nil
|
return nil
|
||||||
@@ -431,7 +466,11 @@ func validatePaymentResumeExpiry(expiresAt int64, code, message string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *PaymentResumeService) sign(payload string) string {
|
func (s *PaymentResumeService) sign(payload string) string {
|
||||||
mac := hmac.New(sha256.New, s.signingKey)
|
return signPaymentResumePayload(payload, s.signingKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func signPaymentResumePayload(payload string, key []byte) string {
|
||||||
|
mac := hmac.New(sha256.New, key)
|
||||||
_, _ = mac.Write([]byte(payload))
|
_, _ = mac.Write([]byte(payload))
|
||||||
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
|
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -334,6 +334,59 @@ func TestParseWeChatPaymentResumeTokenRejectsExpiredToken(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPaymentServiceParseWeChatPaymentResumeTokenUsesExplicitSigningKey(t *testing.T) {
|
||||||
|
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key")
|
||||||
|
|
||||||
|
token, err := NewPaymentResumeService([]byte("explicit-payment-resume-signing-key")).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
|
||||||
|
OpenID: "openid-explicit-key",
|
||||||
|
PaymentType: payment.TypeWxpay,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &PaymentService{
|
||||||
|
configService: &PaymentConfigService{
|
||||||
|
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, err := svc.ParseWeChatPaymentResumeToken(token)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
|
||||||
|
}
|
||||||
|
if claims.OpenID != "openid-explicit-key" {
|
||||||
|
t.Fatalf("openid = %q, want %q", claims.OpenID, "openid-explicit-key")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPaymentServiceParseWeChatPaymentResumeTokenAcceptsLegacyEncryptionKeyDuringMigration(t *testing.T) {
|
||||||
|
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key")
|
||||||
|
|
||||||
|
legacyKey := []byte("0123456789abcdef0123456789abcdef")
|
||||||
|
token, err := NewPaymentResumeService(legacyKey).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
|
||||||
|
OpenID: "openid-legacy-key",
|
||||||
|
PaymentType: payment.TypeWxpay,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &PaymentService{
|
||||||
|
configService: &PaymentConfigService{
|
||||||
|
encryptionKey: legacyKey,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, err := svc.ParseWeChatPaymentResumeToken(token)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
|
||||||
|
}
|
||||||
|
if claims.OpenID != "openid-legacy-key" {
|
||||||
|
t.Fatalf("openid = %q, want %q", claims.OpenID, "openid-legacy-key")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNormalizeVisibleMethodSource(t *testing.T) {
|
func TestNormalizeVisibleMethodSource(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -424,14 +477,14 @@ func TestVisibleMethodLoadBalancerUsesConfiguredSourceWhenMultipleProvidersEnabl
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
method payment.PaymentType
|
method payment.PaymentType
|
||||||
officialName string
|
officialName string
|
||||||
officialTypes string
|
officialTypes string
|
||||||
easyPayName string
|
easyPayName string
|
||||||
easyPayTypes string
|
easyPayTypes string
|
||||||
sourceSetting string
|
sourceSetting string
|
||||||
wantProvider string
|
wantProvider string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "alipay uses official source",
|
name: "alipay uses official source",
|
||||||
@@ -487,7 +540,7 @@ func TestVisibleMethodLoadBalancerUsesConfiguredSourceWhenMultipleProvidersEnabl
|
|||||||
officialProviderKey = payment.TypeWxpay
|
officialProviderKey = payment.TypeWxpay
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = client.PaymentProviderInstance.Create().
|
_, err := client.PaymentProviderInstance.Create().
|
||||||
SetProviderKey(officialProviderKey).
|
SetProviderKey(officialProviderKey).
|
||||||
SetName(tt.officialName).
|
SetName(tt.officialName).
|
||||||
SetConfig("{}").
|
SetConfig("{}").
|
||||||
|
|||||||
@@ -1,10 +1,14 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -44,6 +48,8 @@ const (
|
|||||||
orderIDPrefix = "sub2_"
|
orderIDPrefix = "sub2_"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const paymentResumeSigningKeyEnv = "PAYMENT_RESUME_SIGNING_KEY"
|
||||||
|
|
||||||
// --- Types ---
|
// --- Types ---
|
||||||
|
|
||||||
// generateOutTradeNo creates a unique external order ID for payment providers.
|
// generateOutTradeNo creates a unique external order ID for payment providers.
|
||||||
@@ -179,7 +185,7 @@ type PaymentService struct {
|
|||||||
|
|
||||||
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
|
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
|
||||||
svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
|
svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
|
||||||
svc.resumeService = NewPaymentResumeService(psResumeSigningKey(configService))
|
svc.resumeService = psNewPaymentResumeService(configService)
|
||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -259,16 +265,54 @@ func (s *PaymentService) paymentResume() *PaymentResumeService {
|
|||||||
if s.resumeService != nil {
|
if s.resumeService != nil {
|
||||||
return s.resumeService
|
return s.resumeService
|
||||||
}
|
}
|
||||||
return NewPaymentResumeService(psResumeSigningKey(s.configService))
|
return psNewPaymentResumeService(s.configService)
|
||||||
|
}
|
||||||
|
|
||||||
|
func psNewPaymentResumeService(configService *PaymentConfigService) *PaymentResumeService {
|
||||||
|
signingKey, verifyFallbacks := psResumeSigningKeys(configService)
|
||||||
|
return NewPaymentResumeService(signingKey, verifyFallbacks...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func psResumeSigningKey(configService *PaymentConfigService) []byte {
|
func psResumeSigningKey(configService *PaymentConfigService) []byte {
|
||||||
|
signingKey, _ := psResumeSigningKeys(configService)
|
||||||
|
return signingKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func psResumeSigningKeys(configService *PaymentConfigService) ([]byte, [][]byte) {
|
||||||
|
signingKey := parsePaymentResumeSigningKey(os.Getenv(paymentResumeSigningKeyEnv))
|
||||||
|
legacyKey := psResumeLegacyVerificationKey(configService)
|
||||||
|
if len(signingKey) == 0 {
|
||||||
|
if len(legacyKey) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return legacyKey, nil
|
||||||
|
}
|
||||||
|
if len(legacyKey) == 0 || bytes.Equal(legacyKey, signingKey) {
|
||||||
|
return signingKey, nil
|
||||||
|
}
|
||||||
|
return signingKey, [][]byte{legacyKey}
|
||||||
|
}
|
||||||
|
|
||||||
|
func psResumeLegacyVerificationKey(configService *PaymentConfigService) []byte {
|
||||||
if configService == nil {
|
if configService == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return configService.encryptionKey
|
return configService.encryptionKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parsePaymentResumeSigningKey(raw string) []byte {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(raw) >= 64 && len(raw)%2 == 0 {
|
||||||
|
if decoded, err := hex.DecodeString(raw); err == nil && len(decoded) > 0 {
|
||||||
|
return decoded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return []byte(raw)
|
||||||
|
}
|
||||||
|
|
||||||
func psSliceContains(sl []string, s string) bool {
|
func psSliceContains(sl []string, s string) bool {
|
||||||
for _, v := range sl {
|
for _, v := range sl {
|
||||||
if v == s {
|
if v == s {
|
||||||
|
|||||||
@@ -82,6 +82,41 @@ func filterEnabledVisibleMethodInstances(instances []*dbent.PaymentProviderInsta
|
|||||||
return filtered
|
return filtered
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func filterVisibleMethodInstancesByProviderKey(instances []*dbent.PaymentProviderInstance, method string, providerKey string) []*dbent.PaymentProviderInstance {
|
||||||
|
filtered := make([]*dbent.PaymentProviderInstance, 0, len(instances))
|
||||||
|
for _, inst := range instances {
|
||||||
|
if !providerSupportsVisibleMethod(inst, method) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(inst.ProviderKey), strings.TrimSpace(providerKey)) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filtered = append(filtered, inst)
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
func distinctVisibleMethodProviderKeys(instances []*dbent.PaymentProviderInstance) []string {
|
||||||
|
seen := make(map[string]struct{}, len(instances))
|
||||||
|
keys := make([]string, 0, len(instances))
|
||||||
|
for _, inst := range instances {
|
||||||
|
if inst == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key := strings.TrimSpace(inst.ProviderKey)
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
normalized := strings.ToLower(key)
|
||||||
|
if _, ok := seen[normalized]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[normalized] = struct{}{}
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
return keys
|
||||||
|
}
|
||||||
|
|
||||||
func selectVisibleMethodInstanceByProviderKey(instances []*dbent.PaymentProviderInstance, providerKey string) *dbent.PaymentProviderInstance {
|
func selectVisibleMethodInstanceByProviderKey(instances []*dbent.PaymentProviderInstance, providerKey string) *dbent.PaymentProviderInstance {
|
||||||
providerKey = strings.TrimSpace(providerKey)
|
providerKey = strings.TrimSpace(providerKey)
|
||||||
if providerKey == "" {
|
if providerKey == "" {
|
||||||
@@ -117,32 +152,10 @@ func (s *PaymentConfigService) validateVisibleMethodEnablementConflicts(
|
|||||||
supportedTypes string,
|
supportedTypes string,
|
||||||
enabled bool,
|
enabled bool,
|
||||||
) error {
|
) error {
|
||||||
if s == nil || s.entClient == nil || !enabled {
|
// Visible methods are selected by configured source (official/easypay),
|
||||||
return nil
|
// so multiple enabled providers can intentionally claim the same user-facing
|
||||||
}
|
// method. Order creation and limits will route through the configured source.
|
||||||
|
_, _, _, _, _ = ctx, excludeID, providerKey, supportedTypes, enabled
|
||||||
claimedMethods := enabledVisibleMethodsForProvider(providerKey, supportedTypes)
|
|
||||||
if len(claimedMethods) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
query := s.entClient.PaymentProviderInstance.Query().
|
|
||||||
Where(paymentproviderinstance.EnabledEQ(true))
|
|
||||||
if excludeID > 0 {
|
|
||||||
query = query.Where(paymentproviderinstance.IDNEQ(excludeID))
|
|
||||||
}
|
|
||||||
instances, err := query.All(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("query enabled payment providers: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, method := range claimedMethods {
|
|
||||||
for _, inst := range instances {
|
|
||||||
if providerSupportsVisibleMethod(inst, method) {
|
|
||||||
return buildPaymentProviderConflictError(method, inst)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -172,6 +185,32 @@ func (s *PaymentConfigService) resolveVisibleMethodSourceProviderKey(ctx context
|
|||||||
return providerKey, nil
|
return providerKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *PaymentConfigService) resolveVisibleMethodProviderKey(
|
||||||
|
ctx context.Context,
|
||||||
|
method string,
|
||||||
|
matching []*dbent.PaymentProviderInstance,
|
||||||
|
) (string, error) {
|
||||||
|
switch providerKeys := distinctVisibleMethodProviderKeys(matching); len(providerKeys) {
|
||||||
|
case 0:
|
||||||
|
return "", nil
|
||||||
|
case 1:
|
||||||
|
return strings.TrimSpace(providerKeys[0]), nil
|
||||||
|
default:
|
||||||
|
providerKey, err := s.resolveVisibleMethodSourceProviderKey(ctx, method)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
selected := selectVisibleMethodInstanceByProviderKey(matching, providerKey)
|
||||||
|
if selected == nil {
|
||||||
|
return "", infraerrors.BadRequest(
|
||||||
|
"INVALID_PAYMENT_VISIBLE_METHOD_SOURCE",
|
||||||
|
fmt.Sprintf("%s source has no enabled provider instance", method),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(selected.ProviderKey), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance(
|
func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
method string,
|
method string,
|
||||||
@@ -194,23 +233,9 @@ func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance(
|
|||||||
}
|
}
|
||||||
|
|
||||||
matching := filterEnabledVisibleMethodInstances(instances, method)
|
matching := filterEnabledVisibleMethodInstances(instances, method)
|
||||||
switch len(matching) {
|
providerKey, err := s.resolveVisibleMethodProviderKey(ctx, method, matching)
|
||||||
case 0:
|
if err != nil {
|
||||||
return nil, nil
|
return nil, err
|
||||||
case 1:
|
|
||||||
return matching[0], nil
|
|
||||||
default:
|
|
||||||
providerKey, err := s.resolveVisibleMethodSourceProviderKey(ctx, method)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
selected := selectVisibleMethodInstanceByProviderKey(matching, providerKey)
|
|
||||||
if selected == nil {
|
|
||||||
return nil, infraerrors.BadRequest(
|
|
||||||
"INVALID_PAYMENT_VISIBLE_METHOD_SOURCE",
|
|
||||||
fmt.Sprintf("%s source has no enabled provider instance", method),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
return selected, nil
|
|
||||||
}
|
}
|
||||||
|
return selectVisibleMethodInstanceByProviderKey(matching, providerKey), nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user