fix(review): harden payment, oauth, and migration paths

This commit is contained in:
IanShaw027
2026-04-22 10:26:22 +08:00
parent 7fbd5177c2
commit c229f33e9e
33 changed files with 704 additions and 79 deletions

View File

@@ -80,21 +80,25 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
})
return err
}
// Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount).
// Also skip if paid is NaN/Inf (malformed provider data).
if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) {
if math.Abs(paid-o.PayAmount) > amountToleranceCNY {
s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo})
return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid)
}
if !isValidProviderAmount(paid) {
s.writeAuditLog(ctx, o.ID, "PAYMENT_INVALID_AMOUNT", pk, map[string]any{
"expected": o.PayAmount,
"paid": paid,
"tradeNo": tradeNo,
})
return fmt.Errorf("invalid paid amount from provider: %v", paid)
}
// Use order's expected amount when provider didn't report one
if paid <= 0 || math.IsNaN(paid) || math.IsInf(paid, 0) {
paid = o.PayAmount
if math.Abs(paid-o.PayAmount) > amountToleranceCNY {
s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo})
return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid)
}
return s.toPaid(ctx, o, tradeNo, paid, pk)
}
func isValidProviderAmount(amount float64) bool {
return amount > 0 && !math.IsNaN(amount) && !math.IsInf(amount, 0)
}
func validateProviderNotificationMetadata(order *dbent.PaymentOrder, providerKey string, metadata map[string]string) error {
return validateProviderSnapshotMetadata(order, providerKey, metadata)
}

View File

@@ -5,6 +5,7 @@ package service
import (
"context"
"errors"
"math"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -322,6 +323,16 @@ func TestParseLegacyPaymentOrderID(t *testing.T) {
assert.False(t, ok)
}
func TestIsValidProviderAmount(t *testing.T) {
t.Parallel()
assert.True(t, isValidProviderAmount(0.01))
assert.False(t, isValidProviderAmount(0))
assert.False(t, isValidProviderAmount(-1))
assert.False(t, isValidProviderAmount(math.NaN()))
assert.False(t, isValidProviderAmount(math.Inf(1)))
}
func TestValidateProviderNotificationMetadataRejectsAlipaySnapshotMismatch(t *testing.T) {
t.Parallel()

View File

@@ -139,6 +139,10 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
tm = defaultOrderTimeoutMin
}
exp := time.Now().Add(time.Duration(tm) * time.Minute)
outTradeNo, err := s.allocateOutTradeNo(ctx, tx)
if err != nil {
return nil, err
}
providerSnapshot := buildPaymentOrderProviderSnapshot(sel, req)
selectedInstanceID := ""
selectedProviderKey := ""
@@ -155,7 +159,7 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
SetPayAmount(payAmount).
SetFeeRate(feeRate).
SetRechargeCode("").
SetOutTradeNo(generateOutTradeNo()).
SetOutTradeNo(outTradeNo).
SetPaymentType(req.PaymentType).
SetPaymentTradeNo("").
SetOrderType(req.OrderType).
@@ -193,6 +197,21 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
return order, nil
}
func (s *PaymentService) allocateOutTradeNo(ctx context.Context, tx *dbent.Tx) (string, error) {
const maxAttempts = 5
for attempt := 0; attempt < maxAttempts; attempt++ {
candidate := generateOutTradeNo()
exists, err := tx.PaymentOrder.Query().Where(paymentorder.OutTradeNo(candidate)).Exist(ctx)
if err != nil {
return "", fmt.Errorf("check out_trade_no uniqueness: %w", err)
}
if !exists {
return candidate, nil
}
}
return "", fmt.Errorf("generate unique out_trade_no: exhausted %d attempts", maxAttempts)
}
func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, userID int64, max int) error {
if max <= 0 {
max = defaultMaxPendingOrders
@@ -366,7 +385,10 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
}
resumeToken := ""
if resume := s.paymentResume(); resume != nil {
if resume.isSigningConfigured() {
if canonicalReturnURL != "" {
if err := resume.ensureSigningKey(); err != nil {
return nil, err
}
resumeToken, err = resume.CreateToken(ResumeTokenClaims{
OrderID: order.ID,
UserID: order.UserID,
@@ -482,6 +504,9 @@ func (s *PaymentService) buildWeChatOAuthRequiredResponse(ctx context.Context, r
if err != nil {
return nil, err
}
if err := s.paymentResume().ensureSigningKey(); err != nil {
return nil, err
}
authorizeURL, err := buildWeChatPaymentOAuthStartURL(req, "snsapi_base")
if err != nil {

View File

@@ -150,6 +150,16 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
return ""
}
if resp.Status == payment.ProviderStatusPaid {
if !isValidProviderAmount(resp.Amount) {
s.writeAuditLog(ctx, o.ID, "PAYMENT_INVALID_AMOUNT", prov.ProviderKey(), map[string]any{
"expected": o.PayAmount,
"paid": resp.Amount,
"tradeNo": resp.TradeNo,
"queryRef": queryRef,
})
slog.Warn("query upstream returned invalid paid amount", "orderID", o.ID, "queryRef", queryRef, "paid", resp.Amount)
return ""
}
notificationTradeNo := o.PaymentTradeNo
if upstreamTradeNo := strings.TrimSpace(resp.TradeNo); paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, notificationTradeNo) {
if _, updateErr := s.entClient.PaymentOrder.Update().

View File

@@ -234,6 +234,97 @@ func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) {
require.Equal(t, user.ID, redeemRepo.useCalls[0].userID)
}
func TestVerifyOrderByOutTradeNoRejectsPaidQueryWithZeroAmount(t *testing.T) {
ctx := context.Background()
client := newPaymentOrderLifecycleTestClient(t)
user, err := client.User.Create().
SetEmail("checkpaid-zero-amount@example.com").
SetPasswordHash("hash").
SetUsername("checkpaid-zero-amount-user").
Save(ctx)
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("CHECKPAID-ZERO-AMOUNT").
SetOutTradeNo("sub2_checkpaid_zero_amount").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(ctx)
require.NoError(t, err)
userRepo := &mockUserRepo{
getByIDUser: &User{
ID: user.ID,
Email: user.Email,
Username: user.Username,
Balance: 0,
},
}
redeemRepo := &paymentOrderLifecycleRedeemRepo{
codesByCode: map[string]*RedeemCode{
order.RechargeCode: {
ID: 1,
Code: order.RechargeCode,
Type: RedeemTypeBalance,
Value: order.Amount,
Status: StatusUnused,
},
},
}
redeemService := NewRedeemService(
redeemRepo,
userRepo,
nil,
nil,
nil,
client,
nil,
)
registry := payment.NewRegistry()
provider := &paymentOrderLifecycleQueryProvider{
resp: &payment.QueryOrderResponse{
TradeNo: "upstream-trade-zero",
Status: payment.ProviderStatusPaid,
Amount: 0,
},
}
registry.Register(provider)
svc := &PaymentService{
entClient: client,
registry: registry,
redeemService: redeemService,
userRepo: userRepo,
providersLoaded: true,
}
got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID)
require.NoError(t, err)
require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo)
require.Equal(t, OrderStatusPending, got.Status)
require.Empty(t, got.PaymentTradeNo)
reloaded, err := client.PaymentOrder.Get(ctx, order.ID)
require.NoError(t, err)
require.Equal(t, OrderStatusPending, reloaded.Status)
require.Empty(t, reloaded.PaymentTradeNo)
require.Equal(t, 0.0, userRepo.getByIDUser.Balance)
require.Empty(t, redeemRepo.useCalls)
}
func TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsForAlipay(t *testing.T) {
ctx := context.Background()
client := newPaymentOrderLifecycleTestClient(t)

View File

@@ -159,6 +159,45 @@ func TestMaybeBuildWeChatOAuthRequiredResponseRequiresMPConfigInWeChat(t *testin
}
}
func TestMaybeBuildWeChatOAuthRequiredResponseRequiresResumeSigningKey(t *testing.T) {
t.Parallel()
svc := &PaymentService{
configService: &PaymentConfigService{
settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{
SettingKeyWeChatConnectEnabled: "true",
SettingKeyWeChatConnectAppID: "wx123456",
SettingKeyWeChatConnectAppSecret: "wechat-secret",
SettingKeyWeChatConnectMode: "mp",
SettingKeyWeChatConnectScopes: "snsapi_base",
SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
}},
// Intentionally missing payment resume signing key.
encryptionKey: nil,
},
}
resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{
Amount: 12.5,
PaymentType: payment.TypeWxpay,
IsWeChatBrowser: true,
SrcURL: "https://merchant.example/payment?from=wechat",
OrderType: payment.OrderTypeBalance,
}, 12.5, 12.88, 0.03)
if resp != nil {
t.Fatalf("expected nil response, got %+v", resp)
}
if err == nil {
t.Fatal("expected error, got nil")
}
appErr := infraerrors.FromError(err)
if appErr.Reason != "PAYMENT_RESUME_NOT_CONFIGURED" {
t.Fatalf("reason = %q, want %q", appErr.Reason, "PAYMENT_RESUME_NOT_CONFIGURED")
}
}
func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t *testing.T) {
svc := newWeChatPaymentOAuthTestService(map[string]string{
SettingKeyWeChatConnectEnabled: "true",
@@ -189,7 +228,8 @@ func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t
func newWeChatPaymentOAuthTestService(values map[string]string) *PaymentService {
return &PaymentService{
configService: &PaymentConfigService{
settingRepo: &paymentConfigSettingRepoStub{values: values},
settingRepo: &paymentConfigSettingRepoStub{values: values},
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
},
}
}