package service import ( "context" "fmt" "log/slog" "strconv" "strings" "time" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/Wei-Shaw/sub2api/internal/payment/provider" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) // --- Cancel & Expire --- // Cancel rate limit configuration constants. const ( rateLimitUnitDay = "day" rateLimitUnitMinute = "minute" rateLimitUnitHour = "hour" rateLimitModeFixed = "fixed" checkPaidResultAlreadyPaid = "already_paid" checkPaidResultCancelled = "cancelled" ) func (s *PaymentService) checkCancelRateLimit(ctx context.Context, userID int64, cfg *PaymentConfig) error { if !cfg.CancelRateLimitEnabled || cfg.CancelRateLimitMax <= 0 { return nil } windowStart := cancelRateLimitWindowStart(cfg) operator := fmt.Sprintf("user:%d", userID) count, err := s.entClient.PaymentAuditLog.Query(). Where( paymentauditlog.ActionEQ("ORDER_CANCELLED"), paymentauditlog.OperatorEQ(operator), paymentauditlog.CreatedAtGTE(windowStart), ).Count(ctx) if err != nil { slog.Error("check cancel rate limit failed", "userID", userID, "error", err) return nil // fail open } if count >= cfg.CancelRateLimitMax { return infraerrors.TooManyRequests("CANCEL_RATE_LIMITED", "cancel rate limited"). WithMetadata(map[string]string{ "max": strconv.Itoa(cfg.CancelRateLimitMax), "window": strconv.Itoa(cfg.CancelRateLimitWindow), "unit": cfg.CancelRateLimitUnit, }) } return nil } func cancelRateLimitWindowStart(cfg *PaymentConfig) time.Time { now := time.Now() w := cfg.CancelRateLimitWindow if w <= 0 { w = 1 } unit := cfg.CancelRateLimitUnit if unit == "" { unit = rateLimitUnitDay } if cfg.CancelRateLimitMode == rateLimitModeFixed { switch unit { case rateLimitUnitMinute: t := now.Truncate(time.Minute) return t.Add(-time.Duration(w-1) * time.Minute) case rateLimitUnitDay: y, m, d := now.Date() t := time.Date(y, m, d, 0, 0, 0, 0, now.Location()) return t.AddDate(0, 0, -(w - 1)) default: // hour t := now.Truncate(time.Hour) return t.Add(-time.Duration(w-1) * time.Hour) } } // rolling window switch unit { case rateLimitUnitMinute: return now.Add(-time.Duration(w) * time.Minute) case rateLimitUnitDay: return now.AddDate(0, 0, -w) default: // hour return now.Add(-time.Duration(w) * time.Hour) } } func (s *PaymentService) CancelOrder(ctx context.Context, orderID, userID int64) (string, error) { o, err := s.entClient.PaymentOrder.Get(ctx, orderID) if err != nil { return "", infraerrors.NotFound("NOT_FOUND", "order not found") } if o.UserID != userID { return "", infraerrors.Forbidden("FORBIDDEN", "no permission for this order") } if o.Status != OrderStatusPending { return "", infraerrors.BadRequest("INVALID_STATUS", "order cannot be cancelled in current status") } return s.cancelCore(ctx, o, OrderStatusCancelled, fmt.Sprintf("user:%d", userID), "user cancelled order") } func (s *PaymentService) AdminCancelOrder(ctx context.Context, orderID int64) (string, error) { o, err := s.entClient.PaymentOrder.Get(ctx, orderID) if err != nil { return "", infraerrors.NotFound("NOT_FOUND", "order not found") } if o.Status != OrderStatusPending { return "", infraerrors.BadRequest("INVALID_STATUS", "order cannot be cancelled in current status") } return s.cancelCore(ctx, o, OrderStatusCancelled, "admin", "admin cancelled order") } func (s *PaymentService) cancelCore(ctx context.Context, o *dbent.PaymentOrder, fs, op, ad string) (string, error) { if o.PaymentTradeNo != "" || o.PaymentType != "" { if s.checkPaid(ctx, o) == checkPaidResultAlreadyPaid { return checkPaidResultAlreadyPaid, nil } } c, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(o.ID), paymentorder.StatusEQ(OrderStatusPending)).SetStatus(fs).Save(ctx) if err != nil { return "", fmt.Errorf("update order status: %w", err) } if c > 0 { auditAction := "ORDER_CANCELLED" if fs == OrderStatusExpired { auditAction = "ORDER_EXPIRED" } s.writeAuditLog(ctx, o.ID, auditAction, op, map[string]any{"detail": ad}) } return checkPaidResultCancelled, nil } func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) string { prov, err := s.getOrderProvider(ctx, o) if err != nil { return "" } queryRef := paymentOrderQueryReference(o, prov) if queryRef == "" { return "" } resp, err := prov.QueryOrder(ctx, queryRef) if err != nil { slog.Warn("query upstream failed", "orderID", o.ID, "error", err) return "" } if resp.Status == payment.ProviderStatusPaid { if !isValidProviderAmount(resp.Amount) { s.writeAuditLog(ctx, o.ID, "PAYMENT_INVALID_AMOUNT", prov.ProviderKey(), map[string]any{ "expected": o.PayAmount, "paid": resp.Amount, "tradeNo": resp.TradeNo, "queryRef": queryRef, }) slog.Warn("query upstream returned invalid paid amount", "orderID", o.ID, "queryRef", queryRef, "paid", resp.Amount) retriedResp, retryOK := requeryPaidOrderOnce(ctx, prov, queryRef) if !retryOK { return "" } resp = retriedResp } notificationTradeNo := o.PaymentTradeNo if upstreamTradeNo := strings.TrimSpace(resp.TradeNo); paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, notificationTradeNo) { if _, updateErr := s.entClient.PaymentOrder.Update(). Where(paymentorder.IDEQ(o.ID)). SetPaymentTradeNo(upstreamTradeNo). Save(ctx); updateErr != nil { slog.Error("persist upstream trade no during checkPaid failed", "orderID", o.ID, "tradeNo", upstreamTradeNo, "error", updateErr) } else { o.PaymentTradeNo = upstreamTradeNo } notificationTradeNo = upstreamTradeNo } if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: notificationTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess, Metadata: resp.Metadata}, prov.ProviderKey()); err != nil { slog.Error("fulfillment failed during checkPaid", "orderID", o.ID, "error", err) // Still return already_paid — order was paid, fulfillment can be retried } return checkPaidResultAlreadyPaid } if cp, ok := prov.(payment.CancelableProvider); ok { _ = cp.CancelPayment(ctx, queryRef) } return "" } func requeryPaidOrderOnce(ctx context.Context, prov payment.Provider, queryRef string) (*payment.QueryOrderResponse, bool) { if prov == nil || strings.TrimSpace(queryRef) == "" { return nil, false } resp, err := prov.QueryOrder(ctx, queryRef) if err != nil { slog.Warn("query upstream retry failed", "queryRef", queryRef, "error", err) return nil, false } if resp == nil || resp.Status != payment.ProviderStatusPaid || !isValidProviderAmount(resp.Amount) { return nil, false } return resp, true } func paymentOrderQueryReference(order *dbent.PaymentOrder, prov payment.Provider) string { if order == nil { return "" } providerKey := "" if prov != nil { providerKey = strings.TrimSpace(prov.ProviderKey()) } if providerKey == "" { if snapshot := psOrderProviderSnapshot(order); snapshot != nil { providerKey = strings.TrimSpace(snapshot.ProviderKey) } } if providerKey == "" { providerKey = strings.TrimSpace(psStringValue(order.ProviderKey)) } if providerKey == "" { providerKey = strings.TrimSpace(order.PaymentType) } switch payment.GetBasePaymentType(providerKey) { case payment.TypeAlipay, payment.TypeEasyPay, payment.TypeWxpay: return strings.TrimSpace(order.OutTradeNo) default: if tradeNo := strings.TrimSpace(order.PaymentTradeNo); tradeNo != "" { return tradeNo } return strings.TrimSpace(order.OutTradeNo) } } func paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, currentTradeNo string) bool { upstreamTradeNo = strings.TrimSpace(upstreamTradeNo) if upstreamTradeNo == "" { return false } if strings.EqualFold(upstreamTradeNo, strings.TrimSpace(currentTradeNo)) { return false } if strings.EqualFold(upstreamTradeNo, strings.TrimSpace(queryRef)) { return false } return true } // VerifyOrderByOutTradeNo actively queries the upstream provider to check // if a payment was made, and processes it if so. This handles the case where // the provider's notify callback was missed (e.g. EasyPay popup mode). func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo string, userID int64) (*dbent.PaymentOrder, error) { outTradeNo, err := normalizeOrderLookupOutTradeNo(outTradeNo) if err != nil { return nil, err } o, err := s.entClient.PaymentOrder.Query(). Where(paymentorder.OutTradeNo(outTradeNo)). Only(ctx) if err != nil { return nil, infraerrors.NotFound("NOT_FOUND", "order not found") } if o.UserID != userID { return nil, infraerrors.Forbidden("FORBIDDEN", "no permission for this order") } // Only verify orders that are still pending or recently expired if o.Status == OrderStatusPending || o.Status == OrderStatusExpired { result := s.checkPaid(ctx, o) if result == checkPaidResultAlreadyPaid { // Reload order to get updated status o, err = s.entClient.PaymentOrder.Get(ctx, o.ID) if err != nil { return nil, fmt.Errorf("reload order: %w", err) } } } return o, nil } // VerifyOrderPublic returns the currently persisted public order state without // triggering any upstream reconciliation. Signed resume-token recovery is the // only public recovery path allowed to query upstream state. func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) { outTradeNo, err := normalizeOrderLookupOutTradeNo(outTradeNo) if err != nil { return nil, err } o, err := s.entClient.PaymentOrder.Query(). Where(paymentorder.OutTradeNo(outTradeNo)). Only(ctx) if err != nil { return nil, infraerrors.NotFound("NOT_FOUND", "order not found") } return o, nil } func normalizeOrderLookupOutTradeNo(raw string) (string, error) { outTradeNo := strings.TrimSpace(raw) if outTradeNo == "" { return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is required") } if len(outTradeNo) > 64 { return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is invalid") } for _, ch := range outTradeNo { switch { case ch >= 'a' && ch <= 'z': case ch >= 'A' && ch <= 'Z': case ch >= '0' && ch <= '9': case ch == '_' || ch == '-': default: return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is invalid") } } return outTradeNo, nil } func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) { now := time.Now() orders, err := s.entClient.PaymentOrder.Query().Where(paymentorder.StatusEQ(OrderStatusPending), paymentorder.ExpiresAtLTE(now)).All(ctx) if err != nil { return 0, fmt.Errorf("query expired: %w", err) } n := 0 for _, o := range orders { // Check upstream payment status before expiring — the user may have // paid just before timeout and the webhook hasn't arrived yet. outcome, _ := s.cancelCore(ctx, o, OrderStatusExpired, "system", "order expired") if outcome == checkPaidResultAlreadyPaid { slog.Info("order was paid during expiry", "orderID", o.ID) continue } if outcome != "" { n++ } } return n, nil } // getOrderProvider creates a provider using the order's original instance config. // Falls back to registry lookup if instance ID is missing (legacy orders). func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) { inst, err := s.getOrderProviderInstance(ctx, o) if err != nil { return nil, fmt.Errorf("load order provider instance: %w", err) } if inst != nil { return s.createProviderFromInstance(ctx, inst) } if !paymentOrderAllowsRegistryFallback(o) { return nil, fmt.Errorf("order %d provider instance is unresolved", o.ID) } providerKey := paymentOrderFallbackProviderKey(s.registry, o) if providerKey == "" { return nil, fmt.Errorf("order %d provider fallback key is missing", o.ID) } if !s.webhookRegistryFallbackAllowed(ctx, providerKey) { return nil, fmt.Errorf("order %d provider fallback is ambiguous for %s", o.ID, providerKey) } s.EnsureProviders(ctx) return s.registry.GetProvider(o.PaymentType) } func paymentOrderAllowsRegistryFallback(order *dbent.PaymentOrder) bool { if order == nil { return false } if psOrderProviderSnapshot(order) != nil { return false } if strings.TrimSpace(psStringValue(order.ProviderInstanceID)) != "" { return false } if strings.TrimSpace(psStringValue(order.ProviderKey)) != "" { return false } return true } func paymentOrderFallbackProviderKey(registry *payment.Registry, order *dbent.PaymentOrder) string { if order == nil { return "" } if registry != nil { if key := strings.TrimSpace(registry.GetProviderKey(payment.PaymentType(order.PaymentType))); key != "" { return key } } return strings.TrimSpace(payment.GetBasePaymentType(strings.TrimSpace(order.PaymentType))) } func (s *PaymentService) createProviderFromInstance(ctx context.Context, inst *dbent.PaymentProviderInstance) (payment.Provider, error) { if inst == nil { return nil, fmt.Errorf("payment provider instance is missing") } cfg, err := s.loadBalancer.GetInstanceConfig(ctx, int64(inst.ID)) if err != nil { return nil, fmt.Errorf("load provider instance config: %w", err) } if inst.PaymentMode != "" { cfg["paymentMode"] = inst.PaymentMode } instID := strconv.FormatInt(int64(inst.ID), 10) prov, err := provider.CreateProvider(inst.ProviderKey, instID, cfg) if err != nil { return nil, fmt.Errorf("create provider from instance: %w", err) } return prov, nil } func psStringValue(value *string) string { if value == nil { return "" } return *value }