diff --git a/backend/internal/payment/provider/easypay.go b/backend/internal/payment/provider/easypay.go index 3fa59283..c54aba6a 100644 --- a/backend/internal/payment/provider/easypay.go +++ b/backend/internal/payment/provider/easypay.go @@ -158,6 +158,7 @@ func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Quer Code int `json:"code"` Msg string `json:"msg"` Status int `json:"status"` + Money string `json:"money"` } if err := json.Unmarshal(body, &resp); err != nil { return nil, fmt.Errorf("easypay parse query: %w", err) @@ -166,7 +167,8 @@ func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Quer if resp.Status == easypayStatusPaid { status = payment.ProviderStatusPaid } - return &payment.QueryOrderResponse{TradeNo: tradeNo, Status: status}, nil + amount, _ := strconv.ParseFloat(resp.Money, 64) + return &payment.QueryOrderResponse{TradeNo: tradeNo, Status: status, Amount: amount}, nil } func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[string]string) (*payment.PaymentNotification, error) { @@ -174,9 +176,10 @@ func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[st if err != nil { return nil, fmt.Errorf("parse notify: %w", err) } + // url.ParseQuery already decodes values — no additional decode needed. params := make(map[string]string) for k := range values { - params[k] = decodeURLValue(values.Get(k)) + params[k] = values.Get(k) } sign := params["sign"] if sign == "" { diff --git a/backend/internal/payment/provider/wxpay_test.go b/backend/internal/payment/provider/wxpay_test.go index 4b774d63..b8b99537 100644 --- a/backend/internal/payment/provider/wxpay_test.go +++ b/backend/internal/payment/provider/wxpay_test.go @@ -156,6 +156,7 @@ func TestNewWxpay(t *testing.T) { "apiV3Key": "12345678901234567890123456789012", // exactly 32 bytes "publicKey": "fake-public-key", "publicKeyId": "key-id-001", + "certSerial": "SERIAL001", } // helper to clone and override config fields diff --git a/backend/internal/server/routes/payment.go b/backend/internal/server/routes/payment.go index 828b68f3..641c6cd5 100644 --- a/backend/internal/server/routes/payment.go +++ b/backend/internal/server/routes/payment.go @@ -40,6 +40,14 @@ func RegisterPaymentRoutes( } } + // --- Public payment endpoints (no auth) --- + // Payment result page needs to verify order status without login + // (user session may have expired during provider redirect). + public := v1.Group("/payment/public") + { + public.POST("/orders/verify", paymentHandler.VerifyOrderPublic) + } + // --- Webhook endpoints (no auth) --- webhook := v1.Group("/payment/webhook") { diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go index 7dd6d835..51307849 100644 --- a/backend/internal/service/payment_fulfillment.go +++ b/backend/internal/service/payment_fulfillment.go @@ -8,6 +8,7 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/internal/payment" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" @@ -32,9 +33,17 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo slog.Error("order not found", "orderID", oid) return nil } - if math.Abs(paid-o.PayAmount) > amountToleranceCNY { - s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo}) - return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid) + // Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount). + // Also skip if paid is NaN/Inf (malformed provider data). + if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) { + if math.Abs(paid-o.PayAmount) > amountToleranceCNY { + s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo}) + return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid) + } + } + // Use order's expected amount when provider didn't report one + if paid <= 0 || math.IsNaN(paid) || math.IsInf(paid, 0) { + paid = o.PayAmount } return s.toPaid(ctx, o, tradeNo, paid, pk) } @@ -241,27 +250,42 @@ func (s *PaymentService) doSub(ctx context.Context, o *dbent.PaymentOrder) error if err != nil || g.Status != payment.EntityStatusActive { return fmt.Errorf("group %d no longer exists or inactive", gid) } - _, _, err = s.subscriptionSvc.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{UserID: o.UserID, GroupID: gid, ValidityDays: days, AssignedBy: 0, Notes: fmt.Sprintf("payment order %d", o.ID)}) + // Idempotency: check audit log to see if subscription was already assigned. + // Prevents double-extension on retry after markCompleted fails. + if s.hasAuditLog(ctx, o.ID, "SUBSCRIPTION_SUCCESS") { + slog.Info("subscription already assigned for order, skipping", "orderID", o.ID, "groupID", gid) + return s.markCompleted(ctx, o, "SUBSCRIPTION_SUCCESS") + } + orderNote := fmt.Sprintf("payment order %d", o.ID) + _, _, err = s.subscriptionSvc.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{UserID: o.UserID, GroupID: gid, ValidityDays: days, AssignedBy: 0, Notes: orderNote}) if err != nil { return fmt.Errorf("assign subscription: %w", err) } - now := time.Now() - _, err = s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(o.ID), paymentorder.StatusEQ(OrderStatusRecharging)).SetStatus(OrderStatusCompleted).SetCompletedAt(now).Save(ctx) - if err != nil { - return fmt.Errorf("mark completed: %w", err) - } - s.writeAuditLog(ctx, o.ID, "SUBSCRIPTION_SUCCESS", "system", map[string]any{"groupId": gid, "days": days, "amount": o.Amount}) - return nil + return s.markCompleted(ctx, o, "SUBSCRIPTION_SUCCESS") +} + +func (s *PaymentService) hasAuditLog(ctx context.Context, orderID int64, action string) bool { + oid := strconv.FormatInt(orderID, 10) + c, _ := s.entClient.PaymentAuditLog.Query(). + Where(paymentauditlog.OrderIDEQ(oid), paymentauditlog.ActionEQ(action)). + Limit(1).Count(ctx) + return c > 0 } func (s *PaymentService) markFailed(ctx context.Context, oid int64, cause error) { now := time.Now() r := psErrMsg(cause) - _, e := s.entClient.PaymentOrder.UpdateOneID(oid).SetStatus(OrderStatusFailed).SetFailedAt(now).SetFailedReason(r).Save(ctx) + // Only mark FAILED if still in RECHARGING state — prevents overwriting + // a COMPLETED order when markCompleted failed but fulfillment succeeded. + c, e := s.entClient.PaymentOrder.Update(). + Where(paymentorder.IDEQ(oid), paymentorder.StatusEQ(OrderStatusRecharging)). + SetStatus(OrderStatusFailed).SetFailedAt(now).SetFailedReason(r).Save(ctx) if e != nil { slog.Error("mark FAILED", "orderID", oid, "error", e) } - s.writeAuditLog(ctx, oid, "FULFILLMENT_FAILED", "system", map[string]any{"reason": r}) + if c > 0 { + s.writeAuditLog(ctx, oid, "FULFILLMENT_FAILED", "system", map[string]any{"reason": r}) + } } func (s *PaymentService) RetryFulfillment(ctx context.Context, oid int64) error { diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go index d61a0d88..e81af3f5 100644 --- a/backend/internal/service/payment_order.go +++ b/backend/internal/service/payment_order.go @@ -72,6 +72,9 @@ func (s *PaymentService) validateOrderInput(ctx context.Context, req CreateOrder if req.OrderType == payment.OrderTypeSubscription { return s.validateSubOrder(ctx, req) } + if math.IsNaN(req.Amount) || math.IsInf(req.Amount, 0) || req.Amount <= 0 { + return nil, infraerrors.BadRequest("INVALID_AMOUNT", "amount must be a positive number") + } if (cfg.MinAmount > 0 && req.Amount < cfg.MinAmount) || (cfg.MaxAmount > 0 && req.Amount > cfg.MaxAmount) { return nil, infraerrors.BadRequest("INVALID_AMOUNT", "amount out of range"). WithMetadata(map[string]string{"min": fmt.Sprintf("%.2f", cfg.MinAmount), "max": fmt.Sprintf("%.2f", cfg.MaxAmount)}) @@ -394,7 +397,7 @@ func (s *PaymentService) AdminCancelOrder(ctx context.Context, orderID int64) (s } func (s *PaymentService) cancelCore(ctx context.Context, o *dbent.PaymentOrder, fs, op, ad string) (string, error) { - if o.PaymentTradeNo != "" && o.PaymentType != "" { + if o.PaymentTradeNo != "" || o.PaymentType != "" { if s.checkPaid(ctx, o) == "already_paid" { return "already_paid", nil } @@ -404,14 +407,17 @@ func (s *PaymentService) cancelCore(ctx context.Context, o *dbent.PaymentOrder, return "", fmt.Errorf("update order status: %w", err) } if c > 0 { - s.writeAuditLog(ctx, o.ID, "ORDER_CANCELLED", op, map[string]any{"detail": ad}) + auditAction := "ORDER_CANCELLED" + if fs == OrderStatusExpired { + auditAction = "ORDER_EXPIRED" + } + s.writeAuditLog(ctx, o.ID, auditAction, op, map[string]any{"detail": ad}) } return "cancelled", nil } func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) string { - s.EnsureProviders(ctx) - prov, err := s.registry.GetProvider(o.PaymentType) + prov, err := s.getOrderProvider(ctx, o) if err != nil { return "" } @@ -427,11 +433,14 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s return "" } if resp.Status == payment.ProviderStatusPaid { - _ = s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: o.PaymentTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey()) + if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: o.PaymentTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey()); err != nil { + slog.Error("fulfillment failed during checkPaid", "orderID", o.ID, "error", err) + // Still return already_paid — order was paid, fulfillment can be retried + } return "already_paid" } if cp, ok := prov.(payment.CancelableProvider); ok { - _ = cp.CancelPayment(ctx, o.PaymentTradeNo) + _ = cp.CancelPayment(ctx, tradeNo) } return "" } @@ -463,6 +472,27 @@ func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo return o, nil } +// VerifyOrderPublic verifies payment status without user authentication. +// Used by the payment result page when the user's session has expired. +func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) { + o, err := s.entClient.PaymentOrder.Query(). + Where(paymentorder.OutTradeNo(outTradeNo)). + Only(ctx) + if err != nil { + return nil, infraerrors.NotFound("NOT_FOUND", "order not found") + } + if o.Status == OrderStatusPending || o.Status == OrderStatusExpired { + result := s.checkPaid(ctx, o) + if result == "already_paid" { + o, err = s.entClient.PaymentOrder.Get(ctx, o.ID) + if err != nil { + return nil, fmt.Errorf("reload order: %w", err) + } + } + } + return o, nil +} + func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) { now := time.Now() orders, err := s.entClient.PaymentOrder.Query().Where(paymentorder.StatusEQ(OrderStatusPending), paymentorder.ExpiresAtLTE(now)).All(ctx) @@ -471,34 +501,39 @@ func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) } n := 0 for _, o := range orders { - // Cancel upstream payment (e.g. Stripe PaymentIntent) before marking expired - s.cancelUpstreamPayment(ctx, o) - c, e := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(o.ID), paymentorder.StatusEQ(OrderStatusPending)).SetStatus(OrderStatusExpired).Save(ctx) - if e != nil { - slog.Warn("expire failed", "orderID", o.ID, "error", e) + // Check upstream payment status before expiring — the user may have + // paid just before timeout and the webhook hasn't arrived yet. + outcome, _ := s.cancelCore(ctx, o, OrderStatusExpired, "system", "order expired") + if outcome == "already_paid" { + slog.Info("order was paid during expiry", "orderID", o.ID) continue } - if c > 0 { - s.writeAuditLog(ctx, o.ID, "ORDER_EXPIRED", "system", map[string]any{"expiresAt": o.ExpiresAt.Format(time.RFC3339)}) + if outcome != "" { n++ } } return n, nil } -// cancelUpstreamPayment attempts to cancel the upstream provider payment (e.g. Stripe PaymentIntent). -func (s *PaymentService) cancelUpstreamPayment(ctx context.Context, o *dbent.PaymentOrder) { - if o.PaymentTradeNo == "" || o.PaymentType == "" { - return - } - s.EnsureProviders(ctx) - prov, err := s.registry.GetProvider(o.PaymentType) - if err != nil { - return - } - if cp, ok := prov.(payment.CancelableProvider); ok { - if err := cp.CancelPayment(ctx, o.PaymentTradeNo); err != nil { - slog.Warn("cancel upstream payment failed", "orderID", o.ID, "tradeNo", o.PaymentTradeNo, "error", err) +// getOrderProvider creates a provider using the order's original instance config. +// Falls back to registry lookup if instance ID is missing (legacy orders). +func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) { + if o.ProviderInstanceID != nil && *o.ProviderInstanceID != "" { + instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64) + if err == nil { + cfg, err := s.loadBalancer.GetInstanceConfig(ctx, instID) + if err == nil { + providerKey := s.registry.GetProviderKey(o.PaymentType) + if providerKey == "" { + providerKey = o.PaymentType + } + p, err := provider.CreateProvider(providerKey, *o.ProviderInstanceID, cfg) + if err == nil { + return p, nil + } + } } } + s.EnsureProviders(ctx) + return s.registry.GetProvider(o.PaymentType) } diff --git a/backend/internal/service/payment_refund.go b/backend/internal/service/payment_refund.go index f3d20509..68f9c697 100644 --- a/backend/internal/service/payment_refund.go +++ b/backend/internal/service/payment_refund.go @@ -69,14 +69,18 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float if !psSliceContains(ok, o.Status) { return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund") } + if math.IsNaN(amt) || math.IsInf(amt, 0) { + return nil, nil, infraerrors.BadRequest("INVALID_AMOUNT", "invalid refund amount") + } if amt <= 0 { amt = o.Amount } - if amt > o.Amount { + if amt-o.Amount > amountToleranceCNY { return nil, nil, infraerrors.BadRequest("REFUND_AMOUNT_EXCEEDED", "refund amount exceeds recharge") } + // Full refund: use actual pay_amount for gateway (includes fees) ga := amt - if amt == o.Amount { + if math.Abs(amt-o.Amount) <= amountToleranceCNY { ga = o.PayAmount } rr := strings.TrimSpace(reason) @@ -121,9 +125,16 @@ func (s *PaymentService) ExecuteRefund(ctx context.Context, p *RefundPlan) (*Ref return nil, infraerrors.Conflict("CONFLICT", "order status changed") } if p.DeductionType == payment.DeductionTypeBalance && p.BalanceToDeduct > 0 { - if err := s.userRepo.DeductBalance(ctx, p.Order.UserID, p.BalanceToDeduct); err != nil { - s.restoreStatus(ctx, p) - return nil, fmt.Errorf("deduction: %w", err) + // Skip balance deduction on retry if previous attempt already deducted + // but failed to roll back (REFUND_ROLLBACK_FAILED in audit log). + if !s.hasAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED") { + if err := s.userRepo.DeductBalance(ctx, p.Order.UserID, p.BalanceToDeduct); err != nil { + s.restoreStatus(ctx, p) + return nil, fmt.Errorf("deduction: %w", err) + } + } else { + slog.Warn("skipping balance deduction on retry (previous rollback failed)", "orderID", p.OrderID) + p.BalanceToDeduct = 0 } } if err := s.gwRefund(ctx, p); err != nil { @@ -137,15 +148,28 @@ func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error { s.writeAuditLog(ctx, p.Order.ID, "REFUND_NO_TRADE_NO", "admin", map[string]any{"detail": "skipped"}) return nil } - s.EnsureProviders(ctx) - prov, err := s.registry.GetProvider(p.Order.PaymentType) + + // Use the exact provider instance that created this order, not a random one + // from the registry. Each instance has its own merchant credentials. + prov, err := s.getRefundProvider(ctx, p.Order) if err != nil { - return fmt.Errorf("get provider: %w", err) + return fmt.Errorf("get refund provider: %w", err) } - _, err = prov.Refund(ctx, payment.RefundRequest{TradeNo: p.Order.PaymentTradeNo, OrderID: p.Order.OutTradeNo, Amount: strconv.FormatFloat(p.GatewayAmount, 'f', 2, 64), Reason: p.Reason}) + _, err = prov.Refund(ctx, payment.RefundRequest{ + TradeNo: p.Order.PaymentTradeNo, + OrderID: p.Order.OutTradeNo, + Amount: strconv.FormatFloat(p.GatewayAmount, 'f', 2, 64), + Reason: p.Reason, + }) return err } +// getRefundProvider creates a provider using the order's original instance config. +// Delegates to getOrderProvider which handles instance lookup and fallback. +func (s *PaymentService) getRefundProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) { + return s.getOrderProvider(ctx, o) +} + func (s *PaymentService) handleGwFail(ctx context.Context, p *RefundPlan, gErr error) (*RefundResult, error) { if s.RollbackRefund(ctx, p, gErr) { s.restoreStatus(ctx, p) diff --git a/frontend/src/views/user/PaymentResultView.vue b/frontend/src/views/user/PaymentResultView.vue index cf0bf373..3c7df572 100644 --- a/frontend/src/views/user/PaymentResultView.vue +++ b/frontend/src/views/user/PaymentResultView.vue @@ -136,14 +136,17 @@ onMounted(async () => { } } - // If we have an out_trade_no from a provider return URL, actively verify - // the payment with the upstream provider (handles missed notify callbacks) + // Verify payment via public endpoint (works without login) if (outTradeNo) { try { - const result = await paymentAPI.verifyOrder(outTradeNo) + const result = await paymentAPI.verifyOrderPublic(outTradeNo) order.value = result.data } catch (_err: unknown) { - // Verification failed, fall through to normal order lookup + // Public verify failed, try authenticated endpoint if logged in + try { + const result = await paymentAPI.verifyOrder(outTradeNo) + order.value = result.data + } catch (_e: unknown) { /* fall through */ } } }