fix(payment): support source routing and compatible resume signing

This commit is contained in:
IanShaw027
2026-04-22 12:30:17 +08:00
parent b2e0712190
commit d6a04bb772
12 changed files with 570 additions and 136 deletions

View File

@@ -164,9 +164,8 @@ func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) {
} }
func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing.T) { func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
db, err := sql.Open("sqlite", "file:payment_handler_public_resolve?mode=memory&cache=shared") db, err := sql.Open("sqlite", "file:payment_handler_public_resolve?mode=memory&cache=shared")
require.NoError(t, err) require.NoError(t, err)
@@ -250,3 +249,120 @@ func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing
require.Contains(t, resp.Data, "expires_at") require.Contains(t, resp.Data, "expires_at")
require.Contains(t, resp.Data, "refund_amount") require.Contains(t, resp.Data, "refund_amount")
} }
func TestResolveOrderPublicByResumeTokenReturnsBadRequestForMismatchedToken(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
db, err := sql.Open("sqlite", "file:payment_handler_public_resolve_mismatch?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
user, err := client.User.Create().
SetEmail("public-resolve-mismatch@example.com").
SetPasswordHash("hash").
SetUsername("public-resolve-mismatch-user").
Save(context.Background())
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(100).
SetPayAmount(103).
SetFeeRate(0.03).
SetRechargeCode("PUBLIC-RESOLVE-MISMATCH").
SetOutTradeNo("resolve-order-mismatch-no").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("trade-public-resolve-mismatch").
SetOrderType(payment.OrderTypeBalance).
SetStatus(service.OrderStatusPaid).
SetExpiresAt(time.Now().Add(time.Hour)).
SetPaidAt(time.Now()).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(context.Background())
require.NoError(t, err)
resumeSvc := service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := resumeSvc.CreateToken(service.ResumeTokenClaims{
OrderID: order.ID,
UserID: user.ID + 999,
PaymentType: payment.TypeAlipay,
CanonicalReturnURL: "https://app.example.com/payment/result",
})
require.NoError(t, err)
configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(
http.MethodPost,
"/api/v1/payment/public/orders/resolve",
bytes.NewBufferString(`{"resume_token":"`+token+`"}`),
)
ctx.Request.Header.Set("Content-Type", "application/json")
h.ResolveOrderPublicByResumeToken(ctx)
require.Equal(t, http.StatusBadRequest, recorder.Code)
var resp struct {
Code int `json:"code"`
Reason string `json:"reason"`
Message string `json:"message"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, http.StatusBadRequest, resp.Code)
require.Equal(t, "INVALID_RESUME_TOKEN", resp.Reason)
}
func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) {
gin.SetMode(gin.TestMode)
db, err := sql.Open("sqlite", "file:payment_handler_public_verify_blank?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(
http.MethodPost,
"/api/v1/payment/public/orders/verify",
bytes.NewBufferString(`{"out_trade_no":" "}`),
)
ctx.Request.Header.Set("Content-Type", "application/json")
h.VerifyOrderPublic(ctx)
require.Equal(t, http.StatusBadRequest, recorder.Code)
var resp struct {
Code int `json:"code"`
Reason string `json:"reason"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, http.StatusBadRequest, resp.Code)
require.Equal(t, "INVALID_OUT_TRADE_NO", resp.Reason)
}

View File

@@ -20,7 +20,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
return nil, fmt.Errorf("query provider instances: %w", err) return nil, fmt.Errorf("query provider instances: %w", err)
} }
typeInstances := pcGroupByPaymentType(instances) typeInstances := pcGroupByPaymentType(instances)
typeInstances = pcApplyEnabledVisibleMethodInstances(typeInstances, instances) typeInstances = s.pcApplyEnabledVisibleMethodInstances(ctx, typeInstances, instances)
resp := &MethodLimitsResponse{ resp := &MethodLimitsResponse{
Methods: make(map[string]MethodLimits, len(typeInstances)), Methods: make(map[string]MethodLimits, len(typeInstances)),
} }
@@ -32,7 +32,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
return resp, nil return resp, nil
} }
func pcApplyEnabledVisibleMethodInstances(typeInstances map[string][]*dbent.PaymentProviderInstance, instances []*dbent.PaymentProviderInstance) map[string][]*dbent.PaymentProviderInstance { func (s *PaymentConfigService) pcApplyEnabledVisibleMethodInstances(ctx context.Context, typeInstances map[string][]*dbent.PaymentProviderInstance, instances []*dbent.PaymentProviderInstance) map[string][]*dbent.PaymentProviderInstance {
if len(typeInstances) == 0 { if len(typeInstances) == 0 {
return typeInstances return typeInstances
} }
@@ -44,11 +44,17 @@ func pcApplyEnabledVisibleMethodInstances(typeInstances map[string][]*dbent.Paym
for _, method := range []string{payment.TypeAlipay, payment.TypeWxpay} { for _, method := range []string{payment.TypeAlipay, payment.TypeWxpay} {
matching := filterEnabledVisibleMethodInstances(instances, method) matching := filterEnabledVisibleMethodInstances(instances, method)
if len(matching) != 1 { providerKey, err := s.resolveVisibleMethodProviderKey(ctx, method, matching)
if err != nil || providerKey == "" {
delete(filtered, method) delete(filtered, method)
continue continue
} }
filtered[method] = []*dbent.PaymentProviderInstance{matching[0]} selectedInstances := filterVisibleMethodInstancesByProviderKey(instances, method, providerKey)
if len(selectedInstances) == 0 {
delete(filtered, method)
continue
}
filtered[method] = selectedInstances
} }
return filtered return filtered
} }

View File

@@ -301,7 +301,35 @@ func TestPcInstanceTypeLimits(t *testing.T) {
}) })
} }
func TestGetAvailableMethodLimitsHidesConflictingVisibleMethodProviders(t *testing.T) { func TestGetAvailableMethodLimitsUsesConfiguredVisibleMethodSource(t *testing.T) {
tests := []struct {
name string
sourceSetting string
wantAlipaySingleMin float64
wantAlipaySingleMax float64
wantGlobalMin float64
wantGlobalMax float64
}{
{
name: "official source",
sourceSetting: VisibleMethodSourceOfficialAlipay,
wantAlipaySingleMin: 10,
wantAlipaySingleMax: 100,
wantGlobalMin: 10,
wantGlobalMax: 300,
},
{
name: "easypay source",
sourceSetting: VisibleMethodSourceEasyPayAlipay,
wantAlipaySingleMin: 20,
wantAlipaySingleMax: 200,
wantGlobalMin: 20,
wantGlobalMax: 300,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background() ctx := context.Background()
client := newPaymentConfigServiceTestClient(t) client := newPaymentConfigServiceTestClient(t)
@@ -341,6 +369,11 @@ func TestGetAvailableMethodLimitsHidesConflictingVisibleMethodProviders(t *testi
svc := &PaymentConfigService{ svc := &PaymentConfigService{
entClient: client, entClient: client,
settingRepo: &paymentConfigSettingRepoStub{
values: map[string]string{
SettingPaymentVisibleMethodAlipaySource: tt.sourceSetting,
},
},
} }
resp, err := svc.GetAvailableMethodLimits(ctx) resp, err := svc.GetAvailableMethodLimits(ctx)
@@ -348,8 +381,12 @@ func TestGetAvailableMethodLimitsHidesConflictingVisibleMethodProviders(t *testi
t.Fatalf("GetAvailableMethodLimits returned error: %v", err) t.Fatalf("GetAvailableMethodLimits returned error: %v", err)
} }
if _, ok := resp.Methods[payment.TypeAlipay]; ok { alipayLimits, ok := resp.Methods[payment.TypeAlipay]
t.Fatalf("alipay should be hidden when multiple enabled providers claim it, got %v", resp.Methods[payment.TypeAlipay]) if !ok {
t.Fatalf("expected alipay limits to remain visible, got %v", resp.Methods)
}
if alipayLimits.SingleMin != tt.wantAlipaySingleMin || alipayLimits.SingleMax != tt.wantAlipaySingleMax {
t.Fatalf("alipay limits = %+v, want min=%v max=%v", alipayLimits, tt.wantAlipaySingleMin, tt.wantAlipaySingleMax)
} }
wxpayLimits, ok := resp.Methods[payment.TypeWxpay] wxpayLimits, ok := resp.Methods[payment.TypeWxpay]
@@ -359,7 +396,9 @@ func TestGetAvailableMethodLimitsHidesConflictingVisibleMethodProviders(t *testi
if wxpayLimits.SingleMin != 30 || wxpayLimits.SingleMax != 300 { if wxpayLimits.SingleMin != 30 || wxpayLimits.SingleMax != 300 {
t.Fatalf("wxpay limits = %+v, want official-only min=30 max=300", wxpayLimits) t.Fatalf("wxpay limits = %+v, want official-only min=30 max=300", wxpayLimits)
} }
if resp.GlobalMin != 30 || resp.GlobalMax != 300 { if resp.GlobalMin != tt.wantGlobalMin || resp.GlobalMax != tt.wantGlobalMax {
t.Fatalf("global range = (%v, %v), want (30, 300)", resp.GlobalMin, resp.GlobalMax) t.Fatalf("global range = (%v, %v), want (%v, %v)", resp.GlobalMin, resp.GlobalMax, tt.wantGlobalMin, tt.wantGlobalMax)
}
})
} }
} }

View File

@@ -4,9 +4,12 @@ package service
import ( import (
"context" "context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"testing" "testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -199,7 +202,7 @@ func TestJoinTypes(t *testing.T) {
} }
} }
func TestCreateProviderInstanceRejectsConflictingVisibleMethodEnablement(t *testing.T) { func TestCreateProviderInstanceAllowsVisibleMethodProvidersFromDifferentSources(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background() ctx := context.Background()
@@ -227,15 +230,14 @@ func TestCreateProviderInstanceRejectsConflictingVisibleMethodEnablement(t *test
_, err = svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{ _, err = svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
ProviderKey: "alipay", ProviderKey: "alipay",
Name: "Official Alipay", Name: "Official Alipay",
Config: map[string]string{"appId": "app-1"}, Config: map[string]string{"appId": "app-1", "privateKey": "private-key"},
SupportedTypes: []string{"alipay"}, SupportedTypes: []string{"alipay"},
Enabled: true, Enabled: true,
}) })
require.Error(t, err) require.NoError(t, err)
require.Equal(t, "PAYMENT_PROVIDER_CONFLICT", infraerrors.Reason(err))
} }
func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t *testing.T) { func TestUpdateProviderInstanceAllowsEnablingVisibleMethodProviderFromDifferentSource(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background() ctx := context.Background()
@@ -264,7 +266,7 @@ func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t
candidate, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{ candidate, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
ProviderKey: "wxpay", ProviderKey: "wxpay",
Name: "Official WeChat", Name: "Official WeChat",
Config: map[string]string{"appId": "wx-app"}, Config: validWxpayProviderConfig(t),
SupportedTypes: []string{"wxpay"}, SupportedTypes: []string{"wxpay"},
Enabled: false, Enabled: false,
}) })
@@ -273,8 +275,7 @@ func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t
_, err = svc.UpdateProviderInstance(ctx, candidate.ID, UpdateProviderInstanceRequest{ _, err = svc.UpdateProviderInstance(ctx, candidate.ID, UpdateProviderInstanceRequest{
Enabled: boolPtrValue(true), Enabled: boolPtrValue(true),
}) })
require.Error(t, err) require.NoError(t, err)
require.Equal(t, "PAYMENT_PROVIDER_CONFLICT", infraerrors.Reason(err))
} }
func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) { func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) {
@@ -317,3 +318,25 @@ func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) {
func boolPtrValue(v bool) *bool { func boolPtrValue(v bool) *bool {
return &v return &v
} }
func validWxpayProviderConfig(t *testing.T) map[string]string {
t.Helper()
key, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
privDER, err := x509.MarshalPKCS8PrivateKey(key)
require.NoError(t, err)
pubDER, err := x509.MarshalPKIXPublicKey(&key.PublicKey)
require.NoError(t, err)
return map[string]string{
"appId": "wx-app-test",
"mchId": "mch-test",
"privateKey": string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privDER})),
"apiV3Key": "12345678901234567890123456789012",
"publicKey": string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDER})),
"publicKeyId": "public-key-id-test",
"certSerial": "cert-serial-test",
}
}

View File

@@ -234,6 +234,10 @@ func paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, current
// if a payment was made, and processes it if so. This handles the case where // if a payment was made, and processes it if so. This handles the case where
// the provider's notify callback was missed (e.g. EasyPay popup mode). // the provider's notify callback was missed (e.g. EasyPay popup mode).
func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo string, userID int64) (*dbent.PaymentOrder, error) { func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo string, userID int64) (*dbent.PaymentOrder, error) {
outTradeNo, err := normalizeOrderLookupOutTradeNo(outTradeNo)
if err != nil {
return nil, err
}
o, err := s.entClient.PaymentOrder.Query(). o, err := s.entClient.PaymentOrder.Query().
Where(paymentorder.OutTradeNo(outTradeNo)). Where(paymentorder.OutTradeNo(outTradeNo)).
Only(ctx) Only(ctx)
@@ -261,6 +265,10 @@ func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo
// triggering any upstream reconciliation. Signed resume-token recovery is the // triggering any upstream reconciliation. Signed resume-token recovery is the
// only public recovery path allowed to query upstream state. // only public recovery path allowed to query upstream state.
func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) { func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) {
outTradeNo, err := normalizeOrderLookupOutTradeNo(outTradeNo)
if err != nil {
return nil, err
}
o, err := s.entClient.PaymentOrder.Query(). o, err := s.entClient.PaymentOrder.Query().
Where(paymentorder.OutTradeNo(outTradeNo)). Where(paymentorder.OutTradeNo(outTradeNo)).
Only(ctx) Only(ctx)
@@ -270,6 +278,27 @@ func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo strin
return o, nil return o, nil
} }
func normalizeOrderLookupOutTradeNo(raw string) (string, error) {
outTradeNo := strings.TrimSpace(raw)
if outTradeNo == "" {
return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is required")
}
if len(outTradeNo) > 64 {
return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is invalid")
}
for _, ch := range outTradeNo {
switch {
case ch >= 'a' && ch <= 'z':
case ch >= 'A' && ch <= 'Z':
case ch >= '0' && ch <= '9':
case ch == '_' || ch == '-':
default:
return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is invalid")
}
}
return outTradeNo, nil
}
func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) { func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) {
now := time.Now() now := time.Now()
orders, err := s.entClient.PaymentOrder.Query().Where(paymentorder.StatusEQ(OrderStatusPending), paymentorder.ExpiresAtLTE(now)).All(ctx) orders, err := s.entClient.PaymentOrder.Query().Where(paymentorder.StatusEQ(OrderStatusPending), paymentorder.ExpiresAtLTE(now)).All(ctx)

View File

@@ -2,6 +2,7 @@ package service
import ( import (
"context" "context"
"strings"
"testing" "testing"
"time" "time"
@@ -91,6 +92,8 @@ func TestBuildCreateOrderResponseCopiesJSAPIPayload(t *testing.T) {
} }
func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) { func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) {
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
svc := newWeChatPaymentOAuthTestService(map[string]string{ svc := newWeChatPaymentOAuthTestService(map[string]string{
SettingKeyWeChatConnectEnabled: "true", SettingKeyWeChatConnectEnabled: "true",
SettingKeyWeChatConnectAppID: "wx123456", SettingKeyWeChatConnectAppID: "wx123456",
@@ -198,6 +201,44 @@ func TestMaybeBuildWeChatOAuthRequiredResponseRequiresResumeSigningKey(t *testin
} }
} }
func TestMaybeBuildWeChatOAuthRequiredResponseFallsBackToConfiguredLegacySigningKey(t *testing.T) {
svc := &PaymentService{
configService: &PaymentConfigService{
settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{
SettingKeyWeChatConnectEnabled: "true",
SettingKeyWeChatConnectAppID: "wx123456",
SettingKeyWeChatConnectAppSecret: "wechat-secret",
SettingKeyWeChatConnectMode: "mp",
SettingKeyWeChatConnectScopes: "snsapi_base",
SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
}},
// Legacy stable signing key remains available for no-config upgrade compatibility.
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
},
}
resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{
Amount: 12.5,
PaymentType: payment.TypeWxpay,
IsWeChatBrowser: true,
SrcURL: "https://merchant.example/payment?from=wechat",
OrderType: payment.OrderTypeBalance,
}, 12.5, 12.88, 0.03)
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if resp == nil {
t.Fatal("expected oauth-required response, got nil")
}
if resp.ResultType != payment.CreatePaymentResultOAuthRequired {
t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultOAuthRequired)
}
if resp.OAuth == nil || strings.TrimSpace(resp.OAuth.AuthorizeURL) == "" {
t.Fatalf("expected oauth redirect payload, got %+v", resp.OAuth)
}
}
func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t *testing.T) { func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t *testing.T) {
svc := newWeChatPaymentOAuthTestService(map[string]string{ svc := newWeChatPaymentOAuthTestService(map[string]string{
SettingKeyWeChatConnectEnabled: "true", SettingKeyWeChatConnectEnabled: "true",

View File

@@ -6,6 +6,7 @@ import (
"strings" "strings"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
) )
func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token string) (*dbent.PaymentOrder, error) { func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token string) (*dbent.PaymentOrder, error) {
@@ -16,10 +17,13 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
order, err := s.entClient.PaymentOrder.Get(ctx, claims.OrderID) order, err := s.entClient.PaymentOrder.Get(ctx, claims.OrderID)
if err != nil { if err != nil {
if dbent.IsNotFound(err) {
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
return nil, fmt.Errorf("get order by resume token: %w", err) return nil, fmt.Errorf("get order by resume token: %w", err)
} }
if claims.UserID > 0 && order.UserID != claims.UserID { if claims.UserID > 0 && order.UserID != claims.UserID {
return nil, fmt.Errorf("resume token user mismatch") return nil, invalidResumeTokenMatchError()
} }
snapshot := psOrderProviderSnapshot(order) snapshot := psOrderProviderSnapshot(order)
orderProviderInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID)) orderProviderInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID))
@@ -33,13 +37,13 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
} }
} }
if claims.ProviderInstanceID != "" && orderProviderInstanceID != claims.ProviderInstanceID { if claims.ProviderInstanceID != "" && orderProviderInstanceID != claims.ProviderInstanceID {
return nil, fmt.Errorf("resume token provider instance mismatch") return nil, invalidResumeTokenMatchError()
} }
if claims.ProviderKey != "" && orderProviderKey != claims.ProviderKey { if claims.ProviderKey != "" && !strings.EqualFold(orderProviderKey, claims.ProviderKey) {
return nil, fmt.Errorf("resume token provider key mismatch") return nil, invalidResumeTokenMatchError()
} }
if claims.PaymentType != "" && strings.TrimSpace(order.PaymentType) != claims.PaymentType { if claims.PaymentType != "" && NormalizeVisibleMethod(order.PaymentType) != NormalizeVisibleMethod(claims.PaymentType) {
return nil, fmt.Errorf("resume token payment type mismatch") return nil, invalidResumeTokenMatchError()
} }
if order.Status == OrderStatusPending || order.Status == OrderStatusExpired { if order.Status == OrderStatusPending || order.Status == OrderStatusExpired {
result := s.checkPaid(ctx, order) result := s.checkPaid(ctx, order)
@@ -54,6 +58,10 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
return order, nil return order, nil
} }
func invalidResumeTokenMatchError() error {
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token does not match the payment order")
}
func (s *PaymentService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) { func (s *PaymentService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) {
return s.paymentResume().ParseWeChatPaymentResumeToken(strings.TrimSpace(token)) return s.paymentResume().ParseWeChatPaymentResumeToken(strings.TrimSpace(token))
} }

View File

@@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/payment" "github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -143,7 +144,7 @@ func TestGetPublicOrderByResumeTokenRejectsSnapshotMismatch(t *testing.T) {
_, err = svc.GetPublicOrderByResumeToken(ctx, token) _, err = svc.GetPublicOrderByResumeToken(ctx, token)
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "resume token") require.Equal(t, "INVALID_RESUME_TOKEN", infraerrors.Reason(err))
} }
func TestGetPublicOrderByResumeTokenUsesSnapshotAuthorityWhenColumnsDiffer(t *testing.T) { func TestGetPublicOrderByResumeTokenUsesSnapshotAuthorityWhenColumnsDiffer(t *testing.T) {
@@ -302,3 +303,13 @@ func TestVerifyOrderPublicDoesNotCheckUpstreamForPendingOrder(t *testing.T) {
require.Equal(t, order.ID, got.ID) require.Equal(t, order.ID, got.ID)
require.Equal(t, 0, provider.queryCount) require.Equal(t, 0, provider.queryCount)
} }
func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) {
svc := &PaymentService{
entClient: newPaymentConfigServiceTestClient(t),
}
_, err := svc.VerifyOrderPublic(context.Background(), " ")
require.Error(t, err)
require.Equal(t, "INVALID_OUT_TRADE_NO", infraerrors.Reason(err))
}

View File

@@ -1,6 +1,7 @@
package service package service
import ( import (
"bytes"
"context" "context"
"crypto/hmac" "crypto/hmac"
"crypto/sha256" "crypto/sha256"
@@ -68,6 +69,7 @@ type WeChatPaymentResumeClaims struct {
type PaymentResumeService struct { type PaymentResumeService struct {
signingKey []byte signingKey []byte
verifyKeys [][]byte
} }
type visibleMethodLoadBalancer struct { type visibleMethodLoadBalancer struct {
@@ -75,8 +77,29 @@ type visibleMethodLoadBalancer struct {
configService *PaymentConfigService configService *PaymentConfigService
} }
func NewPaymentResumeService(signingKey []byte) *PaymentResumeService { func NewPaymentResumeService(signingKey []byte, verifyFallbacks ...[]byte) *PaymentResumeService {
return &PaymentResumeService{signingKey: signingKey} svc := &PaymentResumeService{}
if len(signingKey) > 0 {
svc.signingKey = append([]byte(nil), signingKey...)
svc.verifyKeys = append(svc.verifyKeys, svc.signingKey)
}
for _, fallback := range verifyFallbacks {
if len(fallback) == 0 {
continue
}
cloned := append([]byte(nil), fallback...)
duplicate := false
for _, existing := range svc.verifyKeys {
if bytes.Equal(existing, cloned) {
duplicate = true
break
}
}
if !duplicate {
svc.verifyKeys = append(svc.verifyKeys, cloned)
}
}
return svc
} }
func (s *PaymentResumeService) isSigningConfigured() bool { func (s *PaymentResumeService) isSigningConfigured() bool {
@@ -410,7 +433,7 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
if len(parts) != 2 || parts[0] == "" || parts[1] == "" { if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed") return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed")
} }
if !hmac.Equal([]byte(parts[1]), []byte(s.sign(parts[0]))) { if !s.verifySignature(parts[0], parts[1]) {
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch") return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch")
} }
payload, err := base64.RawURLEncoding.DecodeString(parts[0]) payload, err := base64.RawURLEncoding.DecodeString(parts[0])
@@ -420,6 +443,18 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
return json.Unmarshal(payload, dest) return json.Unmarshal(payload, dest)
} }
func (s *PaymentResumeService) verifySignature(payload string, signature string) bool {
if s == nil {
return false
}
for _, key := range s.verifyKeys {
if hmac.Equal([]byte(signature), []byte(signPaymentResumePayload(payload, key))) {
return true
}
}
return false
}
func validatePaymentResumeExpiry(expiresAt int64, code, message string) error { func validatePaymentResumeExpiry(expiresAt int64, code, message string) error {
if expiresAt <= 0 { if expiresAt <= 0 {
return nil return nil
@@ -431,7 +466,11 @@ func validatePaymentResumeExpiry(expiresAt int64, code, message string) error {
} }
func (s *PaymentResumeService) sign(payload string) string { func (s *PaymentResumeService) sign(payload string) string {
mac := hmac.New(sha256.New, s.signingKey) return signPaymentResumePayload(payload, s.signingKey)
}
func signPaymentResumePayload(payload string, key []byte) string {
mac := hmac.New(sha256.New, key)
_, _ = mac.Write([]byte(payload)) _, _ = mac.Write([]byte(payload))
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
} }

View File

@@ -334,6 +334,59 @@ func TestParseWeChatPaymentResumeTokenRejectsExpiredToken(t *testing.T) {
} }
} }
func TestPaymentServiceParseWeChatPaymentResumeTokenUsesExplicitSigningKey(t *testing.T) {
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key")
token, err := NewPaymentResumeService([]byte("explicit-payment-resume-signing-key")).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
OpenID: "openid-explicit-key",
PaymentType: payment.TypeWxpay,
})
if err != nil {
t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
}
svc := &PaymentService{
configService: &PaymentConfigService{
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
},
}
claims, err := svc.ParseWeChatPaymentResumeToken(token)
if err != nil {
t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
}
if claims.OpenID != "openid-explicit-key" {
t.Fatalf("openid = %q, want %q", claims.OpenID, "openid-explicit-key")
}
}
func TestPaymentServiceParseWeChatPaymentResumeTokenAcceptsLegacyEncryptionKeyDuringMigration(t *testing.T) {
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key")
legacyKey := []byte("0123456789abcdef0123456789abcdef")
token, err := NewPaymentResumeService(legacyKey).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
OpenID: "openid-legacy-key",
PaymentType: payment.TypeWxpay,
})
if err != nil {
t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
}
svc := &PaymentService{
configService: &PaymentConfigService{
encryptionKey: legacyKey,
},
}
claims, err := svc.ParseWeChatPaymentResumeToken(token)
if err != nil {
t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
}
if claims.OpenID != "openid-legacy-key" {
t.Fatalf("openid = %q, want %q", claims.OpenID, "openid-legacy-key")
}
}
func TestNormalizeVisibleMethodSource(t *testing.T) { func TestNormalizeVisibleMethodSource(t *testing.T) {
t.Parallel() t.Parallel()
@@ -487,7 +540,7 @@ func TestVisibleMethodLoadBalancerUsesConfiguredSourceWhenMultipleProvidersEnabl
officialProviderKey = payment.TypeWxpay officialProviderKey = payment.TypeWxpay
} }
_, err = client.PaymentProviderInstance.Create(). _, err := client.PaymentProviderInstance.Create().
SetProviderKey(officialProviderKey). SetProviderKey(officialProviderKey).
SetName(tt.officialName). SetName(tt.officialName).
SetConfig("{}"). SetConfig("{}").

View File

@@ -1,10 +1,14 @@
package service package service
import ( import (
"bytes"
"context" "context"
"encoding/hex"
"fmt" "fmt"
"log/slog" "log/slog"
"math/rand/v2" "math/rand/v2"
"os"
"strings"
"sync" "sync"
"time" "time"
@@ -44,6 +48,8 @@ const (
orderIDPrefix = "sub2_" orderIDPrefix = "sub2_"
) )
const paymentResumeSigningKeyEnv = "PAYMENT_RESUME_SIGNING_KEY"
// --- Types --- // --- Types ---
// generateOutTradeNo creates a unique external order ID for payment providers. // generateOutTradeNo creates a unique external order ID for payment providers.
@@ -179,7 +185,7 @@ type PaymentService struct {
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService { func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo} svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
svc.resumeService = NewPaymentResumeService(psResumeSigningKey(configService)) svc.resumeService = psNewPaymentResumeService(configService)
return svc return svc
} }
@@ -259,16 +265,54 @@ func (s *PaymentService) paymentResume() *PaymentResumeService {
if s.resumeService != nil { if s.resumeService != nil {
return s.resumeService return s.resumeService
} }
return NewPaymentResumeService(psResumeSigningKey(s.configService)) return psNewPaymentResumeService(s.configService)
}
func psNewPaymentResumeService(configService *PaymentConfigService) *PaymentResumeService {
signingKey, verifyFallbacks := psResumeSigningKeys(configService)
return NewPaymentResumeService(signingKey, verifyFallbacks...)
} }
func psResumeSigningKey(configService *PaymentConfigService) []byte { func psResumeSigningKey(configService *PaymentConfigService) []byte {
signingKey, _ := psResumeSigningKeys(configService)
return signingKey
}
func psResumeSigningKeys(configService *PaymentConfigService) ([]byte, [][]byte) {
signingKey := parsePaymentResumeSigningKey(os.Getenv(paymentResumeSigningKeyEnv))
legacyKey := psResumeLegacyVerificationKey(configService)
if len(signingKey) == 0 {
if len(legacyKey) == 0 {
return nil, nil
}
return legacyKey, nil
}
if len(legacyKey) == 0 || bytes.Equal(legacyKey, signingKey) {
return signingKey, nil
}
return signingKey, [][]byte{legacyKey}
}
func psResumeLegacyVerificationKey(configService *PaymentConfigService) []byte {
if configService == nil { if configService == nil {
return nil return nil
} }
return configService.encryptionKey return configService.encryptionKey
} }
func parsePaymentResumeSigningKey(raw string) []byte {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
if len(raw) >= 64 && len(raw)%2 == 0 {
if decoded, err := hex.DecodeString(raw); err == nil && len(decoded) > 0 {
return decoded
}
}
return []byte(raw)
}
func psSliceContains(sl []string, s string) bool { func psSliceContains(sl []string, s string) bool {
for _, v := range sl { for _, v := range sl {
if v == s { if v == s {

View File

@@ -82,6 +82,41 @@ func filterEnabledVisibleMethodInstances(instances []*dbent.PaymentProviderInsta
return filtered return filtered
} }
func filterVisibleMethodInstancesByProviderKey(instances []*dbent.PaymentProviderInstance, method string, providerKey string) []*dbent.PaymentProviderInstance {
filtered := make([]*dbent.PaymentProviderInstance, 0, len(instances))
for _, inst := range instances {
if !providerSupportsVisibleMethod(inst, method) {
continue
}
if !strings.EqualFold(strings.TrimSpace(inst.ProviderKey), strings.TrimSpace(providerKey)) {
continue
}
filtered = append(filtered, inst)
}
return filtered
}
func distinctVisibleMethodProviderKeys(instances []*dbent.PaymentProviderInstance) []string {
seen := make(map[string]struct{}, len(instances))
keys := make([]string, 0, len(instances))
for _, inst := range instances {
if inst == nil {
continue
}
key := strings.TrimSpace(inst.ProviderKey)
if key == "" {
continue
}
normalized := strings.ToLower(key)
if _, ok := seen[normalized]; ok {
continue
}
seen[normalized] = struct{}{}
keys = append(keys, key)
}
return keys
}
func selectVisibleMethodInstanceByProviderKey(instances []*dbent.PaymentProviderInstance, providerKey string) *dbent.PaymentProviderInstance { func selectVisibleMethodInstanceByProviderKey(instances []*dbent.PaymentProviderInstance, providerKey string) *dbent.PaymentProviderInstance {
providerKey = strings.TrimSpace(providerKey) providerKey = strings.TrimSpace(providerKey)
if providerKey == "" { if providerKey == "" {
@@ -117,32 +152,10 @@ func (s *PaymentConfigService) validateVisibleMethodEnablementConflicts(
supportedTypes string, supportedTypes string,
enabled bool, enabled bool,
) error { ) error {
if s == nil || s.entClient == nil || !enabled { // Visible methods are selected by configured source (official/easypay),
return nil // so multiple enabled providers can intentionally claim the same user-facing
} // method. Order creation and limits will route through the configured source.
_, _, _, _, _ = ctx, excludeID, providerKey, supportedTypes, enabled
claimedMethods := enabledVisibleMethodsForProvider(providerKey, supportedTypes)
if len(claimedMethods) == 0 {
return nil
}
query := s.entClient.PaymentProviderInstance.Query().
Where(paymentproviderinstance.EnabledEQ(true))
if excludeID > 0 {
query = query.Where(paymentproviderinstance.IDNEQ(excludeID))
}
instances, err := query.All(ctx)
if err != nil {
return fmt.Errorf("query enabled payment providers: %w", err)
}
for _, method := range claimedMethods {
for _, inst := range instances {
if providerSupportsVisibleMethod(inst, method) {
return buildPaymentProviderConflictError(method, inst)
}
}
}
return nil return nil
} }
@@ -172,6 +185,32 @@ func (s *PaymentConfigService) resolveVisibleMethodSourceProviderKey(ctx context
return providerKey, nil return providerKey, nil
} }
func (s *PaymentConfigService) resolveVisibleMethodProviderKey(
ctx context.Context,
method string,
matching []*dbent.PaymentProviderInstance,
) (string, error) {
switch providerKeys := distinctVisibleMethodProviderKeys(matching); len(providerKeys) {
case 0:
return "", nil
case 1:
return strings.TrimSpace(providerKeys[0]), nil
default:
providerKey, err := s.resolveVisibleMethodSourceProviderKey(ctx, method)
if err != nil {
return "", err
}
selected := selectVisibleMethodInstanceByProviderKey(matching, providerKey)
if selected == nil {
return "", infraerrors.BadRequest(
"INVALID_PAYMENT_VISIBLE_METHOD_SOURCE",
fmt.Sprintf("%s source has no enabled provider instance", method),
)
}
return strings.TrimSpace(selected.ProviderKey), nil
}
}
func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance( func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance(
ctx context.Context, ctx context.Context,
method string, method string,
@@ -194,23 +233,9 @@ func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance(
} }
matching := filterEnabledVisibleMethodInstances(instances, method) matching := filterEnabledVisibleMethodInstances(instances, method)
switch len(matching) { providerKey, err := s.resolveVisibleMethodProviderKey(ctx, method, matching)
case 0:
return nil, nil
case 1:
return matching[0], nil
default:
providerKey, err := s.resolveVisibleMethodSourceProviderKey(ctx, method)
if err != nil { if err != nil {
return nil, err return nil, err
} }
selected := selectVisibleMethodInstanceByProviderKey(matching, providerKey) return selectVisibleMethodInstanceByProviderKey(matching, providerKey), nil
if selected == nil {
return nil, infraerrors.BadRequest(
"INVALID_PAYMENT_VISIBLE_METHOD_SOURCE",
fmt.Sprintf("%s source has no enabled provider instance", method),
)
}
return selected, nil
}
} }