Files
sub2api/backend/internal/service/payment_refund.go
erio c738cfec93 fix(payment): critical audit fixes for security, idempotency and correctness
Backend fixes:
- #1: doSub subscription idempotency via audit log check
- #2: markFailed only when status=RECHARGING (prevents overwriting COMPLETED)
- #3: ExpireTimedOutOrders checks upstream payment before expiring
- #4: Public verify endpoint for payment result page (no auth required)
- #5: EasyPay QueryOrder returns amount, confirmPayment handles zero amount
- #6: WxPay notifyUrl priority: request-first, config-fallback
- #7: EasyPay remove double URL decode in VerifyNotification
- #8: checkPaid/cancelUpstreamPayment use order's provider instance
- #9: Amount NaN/Inf/negative validation in order creation and refund
- #10: Refund amount comparison uses tolerance instead of float64 ==
- #11: Skip balance deduction on retry when previous rollback failed
- #12: checkPaid logs fulfillment errors instead of silently ignoring
- #13: WxPay certSerial added to required config fields

Frontend fixes:
- Payment result page no longer requires authentication
- Public verify API fallback for expired sessions
2026-04-14 09:19:33 +08:00

217 lines
8.5 KiB
Go

package service
import (
"context"
"fmt"
"log/slog"
"math"
"strconv"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// --- Refund Flow ---
func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error {
o, err := s.validateRefundRequest(ctx, oid, uid)
if err != nil {
return err
}
u, err := s.userRepo.GetByID(ctx, o.UserID)
if err != nil {
return fmt.Errorf("get user: %w", err)
}
if u.Balance < o.Amount {
return infraerrors.BadRequest("BALANCE_NOT_ENOUGH", "refund amount exceeds balance")
}
nr := strings.TrimSpace(reason)
now := time.Now()
by := fmt.Sprintf("%d", uid)
c, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(oid), paymentorder.UserIDEQ(uid), paymentorder.StatusEQ(OrderStatusCompleted), paymentorder.OrderTypeEQ(payment.OrderTypeBalance)).SetStatus(OrderStatusRefundRequested).SetRefundRequestedAt(now).SetRefundRequestReason(nr).SetRefundRequestedBy(by).SetRefundAmount(o.Amount).Save(ctx)
if err != nil {
return fmt.Errorf("update: %w", err)
}
if c == 0 {
return infraerrors.Conflict("CONFLICT", "order status changed")
}
s.writeAuditLog(ctx, oid, "REFUND_REQUESTED", fmt.Sprintf("user:%d", uid), map[string]any{"amount": o.Amount, "reason": nr})
return nil
}
func (s *PaymentService) validateRefundRequest(ctx context.Context, oid, uid int64) (*dbent.PaymentOrder, error) {
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
if err != nil {
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.UserID != uid {
return nil, infraerrors.Forbidden("FORBIDDEN", "no permission")
}
if o.OrderType != payment.OrderTypeBalance {
return nil, infraerrors.BadRequest("INVALID_ORDER_TYPE", "only balance orders can request refund")
}
if o.Status != OrderStatusCompleted {
return nil, infraerrors.BadRequest("INVALID_STATUS", "only completed orders can request refund")
}
return o, nil
}
func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float64, reason string, force, deduct bool) (*RefundPlan, *RefundResult, error) {
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
if err != nil {
return nil, nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
ok := []string{OrderStatusCompleted, OrderStatusRefundRequested, OrderStatusRefundFailed}
if !psSliceContains(ok, o.Status) {
return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund")
}
if math.IsNaN(amt) || math.IsInf(amt, 0) {
return nil, nil, infraerrors.BadRequest("INVALID_AMOUNT", "invalid refund amount")
}
if amt <= 0 {
amt = o.Amount
}
if amt-o.Amount > amountToleranceCNY {
return nil, nil, infraerrors.BadRequest("REFUND_AMOUNT_EXCEEDED", "refund amount exceeds recharge")
}
// Full refund: use actual pay_amount for gateway (includes fees)
ga := amt
if math.Abs(amt-o.Amount) <= amountToleranceCNY {
ga = o.PayAmount
}
rr := strings.TrimSpace(reason)
if rr == "" && o.RefundRequestReason != nil {
rr = *o.RefundRequestReason
}
if rr == "" {
rr = fmt.Sprintf("refund order:%d", o.ID)
}
p := &RefundPlan{OrderID: oid, Order: o, RefundAmount: amt, GatewayAmount: ga, Reason: rr, Force: force, DeductBalance: deduct, DeductionType: payment.DeductionTypeNone}
if deduct {
if er := s.prepDeduct(ctx, o, p, force); er != nil {
return nil, er, nil
}
}
return p, nil, nil
}
func (s *PaymentService) prepDeduct(ctx context.Context, o *dbent.PaymentOrder, p *RefundPlan, force bool) *RefundResult {
if o.OrderType == payment.OrderTypeSubscription {
p.DeductionType = payment.DeductionTypeSubscription
return nil
}
u, err := s.userRepo.GetByID(ctx, o.UserID)
if err != nil {
if !force {
return &RefundResult{Success: false, Warning: "cannot fetch user balance, use force", RequireForce: true}
}
return nil
}
p.DeductionType = payment.DeductionTypeBalance
p.BalanceToDeduct = math.Min(p.RefundAmount, u.Balance)
return nil
}
func (s *PaymentService) ExecuteRefund(ctx context.Context, p *RefundPlan) (*RefundResult, error) {
c, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(p.OrderID), paymentorder.StatusIn(OrderStatusCompleted, OrderStatusRefundRequested, OrderStatusRefundFailed)).SetStatus(OrderStatusRefunding).Save(ctx)
if err != nil {
return nil, fmt.Errorf("lock: %w", err)
}
if c == 0 {
return nil, infraerrors.Conflict("CONFLICT", "order status changed")
}
if p.DeductionType == payment.DeductionTypeBalance && p.BalanceToDeduct > 0 {
// Skip balance deduction on retry if previous attempt already deducted
// but failed to roll back (REFUND_ROLLBACK_FAILED in audit log).
if !s.hasAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED") {
if err := s.userRepo.DeductBalance(ctx, p.Order.UserID, p.BalanceToDeduct); err != nil {
s.restoreStatus(ctx, p)
return nil, fmt.Errorf("deduction: %w", err)
}
} else {
slog.Warn("skipping balance deduction on retry (previous rollback failed)", "orderID", p.OrderID)
p.BalanceToDeduct = 0
}
}
if err := s.gwRefund(ctx, p); err != nil {
return s.handleGwFail(ctx, p, err)
}
return s.markRefundOk(ctx, p)
}
func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error {
if p.Order.PaymentTradeNo == "" {
s.writeAuditLog(ctx, p.Order.ID, "REFUND_NO_TRADE_NO", "admin", map[string]any{"detail": "skipped"})
return nil
}
// Use the exact provider instance that created this order, not a random one
// from the registry. Each instance has its own merchant credentials.
prov, err := s.getRefundProvider(ctx, p.Order)
if err != nil {
return fmt.Errorf("get refund provider: %w", err)
}
_, err = prov.Refund(ctx, payment.RefundRequest{
TradeNo: p.Order.PaymentTradeNo,
OrderID: p.Order.OutTradeNo,
Amount: strconv.FormatFloat(p.GatewayAmount, 'f', 2, 64),
Reason: p.Reason,
})
return err
}
// getRefundProvider creates a provider using the order's original instance config.
// Delegates to getOrderProvider which handles instance lookup and fallback.
func (s *PaymentService) getRefundProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
return s.getOrderProvider(ctx, o)
}
func (s *PaymentService) handleGwFail(ctx context.Context, p *RefundPlan, gErr error) (*RefundResult, error) {
if s.RollbackRefund(ctx, p, gErr) {
s.restoreStatus(ctx, p)
s.writeAuditLog(ctx, p.OrderID, "REFUND_GATEWAY_FAILED", "admin", map[string]any{"detail": psErrMsg(gErr)})
return &RefundResult{Success: false, Warning: "gateway failed: " + psErrMsg(gErr) + ", rolled back"}, nil
}
now := time.Now()
_, _ = s.entClient.PaymentOrder.UpdateOneID(p.OrderID).SetStatus(OrderStatusRefundFailed).SetFailedAt(now).SetFailedReason(psErrMsg(gErr)).Save(ctx)
s.writeAuditLog(ctx, p.OrderID, "REFUND_FAILED", "admin", map[string]any{"detail": psErrMsg(gErr)})
return nil, infraerrors.InternalServer("REFUND_FAILED", psErrMsg(gErr))
}
func (s *PaymentService) markRefundOk(ctx context.Context, p *RefundPlan) (*RefundResult, error) {
fs := OrderStatusRefunded
if p.RefundAmount < p.Order.Amount {
fs = OrderStatusPartiallyRefunded
}
now := time.Now()
_, err := s.entClient.PaymentOrder.UpdateOneID(p.OrderID).SetStatus(fs).SetRefundAmount(p.RefundAmount).SetRefundReason(p.Reason).SetRefundAt(now).SetForceRefund(p.Force).Save(ctx)
if err != nil {
return nil, fmt.Errorf("mark refund: %w", err)
}
s.writeAuditLog(ctx, p.OrderID, "REFUND_SUCCESS", "admin", map[string]any{"refundAmount": p.RefundAmount, "reason": p.Reason, "balanceDeducted": p.BalanceToDeduct, "force": p.Force})
return &RefundResult{Success: true, BalanceDeducted: p.BalanceToDeduct, SubDaysDeducted: p.SubDaysToDeduct}, nil
}
func (s *PaymentService) RollbackRefund(ctx context.Context, p *RefundPlan, gErr error) bool {
if p.DeductionType == payment.DeductionTypeBalance && p.BalanceToDeduct > 0 {
if err := s.userRepo.UpdateBalance(ctx, p.Order.UserID, p.BalanceToDeduct); err != nil {
slog.Error("[CRITICAL] rollback failed", "orderID", p.OrderID, "amount", p.BalanceToDeduct, "error", err)
s.writeAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED", "admin", map[string]any{"gatewayError": psErrMsg(gErr), "rollbackError": psErrMsg(err), "balanceDeducted": p.BalanceToDeduct})
return false
}
}
return true
}
func (s *PaymentService) restoreStatus(ctx context.Context, p *RefundPlan) {
rs := OrderStatusCompleted
if p.Order.Status == OrderStatusRefundRequested {
rs = OrderStatusRefundRequested
}
_, _ = s.entClient.PaymentOrder.UpdateOneID(p.OrderID).SetStatus(rs).Save(ctx)
}