feat(payment): add complete payment system with multi-provider support
Add a full payment and subscription system supporting EasyPay (Alipay/WeChat), Stripe, and direct Alipay/WeChat Pay providers with multi-instance load balancing.
This commit is contained in:
323
backend/internal/handler/admin/payment_handler.go
Normal file
323
backend/internal/handler/admin/payment_handler.go
Normal file
@@ -0,0 +1,323 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// PaymentHandler handles admin payment management.
|
||||
type PaymentHandler struct {
|
||||
paymentService *service.PaymentService
|
||||
configService *service.PaymentConfigService
|
||||
}
|
||||
|
||||
// NewPaymentHandler creates a new admin PaymentHandler.
|
||||
func NewPaymentHandler(paymentService *service.PaymentService, configService *service.PaymentConfigService) *PaymentHandler {
|
||||
return &PaymentHandler{
|
||||
paymentService: paymentService,
|
||||
configService: configService,
|
||||
}
|
||||
}
|
||||
|
||||
// --- Dashboard ---
|
||||
|
||||
// GetDashboard returns payment dashboard statistics.
|
||||
// GET /api/v1/admin/payment/dashboard
|
||||
func (h *PaymentHandler) GetDashboard(c *gin.Context) {
|
||||
days := 30
|
||||
if d := c.Query("days"); d != "" {
|
||||
if v, err := strconv.Atoi(d); err == nil && v > 0 {
|
||||
days = v
|
||||
}
|
||||
}
|
||||
stats, err := h.paymentService.GetDashboardStats(c.Request.Context(), days)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, stats)
|
||||
}
|
||||
|
||||
// --- Orders ---
|
||||
|
||||
// ListOrders returns a paginated list of all payment orders.
|
||||
// GET /api/v1/admin/payment/orders
|
||||
func (h *PaymentHandler) ListOrders(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
var userID int64
|
||||
if uid := c.Query("user_id"); uid != "" {
|
||||
if v, err := strconv.ParseInt(uid, 10, 64); err == nil {
|
||||
userID = v
|
||||
}
|
||||
}
|
||||
orders, total, err := h.paymentService.AdminListOrders(c.Request.Context(), userID, service.OrderListParams{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Status: c.Query("status"),
|
||||
OrderType: c.Query("order_type"),
|
||||
PaymentType: c.Query("payment_type"),
|
||||
Keyword: c.Query("keyword"),
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Paginated(c, orders, int64(total), page, pageSize)
|
||||
}
|
||||
|
||||
// GetOrderDetail returns detailed information about a single order.
|
||||
// GET /api/v1/admin/payment/orders/:id
|
||||
func (h *PaymentHandler) GetOrderDetail(c *gin.Context) {
|
||||
orderID, ok := parseIDParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
order, err := h.paymentService.GetOrderByID(c.Request.Context(), orderID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
auditLogs, _ := h.paymentService.GetOrderAuditLogs(c.Request.Context(), orderID)
|
||||
response.Success(c, gin.H{"order": order, "auditLogs": auditLogs})
|
||||
}
|
||||
|
||||
// CancelOrder cancels a pending order (admin).
|
||||
// POST /api/v1/admin/payment/orders/:id/cancel
|
||||
func (h *PaymentHandler) CancelOrder(c *gin.Context) {
|
||||
orderID, ok := parseIDParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
msg, err := h.paymentService.AdminCancelOrder(c.Request.Context(), orderID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"message": msg})
|
||||
}
|
||||
|
||||
// RetryFulfillment retries fulfillment for a paid order.
|
||||
// POST /api/v1/admin/payment/orders/:id/retry
|
||||
func (h *PaymentHandler) RetryFulfillment(c *gin.Context) {
|
||||
orderID, ok := parseIDParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := h.paymentService.RetryFulfillment(c.Request.Context(), orderID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"message": "fulfillment retried"})
|
||||
}
|
||||
|
||||
// AdminProcessRefundRequest is the request body for admin refund processing.
|
||||
type AdminProcessRefundRequest struct {
|
||||
Amount float64 `json:"amount"`
|
||||
Reason string `json:"reason"`
|
||||
Force bool `json:"force"`
|
||||
DeductBalance bool `json:"deduct_balance"`
|
||||
}
|
||||
|
||||
// ProcessRefund processes a refund for an order (admin).
|
||||
// POST /api/v1/admin/payment/orders/:id/refund
|
||||
func (h *PaymentHandler) ProcessRefund(c *gin.Context) {
|
||||
orderID, ok := parseIDParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req AdminProcessRefundRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
plan, earlyResult, err := h.paymentService.PrepareRefund(c.Request.Context(), orderID, req.Amount, req.Reason, req.Force, req.DeductBalance)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if earlyResult != nil {
|
||||
response.Success(c, earlyResult)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.paymentService.ExecuteRefund(c.Request.Context(), plan)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// --- Subscription Plans ---
|
||||
|
||||
// ListPlans returns all subscription plans.
|
||||
// GET /api/v1/admin/payment/plans
|
||||
func (h *PaymentHandler) ListPlans(c *gin.Context) {
|
||||
plans, err := h.configService.ListPlans(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, plans)
|
||||
}
|
||||
|
||||
// CreatePlan creates a new subscription plan.
|
||||
// POST /api/v1/admin/payment/plans
|
||||
func (h *PaymentHandler) CreatePlan(c *gin.Context) {
|
||||
var req service.CreatePlanRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
plan, err := h.configService.CreatePlan(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Created(c, plan)
|
||||
}
|
||||
|
||||
// UpdatePlan updates an existing subscription plan.
|
||||
// PUT /api/v1/admin/payment/plans/:id
|
||||
func (h *PaymentHandler) UpdatePlan(c *gin.Context) {
|
||||
id, ok := parseIDParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var req service.UpdatePlanRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
plan, err := h.configService.UpdatePlan(c.Request.Context(), id, req)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, plan)
|
||||
}
|
||||
|
||||
// DeletePlan deletes a subscription plan.
|
||||
// DELETE /api/v1/admin/payment/plans/:id
|
||||
func (h *PaymentHandler) DeletePlan(c *gin.Context) {
|
||||
id, ok := parseIDParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := h.configService.DeletePlan(c.Request.Context(), id); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"message": "deleted"})
|
||||
}
|
||||
|
||||
// --- Provider Instances ---
|
||||
|
||||
// ListProviders returns all payment provider instances.
|
||||
// GET /api/v1/admin/payment/providers
|
||||
func (h *PaymentHandler) ListProviders(c *gin.Context) {
|
||||
providers, err := h.configService.ListProviderInstancesWithConfig(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, providers)
|
||||
}
|
||||
|
||||
// CreateProvider creates a new payment provider instance.
|
||||
// POST /api/v1/admin/payment/providers
|
||||
func (h *PaymentHandler) CreateProvider(c *gin.Context) {
|
||||
var req service.CreateProviderInstanceRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
inst, err := h.configService.CreateProviderInstance(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
h.paymentService.RefreshProviders(c.Request.Context())
|
||||
response.Created(c, inst)
|
||||
}
|
||||
|
||||
// UpdateProvider updates an existing payment provider instance.
|
||||
// PUT /api/v1/admin/payment/providers/:id
|
||||
func (h *PaymentHandler) UpdateProvider(c *gin.Context) {
|
||||
id, ok := parseIDParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var req service.UpdateProviderInstanceRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
inst, err := h.configService.UpdateProviderInstance(c.Request.Context(), id, req)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
h.paymentService.RefreshProviders(c.Request.Context())
|
||||
response.Success(c, inst)
|
||||
}
|
||||
|
||||
// DeleteProvider deletes a payment provider instance.
|
||||
// DELETE /api/v1/admin/payment/providers/:id
|
||||
func (h *PaymentHandler) DeleteProvider(c *gin.Context) {
|
||||
id, ok := parseIDParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := h.configService.DeleteProviderInstance(c.Request.Context(), id); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
h.paymentService.RefreshProviders(c.Request.Context())
|
||||
response.Success(c, gin.H{"message": "deleted"})
|
||||
}
|
||||
|
||||
// parseIDParam parses an int64 path parameter.
|
||||
// Returns the parsed ID and true on success; on failure it writes a BadRequest response and returns false.
|
||||
func parseIDParam(c *gin.Context, paramName string) (int64, bool) {
|
||||
id, err := strconv.ParseInt(c.Param(paramName), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid "+paramName)
|
||||
return 0, false
|
||||
}
|
||||
return id, true
|
||||
}
|
||||
|
||||
// --- Config ---
|
||||
|
||||
// GetConfig returns the payment configuration (admin view).
|
||||
// GET /api/v1/admin/payment/config
|
||||
func (h *PaymentHandler) GetConfig(c *gin.Context) {
|
||||
cfg, err := h.configService.GetPaymentConfig(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
// UpdateConfig updates the payment configuration.
|
||||
// PUT /api/v1/admin/payment/config
|
||||
func (h *PaymentHandler) UpdateConfig(c *gin.Context) {
|
||||
var req service.UpdatePaymentConfigRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if err := h.configService.UpdatePaymentConfig(c.Request.Context(), req); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"message": "updated"})
|
||||
}
|
||||
@@ -46,19 +46,23 @@ func scopesContainOpenID(scopes string) bool {
|
||||
|
||||
// SettingHandler 系统设置处理器
|
||||
type SettingHandler struct {
|
||||
settingService *service.SettingService
|
||||
emailService *service.EmailService
|
||||
turnstileService *service.TurnstileService
|
||||
opsService *service.OpsService
|
||||
settingService *service.SettingService
|
||||
emailService *service.EmailService
|
||||
turnstileService *service.TurnstileService
|
||||
opsService *service.OpsService
|
||||
paymentConfigService *service.PaymentConfigService
|
||||
paymentService *service.PaymentService
|
||||
}
|
||||
|
||||
// NewSettingHandler 创建系统设置处理器
|
||||
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService) *SettingHandler {
|
||||
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, paymentConfigService *service.PaymentConfigService, paymentService *service.PaymentService) *SettingHandler {
|
||||
return &SettingHandler{
|
||||
settingService: settingService,
|
||||
emailService: emailService,
|
||||
turnstileService: turnstileService,
|
||||
opsService: opsService,
|
||||
settingService: settingService,
|
||||
emailService: emailService,
|
||||
turnstileService: turnstileService,
|
||||
opsService: opsService,
|
||||
paymentConfigService: paymentConfigService,
|
||||
paymentService: paymentService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,6 +85,15 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// Load payment config
|
||||
var paymentCfg *service.PaymentConfig
|
||||
if h.paymentConfigService != nil {
|
||||
paymentCfg, _ = h.paymentConfigService.GetPaymentConfig(c.Request.Context())
|
||||
}
|
||||
if paymentCfg == nil {
|
||||
paymentCfg = &service.PaymentConfig{}
|
||||
}
|
||||
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
@@ -160,6 +173,24 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
EnableFingerprintUnification: settings.EnableFingerprintUnification,
|
||||
EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
|
||||
EnableCCHSigning: settings.EnableCCHSigning,
|
||||
PaymentEnabled: paymentCfg.Enabled,
|
||||
PaymentMinAmount: paymentCfg.MinAmount,
|
||||
PaymentMaxAmount: paymentCfg.MaxAmount,
|
||||
PaymentDailyLimit: paymentCfg.DailyLimit,
|
||||
PaymentOrderTimeoutMin: paymentCfg.OrderTimeoutMin,
|
||||
PaymentMaxPendingOrders: paymentCfg.MaxPendingOrders,
|
||||
PaymentEnabledTypes: paymentCfg.EnabledTypes,
|
||||
PaymentBalanceDisabled: paymentCfg.BalanceDisabled,
|
||||
PaymentLoadBalanceStrat: paymentCfg.LoadBalanceStrategy,
|
||||
PaymentProductNamePrefix: paymentCfg.ProductNamePrefix,
|
||||
PaymentProductNameSuffix: paymentCfg.ProductNameSuffix,
|
||||
PaymentHelpImageURL: paymentCfg.HelpImageURL,
|
||||
PaymentHelpText: paymentCfg.HelpText,
|
||||
PaymentCancelRateLimitEnabled: paymentCfg.CancelRateLimitEnabled,
|
||||
PaymentCancelRateLimitMax: paymentCfg.CancelRateLimitMax,
|
||||
PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow,
|
||||
PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit,
|
||||
PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -268,6 +299,28 @@ type UpdateSettingsRequest struct {
|
||||
EnableFingerprintUnification *bool `json:"enable_fingerprint_unification"`
|
||||
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
|
||||
EnableCCHSigning *bool `json:"enable_cch_signing"`
|
||||
|
||||
// Payment configuration (integrated into settings, full replace)
|
||||
PaymentEnabled *bool `json:"payment_enabled"`
|
||||
PaymentMinAmount *float64 `json:"payment_min_amount"`
|
||||
PaymentMaxAmount *float64 `json:"payment_max_amount"`
|
||||
PaymentDailyLimit *float64 `json:"payment_daily_limit"`
|
||||
PaymentOrderTimeoutMin *int `json:"payment_order_timeout_minutes"`
|
||||
PaymentMaxPendingOrders *int `json:"payment_max_pending_orders"`
|
||||
PaymentEnabledTypes []string `json:"payment_enabled_types"`
|
||||
PaymentBalanceDisabled *bool `json:"payment_balance_disabled"`
|
||||
PaymentLoadBalanceStrat *string `json:"payment_load_balance_strategy"`
|
||||
PaymentProductNamePrefix *string `json:"payment_product_name_prefix"`
|
||||
PaymentProductNameSuffix *string `json:"payment_product_name_suffix"`
|
||||
PaymentHelpImageURL *string `json:"payment_help_image_url"`
|
||||
PaymentHelpText *string `json:"payment_help_text"`
|
||||
|
||||
// Cancel rate limit
|
||||
PaymentCancelRateLimitEnabled *bool `json:"payment_cancel_rate_limit_enabled"`
|
||||
PaymentCancelRateLimitMax *int `json:"payment_cancel_rate_limit_max"`
|
||||
PaymentCancelRateLimitWindow *int `json:"payment_cancel_rate_limit_window"`
|
||||
PaymentCancelRateLimitUnit *string `json:"payment_cancel_rate_limit_unit"`
|
||||
PaymentCancelRateLimitMode *string `json:"payment_cancel_rate_limit_window_mode"`
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
@@ -822,6 +875,39 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Update payment configuration (integrated into system settings).
|
||||
// Skip if no payment fields were provided (prevents accidental wipe).
|
||||
if h.paymentConfigService != nil && hasPaymentFields(req) {
|
||||
paymentReq := service.UpdatePaymentConfigRequest{
|
||||
Enabled: req.PaymentEnabled,
|
||||
MinAmount: req.PaymentMinAmount,
|
||||
MaxAmount: req.PaymentMaxAmount,
|
||||
DailyLimit: req.PaymentDailyLimit,
|
||||
OrderTimeoutMin: req.PaymentOrderTimeoutMin,
|
||||
MaxPendingOrders: req.PaymentMaxPendingOrders,
|
||||
EnabledTypes: req.PaymentEnabledTypes,
|
||||
BalanceDisabled: req.PaymentBalanceDisabled,
|
||||
LoadBalanceStrategy: req.PaymentLoadBalanceStrat,
|
||||
ProductNamePrefix: req.PaymentProductNamePrefix,
|
||||
ProductNameSuffix: req.PaymentProductNameSuffix,
|
||||
HelpImageURL: req.PaymentHelpImageURL,
|
||||
HelpText: req.PaymentHelpText,
|
||||
CancelRateLimitEnabled: req.PaymentCancelRateLimitEnabled,
|
||||
CancelRateLimitMax: req.PaymentCancelRateLimitMax,
|
||||
CancelRateLimitWindow: req.PaymentCancelRateLimitWindow,
|
||||
CancelRateLimitUnit: req.PaymentCancelRateLimitUnit,
|
||||
CancelRateLimitMode: req.PaymentCancelRateLimitMode,
|
||||
}
|
||||
if err := h.paymentConfigService.UpdatePaymentConfig(c.Request.Context(), paymentReq); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
// Refresh in-memory provider registry so config changes take effect immediately
|
||||
if h.paymentService != nil {
|
||||
h.paymentService.RefreshProviders(c.Request.Context())
|
||||
}
|
||||
}
|
||||
|
||||
h.auditSettingsUpdate(c, previousSettings, settings, req)
|
||||
|
||||
// 重新获取设置返回
|
||||
@@ -838,6 +924,15 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// Reload payment config for response
|
||||
var updatedPaymentCfg *service.PaymentConfig
|
||||
if h.paymentConfigService != nil {
|
||||
updatedPaymentCfg, _ = h.paymentConfigService.GetPaymentConfig(c.Request.Context())
|
||||
}
|
||||
if updatedPaymentCfg == nil {
|
||||
updatedPaymentCfg = &service.PaymentConfig{}
|
||||
}
|
||||
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||
@@ -917,9 +1012,40 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
|
||||
EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
|
||||
EnableCCHSigning: updatedSettings.EnableCCHSigning,
|
||||
PaymentEnabled: updatedPaymentCfg.Enabled,
|
||||
PaymentMinAmount: updatedPaymentCfg.MinAmount,
|
||||
PaymentMaxAmount: updatedPaymentCfg.MaxAmount,
|
||||
PaymentDailyLimit: updatedPaymentCfg.DailyLimit,
|
||||
PaymentOrderTimeoutMin: updatedPaymentCfg.OrderTimeoutMin,
|
||||
PaymentMaxPendingOrders: updatedPaymentCfg.MaxPendingOrders,
|
||||
PaymentEnabledTypes: updatedPaymentCfg.EnabledTypes,
|
||||
PaymentBalanceDisabled: updatedPaymentCfg.BalanceDisabled,
|
||||
PaymentLoadBalanceStrat: updatedPaymentCfg.LoadBalanceStrategy,
|
||||
PaymentProductNamePrefix: updatedPaymentCfg.ProductNamePrefix,
|
||||
PaymentProductNameSuffix: updatedPaymentCfg.ProductNameSuffix,
|
||||
PaymentHelpImageURL: updatedPaymentCfg.HelpImageURL,
|
||||
PaymentHelpText: updatedPaymentCfg.HelpText,
|
||||
PaymentCancelRateLimitEnabled: updatedPaymentCfg.CancelRateLimitEnabled,
|
||||
PaymentCancelRateLimitMax: updatedPaymentCfg.CancelRateLimitMax,
|
||||
PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow,
|
||||
PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit,
|
||||
PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode,
|
||||
})
|
||||
}
|
||||
|
||||
// hasPaymentFields returns true if any payment-related field was explicitly provided.
|
||||
func hasPaymentFields(req UpdateSettingsRequest) bool {
|
||||
return req.PaymentEnabled != nil || req.PaymentMinAmount != nil ||
|
||||
req.PaymentMaxAmount != nil || req.PaymentDailyLimit != nil ||
|
||||
req.PaymentOrderTimeoutMin != nil || req.PaymentMaxPendingOrders != nil ||
|
||||
req.PaymentEnabledTypes != nil || req.PaymentBalanceDisabled != nil ||
|
||||
req.PaymentLoadBalanceStrat != nil || req.PaymentProductNamePrefix != nil ||
|
||||
req.PaymentProductNameSuffix != nil || req.PaymentHelpImageURL != nil ||
|
||||
req.PaymentHelpText != nil || req.PaymentCancelRateLimitEnabled != nil ||
|
||||
req.PaymentCancelRateLimitMax != nil || req.PaymentCancelRateLimitWindow != nil ||
|
||||
req.PaymentCancelRateLimitUnit != nil || req.PaymentCancelRateLimitMode != nil
|
||||
}
|
||||
|
||||
func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) {
|
||||
if before == nil || after == nil {
|
||||
return
|
||||
|
||||
@@ -121,6 +121,28 @@ type SystemSettings struct {
|
||||
EnableFingerprintUnification bool `json:"enable_fingerprint_unification"`
|
||||
EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
|
||||
EnableCCHSigning bool `json:"enable_cch_signing"`
|
||||
|
||||
// Payment configuration
|
||||
PaymentEnabled bool `json:"payment_enabled"`
|
||||
PaymentMinAmount float64 `json:"payment_min_amount"`
|
||||
PaymentMaxAmount float64 `json:"payment_max_amount"`
|
||||
PaymentDailyLimit float64 `json:"payment_daily_limit"`
|
||||
PaymentOrderTimeoutMin int `json:"payment_order_timeout_minutes"`
|
||||
PaymentMaxPendingOrders int `json:"payment_max_pending_orders"`
|
||||
PaymentEnabledTypes []string `json:"payment_enabled_types"`
|
||||
PaymentBalanceDisabled bool `json:"payment_balance_disabled"`
|
||||
PaymentLoadBalanceStrat string `json:"payment_load_balance_strategy"`
|
||||
PaymentProductNamePrefix string `json:"payment_product_name_prefix"`
|
||||
PaymentProductNameSuffix string `json:"payment_product_name_suffix"`
|
||||
PaymentHelpImageURL string `json:"payment_help_image_url"`
|
||||
PaymentHelpText string `json:"payment_help_text"`
|
||||
|
||||
// Cancel rate limit
|
||||
PaymentCancelRateLimitEnabled bool `json:"payment_cancel_rate_limit_enabled"`
|
||||
PaymentCancelRateLimitMax int `json:"payment_cancel_rate_limit_max"`
|
||||
PaymentCancelRateLimitWindow int `json:"payment_cancel_rate_limit_window"`
|
||||
PaymentCancelRateLimitUnit string `json:"payment_cancel_rate_limit_unit"`
|
||||
PaymentCancelRateLimitMode string `json:"payment_cancel_rate_limit_window_mode"`
|
||||
}
|
||||
|
||||
type DefaultSubscriptionSetting struct {
|
||||
@@ -155,6 +177,7 @@ type PublicSettings struct {
|
||||
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
PaymentEnabled bool `json:"payment_enabled"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
|
||||
@@ -31,22 +31,25 @@ type AdminHandlers struct {
|
||||
APIKey *admin.AdminAPIKeyHandler
|
||||
ScheduledTest *admin.ScheduledTestHandler
|
||||
Channel *admin.ChannelHandler
|
||||
Payment *admin.PaymentHandler
|
||||
}
|
||||
|
||||
// Handlers contains all HTTP handlers
|
||||
type Handlers struct {
|
||||
Auth *AuthHandler
|
||||
User *UserHandler
|
||||
APIKey *APIKeyHandler
|
||||
Usage *UsageHandler
|
||||
Redeem *RedeemHandler
|
||||
Subscription *SubscriptionHandler
|
||||
Announcement *AnnouncementHandler
|
||||
Admin *AdminHandlers
|
||||
Gateway *GatewayHandler
|
||||
OpenAIGateway *OpenAIGatewayHandler
|
||||
Setting *SettingHandler
|
||||
Totp *TotpHandler
|
||||
Auth *AuthHandler
|
||||
User *UserHandler
|
||||
APIKey *APIKeyHandler
|
||||
Usage *UsageHandler
|
||||
Redeem *RedeemHandler
|
||||
Subscription *SubscriptionHandler
|
||||
Announcement *AnnouncementHandler
|
||||
Admin *AdminHandlers
|
||||
Gateway *GatewayHandler
|
||||
OpenAIGateway *OpenAIGatewayHandler
|
||||
Setting *SettingHandler
|
||||
Totp *TotpHandler
|
||||
Payment *PaymentHandler
|
||||
PaymentWebhook *PaymentWebhookHandler
|
||||
}
|
||||
|
||||
// BuildInfo contains build-time information
|
||||
|
||||
416
backend/internal/handler/payment_handler.go
Normal file
416
backend/internal/handler/payment_handler.go
Normal file
@@ -0,0 +1,416 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// PaymentHandler handles user-facing payment requests.
|
||||
type PaymentHandler struct {
|
||||
channelService *service.ChannelService
|
||||
paymentService *service.PaymentService
|
||||
configService *service.PaymentConfigService
|
||||
}
|
||||
|
||||
// NewPaymentHandler creates a new PaymentHandler.
|
||||
func NewPaymentHandler(paymentService *service.PaymentService, configService *service.PaymentConfigService, channelService *service.ChannelService) *PaymentHandler {
|
||||
return &PaymentHandler{
|
||||
channelService: channelService,
|
||||
paymentService: paymentService,
|
||||
configService: configService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetPaymentConfig returns the payment system configuration.
|
||||
// GET /api/v1/payment/config
|
||||
func (h *PaymentHandler) GetPaymentConfig(c *gin.Context) {
|
||||
cfg, err := h.configService.GetPaymentConfig(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
// GetPlans returns subscription plans available for sale.
|
||||
// GET /api/v1/payment/plans
|
||||
func (h *PaymentHandler) GetPlans(c *gin.Context) {
|
||||
plans, err := h.configService.ListPlansForSale(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
// Enrich plans with group platform for frontend color coding
|
||||
type planWithPlatform struct {
|
||||
ID int64 `json:"id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
GroupPlatform string `json:"group_platform"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Price float64 `json:"price"`
|
||||
OriginalPrice *float64 `json:"original_price,omitempty"`
|
||||
ValidityDays int `json:"validity_days"`
|
||||
ValidityUnit string `json:"validity_unit"`
|
||||
Features string `json:"features"`
|
||||
ProductName string `json:"product_name"`
|
||||
ForSale bool `json:"for_sale"`
|
||||
SortOrder int `json:"sort_order"`
|
||||
}
|
||||
platformMap := h.configService.GetGroupPlatformMap(c.Request.Context(), plans)
|
||||
result := make([]planWithPlatform, 0, len(plans))
|
||||
for _, p := range plans {
|
||||
result = append(result, planWithPlatform{
|
||||
ID: int64(p.ID), GroupID: p.GroupID, GroupPlatform: platformMap[p.GroupID],
|
||||
Name: p.Name, Description: p.Description, Price: p.Price, OriginalPrice: p.OriginalPrice,
|
||||
ValidityDays: p.ValidityDays, ValidityUnit: p.ValidityUnit, Features: p.Features,
|
||||
ProductName: p.ProductName, ForSale: p.ForSale, SortOrder: p.SortOrder,
|
||||
})
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// GetChannels returns enabled payment channels.
|
||||
// GET /api/v1/payment/channels
|
||||
func (h *PaymentHandler) GetChannels(c *gin.Context) {
|
||||
channels, _, err := h.channelService.List(c.Request.Context(), pagination.PaginationParams{Page: 1, PageSize: 1000}, "active", "")
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, channels)
|
||||
}
|
||||
|
||||
// GetCheckoutInfo returns all data the payment page needs in a single call:
|
||||
// payment methods with limits, subscription plans, and configuration.
|
||||
// GET /api/v1/payment/checkout-info
|
||||
func (h *PaymentHandler) GetCheckoutInfo(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Fetch limits (methods + global range)
|
||||
limitsResp, err := h.configService.GetAvailableMethodLimits(ctx)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Fetch payment config
|
||||
cfg, err := h.configService.GetPaymentConfig(ctx)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Fetch plans with group info
|
||||
plans, _ := h.configService.ListPlansForSale(ctx)
|
||||
groupInfo := h.configService.GetGroupInfoMap(ctx, plans)
|
||||
planList := make([]checkoutPlan, 0, len(plans))
|
||||
for _, p := range plans {
|
||||
gi := groupInfo[p.GroupID]
|
||||
planList = append(planList, checkoutPlan{
|
||||
ID: int64(p.ID), GroupID: p.GroupID,
|
||||
GroupPlatform: gi.Platform, GroupName: gi.Name,
|
||||
RateMultiplier: gi.RateMultiplier, DailyLimitUSD: gi.DailyLimitUSD,
|
||||
WeeklyLimitUSD: gi.WeeklyLimitUSD, MonthlyLimitUSD: gi.MonthlyLimitUSD,
|
||||
ModelScopes: gi.ModelScopes,
|
||||
Name: p.Name, Description: p.Description, Price: p.Price, OriginalPrice: p.OriginalPrice,
|
||||
ValidityDays: p.ValidityDays, ValidityUnit: p.ValidityUnit, Features: parseFeatures(p.Features),
|
||||
ProductName: p.ProductName,
|
||||
})
|
||||
}
|
||||
|
||||
response.Success(c, checkoutInfoResponse{
|
||||
Methods: limitsResp.Methods,
|
||||
GlobalMin: limitsResp.GlobalMin,
|
||||
GlobalMax: limitsResp.GlobalMax,
|
||||
Plans: planList,
|
||||
BalanceDisabled: cfg.BalanceDisabled,
|
||||
HelpText: cfg.HelpText,
|
||||
HelpImageURL: cfg.HelpImageURL,
|
||||
StripePublishableKey: cfg.StripePublishableKey,
|
||||
})
|
||||
}
|
||||
|
||||
type checkoutInfoResponse struct {
|
||||
Methods map[string]service.MethodLimits `json:"methods"`
|
||||
GlobalMin float64 `json:"global_min"`
|
||||
GlobalMax float64 `json:"global_max"`
|
||||
Plans []checkoutPlan `json:"plans"`
|
||||
BalanceDisabled bool `json:"balance_disabled"`
|
||||
HelpText string `json:"help_text"`
|
||||
HelpImageURL string `json:"help_image_url"`
|
||||
StripePublishableKey string `json:"stripe_publishable_key"`
|
||||
}
|
||||
|
||||
type checkoutPlan struct {
|
||||
ID int64 `json:"id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
GroupPlatform string `json:"group_platform"`
|
||||
GroupName string `json:"group_name"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
ModelScopes []string `json:"supported_model_scopes"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Price float64 `json:"price"`
|
||||
OriginalPrice *float64 `json:"original_price,omitempty"`
|
||||
ValidityDays int `json:"validity_days"`
|
||||
ValidityUnit string `json:"validity_unit"`
|
||||
Features []string `json:"features"`
|
||||
ProductName string `json:"product_name"`
|
||||
}
|
||||
|
||||
// parseFeatures splits a newline-separated features string into a string slice.
|
||||
func parseFeatures(raw string) []string {
|
||||
if raw == "" {
|
||||
return []string{}
|
||||
}
|
||||
var out []string
|
||||
for _, line := range strings.Split(raw, "\n") {
|
||||
if s := strings.TrimSpace(line); s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
if out == nil {
|
||||
return []string{}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// GetLimits returns per-payment-type limits derived from enabled provider instances.
|
||||
// GET /api/v1/payment/limits
|
||||
func (h *PaymentHandler) GetLimits(c *gin.Context) {
|
||||
resp, err := h.configService.GetAvailableMethodLimits(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, resp)
|
||||
}
|
||||
|
||||
// CreateOrderRequest is the request body for creating a payment order.
|
||||
type CreateOrderRequest struct {
|
||||
Amount float64 `json:"amount"`
|
||||
PaymentType string `json:"payment_type" binding:"required"`
|
||||
OrderType string `json:"order_type"`
|
||||
PlanID int64 `json:"plan_id"`
|
||||
}
|
||||
|
||||
// CreateOrder creates a new payment order.
|
||||
// POST /api/v1/payment/orders
|
||||
func (h *PaymentHandler) CreateOrder(c *gin.Context) {
|
||||
subject, ok := requireAuth(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req CreateOrderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.paymentService.CreateOrder(c.Request.Context(), service.CreateOrderRequest{
|
||||
UserID: subject.UserID,
|
||||
Amount: req.Amount,
|
||||
PaymentType: req.PaymentType,
|
||||
ClientIP: c.ClientIP(),
|
||||
IsMobile: isMobile(c),
|
||||
SrcHost: c.Request.Host,
|
||||
SrcURL: c.Request.Referer(),
|
||||
OrderType: req.OrderType,
|
||||
PlanID: req.PlanID,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// GetMyOrders returns the authenticated user's orders.
|
||||
// GET /api/v1/payment/orders/my
|
||||
func (h *PaymentHandler) GetMyOrders(c *gin.Context) {
|
||||
subject, ok := requireAuth(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
orders, total, err := h.paymentService.GetUserOrders(c.Request.Context(), subject.UserID, service.OrderListParams{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Status: c.Query("status"),
|
||||
OrderType: c.Query("order_type"),
|
||||
PaymentType: c.Query("payment_type"),
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Paginated(c, orders, int64(total), page, pageSize)
|
||||
}
|
||||
|
||||
// GetOrder returns a single order for the authenticated user.
|
||||
// GET /api/v1/payment/orders/:id
|
||||
func (h *PaymentHandler) GetOrder(c *gin.Context) {
|
||||
subject, ok := requireAuth(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
orderID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid order ID")
|
||||
return
|
||||
}
|
||||
|
||||
order, err := h.paymentService.GetOrder(c.Request.Context(), orderID, subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, order)
|
||||
}
|
||||
|
||||
// CancelOrder cancels a pending order for the authenticated user.
|
||||
// POST /api/v1/payment/orders/:id/cancel
|
||||
func (h *PaymentHandler) CancelOrder(c *gin.Context) {
|
||||
subject, ok := requireAuth(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
orderID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid order ID")
|
||||
return
|
||||
}
|
||||
|
||||
msg, err := h.paymentService.CancelOrder(c.Request.Context(), orderID, subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"message": msg})
|
||||
}
|
||||
|
||||
// RefundRequestBody is the request body for requesting a refund.
|
||||
type RefundRequestBody struct {
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
// RequestRefund submits a refund request for a completed order.
|
||||
// POST /api/v1/payment/orders/:id/refund-request
|
||||
func (h *PaymentHandler) RequestRefund(c *gin.Context) {
|
||||
subject, ok := requireAuth(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
orderID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid order ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req RefundRequestBody
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.paymentService.RequestRefund(c.Request.Context(), orderID, subject.UserID, req.Reason); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"message": "refund requested"})
|
||||
}
|
||||
|
||||
// VerifyOrderRequest is the request body for verifying a payment order.
|
||||
type VerifyOrderRequest struct {
|
||||
OutTradeNo string `json:"out_trade_no" binding:"required"`
|
||||
}
|
||||
|
||||
// VerifyOrder actively queries the upstream payment provider to check
|
||||
// if payment was made, and processes it if so.
|
||||
// POST /api/v1/payment/orders/verify
|
||||
func (h *PaymentHandler) VerifyOrder(c *gin.Context) {
|
||||
subject, ok := requireAuth(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req VerifyOrderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
order, err := h.paymentService.VerifyOrderByOutTradeNo(c.Request.Context(), req.OutTradeNo, subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, order)
|
||||
}
|
||||
|
||||
// PublicOrderResult is the limited order info returned by the public verify endpoint.
|
||||
// No user details are exposed — only payment status information.
|
||||
type PublicOrderResult struct {
|
||||
ID int64 `json:"id"`
|
||||
OutTradeNo string `json:"out_trade_no"`
|
||||
Amount float64 `json:"amount"`
|
||||
PayAmount float64 `json:"pay_amount"`
|
||||
PaymentType string `json:"payment_type"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// VerifyOrderPublic verifies payment status without requiring authentication.
|
||||
// Returns limited order info (no user details) to prevent information leakage.
|
||||
// POST /api/v1/payment/public/orders/verify
|
||||
func (h *PaymentHandler) VerifyOrderPublic(c *gin.Context) {
|
||||
var req VerifyOrderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
order, err := h.paymentService.VerifyOrderPublic(c.Request.Context(), req.OutTradeNo)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, PublicOrderResult{
|
||||
ID: order.ID,
|
||||
OutTradeNo: order.OutTradeNo,
|
||||
Amount: order.Amount,
|
||||
PayAmount: order.PayAmount,
|
||||
PaymentType: order.PaymentType,
|
||||
Status: order.Status,
|
||||
})
|
||||
}
|
||||
|
||||
// requireAuth extracts the authenticated subject from the context.
|
||||
// Returns the subject and true on success; on failure it writes an Unauthorized response and returns false.
|
||||
func requireAuth(c *gin.Context) (middleware2.AuthSubject, bool) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return middleware2.AuthSubject{}, false
|
||||
}
|
||||
return subject, true
|
||||
}
|
||||
|
||||
// isMobile detects mobile user agents.
|
||||
func isMobile(c *gin.Context) bool {
|
||||
ua := strings.ToLower(c.GetHeader("User-Agent"))
|
||||
return strings.Contains(ua, "mobile") || strings.Contains(ua, "android") || strings.Contains(ua, "iphone")
|
||||
}
|
||||
152
backend/internal/handler/payment_webhook_handler.go
Normal file
152
backend/internal/handler/payment_webhook_handler.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// PaymentWebhookHandler handles payment provider webhook callbacks.
|
||||
type PaymentWebhookHandler struct {
|
||||
paymentService *service.PaymentService
|
||||
registry *payment.Registry
|
||||
}
|
||||
|
||||
// maxWebhookBodySize is the maximum allowed webhook request body size (1 MB).
|
||||
const maxWebhookBodySize = 1 << 20
|
||||
|
||||
// webhookLogTruncateLen is the maximum length of raw body logged on verify failure.
|
||||
const webhookLogTruncateLen = 200
|
||||
|
||||
// NewPaymentWebhookHandler creates a new PaymentWebhookHandler.
|
||||
func NewPaymentWebhookHandler(paymentService *service.PaymentService, registry *payment.Registry) *PaymentWebhookHandler {
|
||||
return &PaymentWebhookHandler{
|
||||
paymentService: paymentService,
|
||||
registry: registry,
|
||||
}
|
||||
}
|
||||
|
||||
// EasyPayNotify handles EasyPay payment notifications.
|
||||
// POST /api/v1/payment/webhook/easypay
|
||||
func (h *PaymentWebhookHandler) EasyPayNotify(c *gin.Context) {
|
||||
h.handleNotify(c, payment.TypeEasyPay)
|
||||
}
|
||||
|
||||
// AlipayNotify handles Alipay payment notifications.
|
||||
// POST /api/v1/payment/webhook/alipay
|
||||
func (h *PaymentWebhookHandler) AlipayNotify(c *gin.Context) {
|
||||
h.handleNotify(c, payment.TypeAlipay)
|
||||
}
|
||||
|
||||
// WxpayNotify handles WeChat Pay payment notifications.
|
||||
// POST /api/v1/payment/webhook/wxpay
|
||||
func (h *PaymentWebhookHandler) WxpayNotify(c *gin.Context) {
|
||||
h.handleNotify(c, payment.TypeWxpay)
|
||||
}
|
||||
|
||||
// StripeWebhook handles Stripe webhook events.
|
||||
// POST /api/v1/payment/webhook/stripe
|
||||
func (h *PaymentWebhookHandler) StripeWebhook(c *gin.Context) {
|
||||
h.handleNotify(c, payment.TypeStripe)
|
||||
}
|
||||
|
||||
// handleNotify is the shared logic for all provider webhook handlers.
|
||||
func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string) {
|
||||
var rawBody string
|
||||
if c.Request.Method == http.MethodGet {
|
||||
// GET callbacks (e.g. EasyPay) pass params as URL query string
|
||||
rawBody = c.Request.URL.RawQuery
|
||||
} else {
|
||||
body, err := io.ReadAll(io.LimitReader(c.Request.Body, maxWebhookBodySize))
|
||||
if err != nil {
|
||||
slog.Error("[Payment Webhook] failed to read body", "provider", providerKey, "error", err)
|
||||
c.String(http.StatusBadRequest, "failed to read body")
|
||||
return
|
||||
}
|
||||
rawBody = string(body)
|
||||
}
|
||||
|
||||
// Extract out_trade_no to look up the order's specific provider instance.
|
||||
// This is needed when multiple instances of the same provider exist (e.g. multiple EasyPay accounts).
|
||||
outTradeNo := extractOutTradeNo(rawBody, providerKey)
|
||||
|
||||
provider, err := h.paymentService.GetWebhookProvider(c.Request.Context(), providerKey, outTradeNo)
|
||||
if err != nil {
|
||||
slog.Warn("[Payment Webhook] provider not found", "provider", providerKey, "outTradeNo", outTradeNo, "error", err)
|
||||
writeSuccessResponse(c, providerKey)
|
||||
return
|
||||
}
|
||||
|
||||
headers := make(map[string]string)
|
||||
for k := range c.Request.Header {
|
||||
headers[strings.ToLower(k)] = c.GetHeader(k)
|
||||
}
|
||||
|
||||
notification, err := provider.VerifyNotification(c.Request.Context(), rawBody, headers)
|
||||
if err != nil {
|
||||
truncatedBody := rawBody
|
||||
if len(truncatedBody) > webhookLogTruncateLen {
|
||||
truncatedBody = truncatedBody[:webhookLogTruncateLen] + "...(truncated)"
|
||||
}
|
||||
slog.Error("[Payment Webhook] verify failed", "provider", providerKey, "error", err, "method", c.Request.Method, "bodyLen", len(rawBody))
|
||||
slog.Debug("[Payment Webhook] verify failed body", "provider", providerKey, "rawBody", truncatedBody)
|
||||
c.String(http.StatusBadRequest, "verify failed")
|
||||
return
|
||||
}
|
||||
|
||||
// nil notification means irrelevant event (e.g. Stripe non-payment event); return success.
|
||||
if notification == nil {
|
||||
writeSuccessResponse(c, providerKey)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.paymentService.HandlePaymentNotification(c.Request.Context(), notification, providerKey); err != nil {
|
||||
slog.Error("[Payment Webhook] handle notification failed", "provider", providerKey, "error", err)
|
||||
c.String(http.StatusInternalServerError, "handle failed")
|
||||
return
|
||||
}
|
||||
|
||||
writeSuccessResponse(c, providerKey)
|
||||
}
|
||||
|
||||
// extractOutTradeNo parses the webhook body to find the out_trade_no.
|
||||
// This allows looking up the correct provider instance before verification.
|
||||
func extractOutTradeNo(rawBody, providerKey string) string {
|
||||
switch providerKey {
|
||||
case payment.TypeEasyPay:
|
||||
values, err := url.ParseQuery(rawBody)
|
||||
if err == nil {
|
||||
return values.Get("out_trade_no")
|
||||
}
|
||||
}
|
||||
// For other providers (Stripe, Alipay direct, WxPay direct), the registry
|
||||
// typically has only one instance, so no instance lookup is needed.
|
||||
return ""
|
||||
}
|
||||
|
||||
// wxpaySuccessResponse is the JSON response expected by WeChat Pay webhook.
|
||||
type wxpaySuccessResponse struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// writeSuccessResponse sends the provider-specific success response.
|
||||
// WeChat Pay requires JSON {"code":"SUCCESS","message":"成功"};
|
||||
// Stripe expects an empty 200; others accept plain text "success".
|
||||
func writeSuccessResponse(c *gin.Context, providerKey string) {
|
||||
switch providerKey {
|
||||
case payment.TypeWxpay:
|
||||
c.JSON(http.StatusOK, wxpaySuccessResponse{Code: "SUCCESS", Message: "成功"})
|
||||
case payment.TypeStripe:
|
||||
c.String(http.StatusOK, "")
|
||||
default:
|
||||
c.String(http.StatusOK, "success")
|
||||
}
|
||||
}
|
||||
99
backend/internal/handler/payment_webhook_handler_test.go
Normal file
99
backend/internal/handler/payment_webhook_handler_test.go
Normal file
@@ -0,0 +1,99 @@
|
||||
//go:build unit
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWriteSuccessResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
providerKey string
|
||||
wantCode int
|
||||
wantContentType string
|
||||
wantBody string
|
||||
checkJSON bool
|
||||
wantJSONCode string
|
||||
wantJSONMessage string
|
||||
}{
|
||||
{
|
||||
name: "wxpay returns JSON with code SUCCESS",
|
||||
providerKey: "wxpay",
|
||||
wantCode: http.StatusOK,
|
||||
wantContentType: "application/json",
|
||||
checkJSON: true,
|
||||
wantJSONCode: "SUCCESS",
|
||||
wantJSONMessage: "成功",
|
||||
},
|
||||
{
|
||||
name: "stripe returns empty 200",
|
||||
providerKey: "stripe",
|
||||
wantCode: http.StatusOK,
|
||||
wantContentType: "text/plain",
|
||||
wantBody: "",
|
||||
},
|
||||
{
|
||||
name: "easypay returns plain text success",
|
||||
providerKey: "easypay",
|
||||
wantCode: http.StatusOK,
|
||||
wantContentType: "text/plain",
|
||||
wantBody: "success",
|
||||
},
|
||||
{
|
||||
name: "alipay returns plain text success",
|
||||
providerKey: "alipay",
|
||||
wantCode: http.StatusOK,
|
||||
wantContentType: "text/plain",
|
||||
wantBody: "success",
|
||||
},
|
||||
{
|
||||
name: "unknown provider returns plain text success",
|
||||
providerKey: "unknown_provider",
|
||||
wantCode: http.StatusOK,
|
||||
wantContentType: "text/plain",
|
||||
wantBody: "success",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
writeSuccessResponse(c, tt.providerKey)
|
||||
|
||||
assert.Equal(t, tt.wantCode, w.Code)
|
||||
assert.Contains(t, w.Header().Get("Content-Type"), tt.wantContentType)
|
||||
|
||||
if tt.checkJSON {
|
||||
var resp wxpaySuccessResponse
|
||||
err := json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
require.NoError(t, err, "response body should be valid JSON")
|
||||
assert.Equal(t, tt.wantJSONCode, resp.Code)
|
||||
assert.Equal(t, tt.wantJSONMessage, resp.Message)
|
||||
} else {
|
||||
assert.Equal(t, tt.wantBody, w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebhookConstants(t *testing.T) {
|
||||
t.Run("maxWebhookBodySize is 1MB", func(t *testing.T) {
|
||||
assert.Equal(t, int64(1<<20), int64(maxWebhookBodySize))
|
||||
})
|
||||
|
||||
t.Run("webhookLogTruncateLen is 200", func(t *testing.T) {
|
||||
assert.Equal(t, 200, webhookLogTruncateLen)
|
||||
})
|
||||
}
|
||||
@@ -34,6 +34,7 @@ func ProvideAdminHandlers(
|
||||
apiKeyHandler *admin.AdminAPIKeyHandler,
|
||||
scheduledTestHandler *admin.ScheduledTestHandler,
|
||||
channelHandler *admin.ChannelHandler,
|
||||
paymentHandler *admin.PaymentHandler,
|
||||
) *AdminHandlers {
|
||||
return &AdminHandlers{
|
||||
Dashboard: dashboardHandler,
|
||||
@@ -61,6 +62,7 @@ func ProvideAdminHandlers(
|
||||
APIKey: apiKeyHandler,
|
||||
ScheduledTest: scheduledTestHandler,
|
||||
Channel: channelHandler,
|
||||
Payment: paymentHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,22 +90,26 @@ func ProvideHandlers(
|
||||
openaiGatewayHandler *OpenAIGatewayHandler,
|
||||
settingHandler *SettingHandler,
|
||||
totpHandler *TotpHandler,
|
||||
paymentHandler *PaymentHandler,
|
||||
paymentWebhookHandler *PaymentWebhookHandler,
|
||||
_ *service.IdempotencyCoordinator,
|
||||
_ *service.IdempotencyCleanupService,
|
||||
) *Handlers {
|
||||
return &Handlers{
|
||||
Auth: authHandler,
|
||||
User: userHandler,
|
||||
APIKey: apiKeyHandler,
|
||||
Usage: usageHandler,
|
||||
Redeem: redeemHandler,
|
||||
Subscription: subscriptionHandler,
|
||||
Announcement: announcementHandler,
|
||||
Admin: adminHandlers,
|
||||
Gateway: gatewayHandler,
|
||||
OpenAIGateway: openaiGatewayHandler,
|
||||
Setting: settingHandler,
|
||||
Totp: totpHandler,
|
||||
Auth: authHandler,
|
||||
User: userHandler,
|
||||
APIKey: apiKeyHandler,
|
||||
Usage: usageHandler,
|
||||
Redeem: redeemHandler,
|
||||
Subscription: subscriptionHandler,
|
||||
Announcement: announcementHandler,
|
||||
Admin: adminHandlers,
|
||||
Gateway: gatewayHandler,
|
||||
OpenAIGateway: openaiGatewayHandler,
|
||||
Setting: settingHandler,
|
||||
Totp: totpHandler,
|
||||
Payment: paymentHandler,
|
||||
PaymentWebhook: paymentWebhookHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,6 +127,8 @@ var ProviderSet = wire.NewSet(
|
||||
NewOpenAIGatewayHandler,
|
||||
NewTotpHandler,
|
||||
ProvideSettingHandler,
|
||||
NewPaymentHandler,
|
||||
NewPaymentWebhookHandler,
|
||||
|
||||
// Admin handlers
|
||||
admin.NewDashboardHandler,
|
||||
@@ -148,6 +156,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewAdminAPIKeyHandler,
|
||||
admin.NewScheduledTestHandler,
|
||||
admin.NewChannelHandler,
|
||||
admin.NewPaymentHandler,
|
||||
|
||||
// AdminHandlers and Handlers constructors
|
||||
ProvideAdminHandlers,
|
||||
|
||||
24
backend/internal/payment/amount.go
Normal file
24
backend/internal/payment/amount.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package payment
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
const centsPerYuan = 100
|
||||
|
||||
// YuanToFen converts a CNY yuan string (e.g. "10.50") to fen (int64).
|
||||
// Uses shopspring/decimal for precision.
|
||||
func YuanToFen(yuanStr string) (int64, error) {
|
||||
d, err := decimal.NewFromString(yuanStr)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid amount: %s", yuanStr)
|
||||
}
|
||||
return d.Mul(decimal.NewFromInt(centsPerYuan)).IntPart(), nil
|
||||
}
|
||||
|
||||
// FenToYuan converts fen (int64) to yuan as a float64 for interface compatibility.
|
||||
func FenToYuan(fen int64) float64 {
|
||||
return decimal.NewFromInt(fen).Div(decimal.NewFromInt(centsPerYuan)).InexactFloat64()
|
||||
}
|
||||
128
backend/internal/payment/amount_test.go
Normal file
128
backend/internal/payment/amount_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
//go:build unit
|
||||
|
||||
package payment
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestYuanToFen(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want int64
|
||||
wantErr bool
|
||||
}{
|
||||
// Normal values
|
||||
{name: "one yuan", input: "1.00", want: 100},
|
||||
{name: "ten yuan fifty fen", input: "10.50", want: 1050},
|
||||
{name: "one fen", input: "0.01", want: 1},
|
||||
{name: "large amount", input: "99999.99", want: 9999999},
|
||||
|
||||
// Edge: zero
|
||||
{name: "zero no decimal", input: "0", want: 0},
|
||||
{name: "zero with decimal", input: "0.00", want: 0},
|
||||
|
||||
// IEEE 754 precision edge case: 1.15 * 100 = 114.99999... in float64
|
||||
{name: "ieee754 precision 1.15", input: "1.15", want: 115},
|
||||
|
||||
// More precision edge cases
|
||||
{name: "ieee754 precision 0.1", input: "0.1", want: 10},
|
||||
{name: "ieee754 precision 0.2", input: "0.2", want: 20},
|
||||
{name: "ieee754 precision 33.33", input: "33.33", want: 3333},
|
||||
|
||||
// Large value
|
||||
{name: "hundred thousand", input: "100000.00", want: 10000000},
|
||||
|
||||
// Integer without decimal
|
||||
{name: "integer 5", input: "5", want: 500},
|
||||
{name: "integer 100", input: "100", want: 10000},
|
||||
|
||||
// Single decimal place
|
||||
{name: "single decimal 1.5", input: "1.5", want: 150},
|
||||
|
||||
// Negative values
|
||||
{name: "negative one yuan", input: "-1.00", want: -100},
|
||||
{name: "negative with fen", input: "-10.50", want: -1050},
|
||||
|
||||
// Invalid inputs
|
||||
{name: "empty string", input: "", wantErr: true},
|
||||
{name: "alphabetic", input: "abc", wantErr: true},
|
||||
{name: "double dot", input: "1.2.3", wantErr: true},
|
||||
{name: "spaces", input: " ", wantErr: true},
|
||||
{name: "special chars", input: "$10.00", wantErr: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := YuanToFen(tt.input)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("YuanToFen(%q) expected error, got %d", tt.input, got)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("YuanToFen(%q) unexpected error: %v", tt.input, err)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("YuanToFen(%q) = %d, want %d", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFenToYuan(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fen int64
|
||||
want float64
|
||||
}{
|
||||
{name: "one yuan", fen: 100, want: 1.0},
|
||||
{name: "ten yuan fifty fen", fen: 1050, want: 10.5},
|
||||
{name: "one fen", fen: 1, want: 0.01},
|
||||
{name: "zero", fen: 0, want: 0.0},
|
||||
{name: "large amount", fen: 9999999, want: 99999.99},
|
||||
{name: "negative", fen: -100, want: -1.0},
|
||||
{name: "negative with fen", fen: -1050, want: -10.5},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := FenToYuan(tt.fen)
|
||||
if math.Abs(got-tt.want) > 1e-9 {
|
||||
t.Errorf("FenToYuan(%d) = %f, want %f", tt.fen, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestYuanToFenRoundTrip(t *testing.T) {
|
||||
// Verify that converting yuan->fen->yuan preserves the value.
|
||||
cases := []struct {
|
||||
yuan string
|
||||
fen int64
|
||||
}{
|
||||
{"0.01", 1},
|
||||
{"1.00", 100},
|
||||
{"10.50", 1050},
|
||||
{"99999.99", 9999999},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
fen, err := YuanToFen(tc.yuan)
|
||||
if err != nil {
|
||||
t.Fatalf("YuanToFen(%q) unexpected error: %v", tc.yuan, err)
|
||||
}
|
||||
if fen != tc.fen {
|
||||
t.Errorf("YuanToFen(%q) = %d, want %d", tc.yuan, fen, tc.fen)
|
||||
}
|
||||
yuan := FenToYuan(fen)
|
||||
// Parse expected yuan back for comparison
|
||||
expectedYuan := FenToYuan(tc.fen)
|
||||
if math.Abs(yuan-expectedYuan) > 1e-9 {
|
||||
t.Errorf("round-trip: FenToYuan(%d) = %f, want %f", fen, yuan, expectedYuan)
|
||||
}
|
||||
}
|
||||
}
|
||||
98
backend/internal/payment/crypto.go
Normal file
98
backend/internal/payment/crypto.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package payment
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Encrypt encrypts plaintext using AES-256-GCM with the given 32-byte key.
|
||||
// The output format is "iv:authTag:ciphertext" where each component is base64-encoded,
|
||||
// matching the Node.js crypto.ts format for cross-compatibility.
|
||||
func Encrypt(plaintext string, key []byte) (string, error) {
|
||||
if len(key) != 32 {
|
||||
return "", fmt.Errorf("encryption key must be 32 bytes, got %d", len(key))
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create AES cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create GCM: %w", err)
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize()) // 12 bytes for GCM
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", fmt.Errorf("generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// Seal appends the ciphertext + auth tag
|
||||
sealed := gcm.Seal(nil, nonce, []byte(plaintext), nil)
|
||||
|
||||
// Split sealed into ciphertext and auth tag (last 16 bytes)
|
||||
tagSize := gcm.Overhead()
|
||||
ciphertext := sealed[:len(sealed)-tagSize]
|
||||
authTag := sealed[len(sealed)-tagSize:]
|
||||
|
||||
// Format: iv:authTag:ciphertext (all base64)
|
||||
return fmt.Sprintf("%s:%s:%s",
|
||||
base64.StdEncoding.EncodeToString(nonce),
|
||||
base64.StdEncoding.EncodeToString(authTag),
|
||||
base64.StdEncoding.EncodeToString(ciphertext),
|
||||
), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts a ciphertext string produced by Encrypt.
|
||||
// The input format is "iv:authTag:ciphertext" where each component is base64-encoded.
|
||||
func Decrypt(ciphertext string, key []byte) (string, error) {
|
||||
if len(key) != 32 {
|
||||
return "", fmt.Errorf("encryption key must be 32 bytes, got %d", len(key))
|
||||
}
|
||||
|
||||
parts := strings.SplitN(ciphertext, ":", 3)
|
||||
if len(parts) != 3 {
|
||||
return "", fmt.Errorf("invalid ciphertext format: expected iv:authTag:ciphertext")
|
||||
}
|
||||
|
||||
nonce, err := base64.StdEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode IV: %w", err)
|
||||
}
|
||||
|
||||
authTag, err := base64.StdEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode auth tag: %w", err)
|
||||
}
|
||||
|
||||
encrypted, err := base64.StdEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode ciphertext: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create AES cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create GCM: %w", err)
|
||||
}
|
||||
|
||||
// Reconstruct the sealed data: ciphertext + authTag
|
||||
sealed := append(encrypted, authTag...)
|
||||
|
||||
plaintext, err := gcm.Open(nil, nonce, sealed, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decrypt: %w", err)
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
183
backend/internal/payment/crypto_test.go
Normal file
183
backend/internal/payment/crypto_test.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package payment
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func makeKey(t *testing.T) []byte {
|
||||
t.Helper()
|
||||
key := make([]byte, 32)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
t.Fatalf("generate random key: %v", err)
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
func TestEncryptDecryptRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
key := makeKey(t)
|
||||
|
||||
plaintexts := []string{
|
||||
"hello world",
|
||||
"short",
|
||||
"a longer string with special chars: !@#$%^&*()",
|
||||
`{"key":"value","num":42}`,
|
||||
"你好世界 unicode test 🎉",
|
||||
strings.Repeat("x", 10000),
|
||||
}
|
||||
|
||||
for _, pt := range plaintexts {
|
||||
encrypted, err := Encrypt(pt, key)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt(%q) error: %v", pt[:min(len(pt), 30)], err)
|
||||
}
|
||||
decrypted, err := Decrypt(encrypted, key)
|
||||
if err != nil {
|
||||
t.Fatalf("Decrypt error for plaintext %q: %v", pt[:min(len(pt), 30)], err)
|
||||
}
|
||||
if decrypted != pt {
|
||||
t.Fatalf("round-trip failed: got %q, want %q", decrypted[:min(len(decrypted), 30)], pt[:min(len(pt), 30)])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptProducesDifferentCiphertexts(t *testing.T) {
|
||||
t.Parallel()
|
||||
key := makeKey(t)
|
||||
|
||||
ct1, err := Encrypt("same plaintext", key)
|
||||
if err != nil {
|
||||
t.Fatalf("first Encrypt error: %v", err)
|
||||
}
|
||||
ct2, err := Encrypt("same plaintext", key)
|
||||
if err != nil {
|
||||
t.Fatalf("second Encrypt error: %v", err)
|
||||
}
|
||||
if ct1 == ct2 {
|
||||
t.Fatal("two encryptions of the same plaintext should produce different ciphertexts (random nonce)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptWithWrongKeyFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
key1 := makeKey(t)
|
||||
key2 := makeKey(t)
|
||||
|
||||
encrypted, err := Encrypt("secret data", key1)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt error: %v", err)
|
||||
}
|
||||
|
||||
_, err = Decrypt(encrypted, key2)
|
||||
if err == nil {
|
||||
t.Fatal("Decrypt with wrong key should fail, but got nil error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptRejectsInvalidKeyLength(t *testing.T) {
|
||||
t.Parallel()
|
||||
badKeys := [][]byte{
|
||||
nil,
|
||||
make([]byte, 0),
|
||||
make([]byte, 16),
|
||||
make([]byte, 31),
|
||||
make([]byte, 33),
|
||||
make([]byte, 64),
|
||||
}
|
||||
for _, key := range badKeys {
|
||||
_, err := Encrypt("test", key)
|
||||
if err == nil {
|
||||
t.Fatalf("Encrypt should reject key of length %d", len(key))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptRejectsInvalidKeyLength(t *testing.T) {
|
||||
t.Parallel()
|
||||
badKeys := [][]byte{
|
||||
nil,
|
||||
make([]byte, 16),
|
||||
make([]byte, 33),
|
||||
}
|
||||
for _, key := range badKeys {
|
||||
_, err := Decrypt("dummydata:dummydata:dummydata", key)
|
||||
if err == nil {
|
||||
t.Fatalf("Decrypt should reject key of length %d", len(key))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptEmptyPlaintext(t *testing.T) {
|
||||
t.Parallel()
|
||||
key := makeKey(t)
|
||||
|
||||
encrypted, err := Encrypt("", key)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt empty plaintext error: %v", err)
|
||||
}
|
||||
decrypted, err := Decrypt(encrypted, key)
|
||||
if err != nil {
|
||||
t.Fatalf("Decrypt empty plaintext error: %v", err)
|
||||
}
|
||||
if decrypted != "" {
|
||||
t.Fatalf("expected empty string, got %q", decrypted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecryptUnicodeJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
key := makeKey(t)
|
||||
|
||||
jsonContent := `{"name":"测试用户","email":"test@example.com","balance":100.50}`
|
||||
encrypted, err := Encrypt(jsonContent, key)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt JSON error: %v", err)
|
||||
}
|
||||
decrypted, err := Decrypt(encrypted, key)
|
||||
if err != nil {
|
||||
t.Fatalf("Decrypt JSON error: %v", err)
|
||||
}
|
||||
if decrypted != jsonContent {
|
||||
t.Fatalf("JSON round-trip failed: got %q, want %q", decrypted, jsonContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptInvalidFormat(t *testing.T) {
|
||||
t.Parallel()
|
||||
key := makeKey(t)
|
||||
|
||||
invalidInputs := []string{
|
||||
"",
|
||||
"nodelimiter",
|
||||
"only:two",
|
||||
"invalid:base64:!!!",
|
||||
}
|
||||
for _, input := range invalidInputs {
|
||||
_, err := Decrypt(input, key)
|
||||
if err == nil {
|
||||
t.Fatalf("Decrypt(%q) should fail but got nil error", input)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCiphertextFormat(t *testing.T) {
|
||||
t.Parallel()
|
||||
key := makeKey(t)
|
||||
|
||||
encrypted, err := Encrypt("test", key)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt error: %v", err)
|
||||
}
|
||||
|
||||
parts := strings.SplitN(encrypted, ":", 3)
|
||||
if len(parts) != 3 {
|
||||
t.Fatalf("ciphertext should have format iv:authTag:ciphertext, got %d parts", len(parts))
|
||||
}
|
||||
for i, part := range parts {
|
||||
if part == "" {
|
||||
t.Fatalf("ciphertext part %d is empty", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
19
backend/internal/payment/fee.go
Normal file
19
backend/internal/payment/fee.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package payment
|
||||
|
||||
import (
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
// CalculatePayAmount computes the total pay amount given a recharge amount and
|
||||
// fee rate (percentage). Fee = amount * feeRate / 100, rounded UP (away from zero)
|
||||
// to 2 decimal places. The returned string is formatted to exactly 2 decimal places.
|
||||
// If feeRate <= 0, the amount is returned as-is (formatted to 2 decimal places).
|
||||
func CalculatePayAmount(rechargeAmount float64, feeRate float64) string {
|
||||
amount := decimal.NewFromFloat(rechargeAmount)
|
||||
if feeRate <= 0 {
|
||||
return amount.StringFixed(2)
|
||||
}
|
||||
rate := decimal.NewFromFloat(feeRate)
|
||||
fee := amount.Mul(rate).Div(decimal.NewFromInt(100)).RoundUp(2)
|
||||
return amount.Add(fee).StringFixed(2)
|
||||
}
|
||||
111
backend/internal/payment/fee_test.go
Normal file
111
backend/internal/payment/fee_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package payment
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCalculatePayAmount(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
amount float64
|
||||
feeRate float64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "zero fee rate returns same amount",
|
||||
amount: 100.00,
|
||||
feeRate: 0,
|
||||
expected: "100.00",
|
||||
},
|
||||
{
|
||||
name: "negative fee rate returns same amount",
|
||||
amount: 50.00,
|
||||
feeRate: -5,
|
||||
expected: "50.00",
|
||||
},
|
||||
{
|
||||
name: "1 percent fee rate",
|
||||
amount: 100.00,
|
||||
feeRate: 1,
|
||||
expected: "101.00",
|
||||
},
|
||||
{
|
||||
name: "5 percent fee on 200",
|
||||
amount: 200.00,
|
||||
feeRate: 5,
|
||||
expected: "210.00",
|
||||
},
|
||||
{
|
||||
name: "fee rounds UP to 2 decimal places",
|
||||
amount: 100.00,
|
||||
feeRate: 3,
|
||||
expected: "103.00",
|
||||
},
|
||||
{
|
||||
name: "fee rounds UP small remainder",
|
||||
amount: 10.00,
|
||||
feeRate: 3.33,
|
||||
expected: "10.34", // 10 * 3.33 / 100 = 0.333 -> round up -> 0.34
|
||||
},
|
||||
{
|
||||
name: "very small amount",
|
||||
amount: 0.01,
|
||||
feeRate: 1,
|
||||
expected: "0.02", // 0.01 * 1/100 = 0.0001 -> round up -> 0.01 -> total 0.02
|
||||
},
|
||||
{
|
||||
name: "large amount",
|
||||
amount: 99999.99,
|
||||
feeRate: 10,
|
||||
expected: "109999.99", // 99999.99 * 10/100 = 9999.999 -> round up -> 10000.00 -> total 109999.99
|
||||
},
|
||||
{
|
||||
name: "100 percent fee rate doubles amount",
|
||||
amount: 50.00,
|
||||
feeRate: 100,
|
||||
expected: "100.00",
|
||||
},
|
||||
{
|
||||
name: "precision 0.01 fee difference",
|
||||
amount: 100.00,
|
||||
feeRate: 1.01,
|
||||
expected: "101.01", // 100 * 1.01/100 = 1.01
|
||||
},
|
||||
{
|
||||
name: "precision 0.02 fee",
|
||||
amount: 100.00,
|
||||
feeRate: 1.02,
|
||||
expected: "101.02",
|
||||
},
|
||||
{
|
||||
name: "zero amount with positive fee",
|
||||
amount: 0,
|
||||
feeRate: 5,
|
||||
expected: "0.00",
|
||||
},
|
||||
{
|
||||
name: "fractional amount no fee",
|
||||
amount: 19.99,
|
||||
feeRate: 0,
|
||||
expected: "19.99",
|
||||
},
|
||||
{
|
||||
name: "fractional fee that causes rounding up",
|
||||
amount: 33.33,
|
||||
feeRate: 7.77,
|
||||
expected: "35.92", // 33.33 * 7.77 / 100 = 2.589741 -> round up -> 2.59 -> total 35.92
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := CalculatePayAmount(tt.amount, tt.feeRate)
|
||||
if got != tt.expected {
|
||||
t.Fatalf("CalculatePayAmount(%v, %v) = %q, want %q", tt.amount, tt.feeRate, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
328
backend/internal/payment/load_balancer.go
Normal file
328
backend/internal/payment/load_balancer.go
Normal file
@@ -0,0 +1,328 @@
|
||||
package payment
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
|
||||
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
|
||||
)
|
||||
|
||||
// Strategy represents a load balancing strategy for provider instance selection.
|
||||
type Strategy string
|
||||
|
||||
const (
|
||||
StrategyRoundRobin Strategy = "round-robin"
|
||||
StrategyLeastAmount Strategy = "least-amount"
|
||||
)
|
||||
|
||||
// ChannelLimits holds limits for a single payment channel within a provider instance.
|
||||
type ChannelLimits struct {
|
||||
DailyLimit float64 `json:"dailyLimit,omitempty"`
|
||||
SingleMin float64 `json:"singleMin,omitempty"`
|
||||
SingleMax float64 `json:"singleMax,omitempty"`
|
||||
}
|
||||
|
||||
// InstanceLimits holds per-channel limits for a provider instance (JSON).
|
||||
type InstanceLimits map[string]ChannelLimits
|
||||
|
||||
// LoadBalancer selects a provider instance for a given payment type.
|
||||
type LoadBalancer interface {
|
||||
GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error)
|
||||
SelectInstance(ctx context.Context, providerKey string, paymentType PaymentType, strategy Strategy, orderAmount float64) (*InstanceSelection, error)
|
||||
}
|
||||
|
||||
// DefaultLoadBalancer implements LoadBalancer using database queries.
|
||||
type DefaultLoadBalancer struct {
|
||||
db *dbent.Client
|
||||
encryptionKey []byte
|
||||
counter atomic.Uint64
|
||||
}
|
||||
|
||||
// NewDefaultLoadBalancer creates a new load balancer.
|
||||
func NewDefaultLoadBalancer(db *dbent.Client, encryptionKey []byte) *DefaultLoadBalancer {
|
||||
return &DefaultLoadBalancer{db: db, encryptionKey: encryptionKey}
|
||||
}
|
||||
|
||||
// instanceCandidate pairs an instance with its pre-fetched daily usage.
|
||||
type instanceCandidate struct {
|
||||
inst *dbent.PaymentProviderInstance
|
||||
dailyUsed float64 // includes PENDING orders
|
||||
}
|
||||
|
||||
// SelectInstance picks an enabled instance for the given provider key and payment type.
|
||||
//
|
||||
// Flow:
|
||||
// 1. Query all enabled instances for providerKey, filter by supported paymentType
|
||||
// 2. Batch-query daily usage (PENDING + PAID + COMPLETED + RECHARGING) for all candidates
|
||||
// 3. Filter out instances where: single-min/max violated OR daily remaining < orderAmount
|
||||
// 4. Pick from survivors using the configured strategy (round-robin / least-amount)
|
||||
// 5. If all filtered out, fall back to full list (let the provider itself reject)
|
||||
func (lb *DefaultLoadBalancer) SelectInstance(
|
||||
ctx context.Context,
|
||||
providerKey string,
|
||||
paymentType PaymentType,
|
||||
strategy Strategy,
|
||||
orderAmount float64,
|
||||
) (*InstanceSelection, error) {
|
||||
// Step 1: query enabled instances matching payment type.
|
||||
instances, err := lb.queryEnabledInstances(ctx, providerKey, paymentType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Step 2: batch-fetch daily usage for all candidates.
|
||||
candidates := lb.attachDailyUsage(ctx, instances)
|
||||
|
||||
// Step 3: filter by limits.
|
||||
available := filterByLimits(candidates, paymentType, orderAmount)
|
||||
if len(available) == 0 {
|
||||
slog.Warn("all instances exceeded limits, using full candidate list",
|
||||
"provider", providerKey, "payment_type", paymentType,
|
||||
"order_amount", orderAmount, "count", len(candidates))
|
||||
available = candidates
|
||||
}
|
||||
|
||||
// Step 4: pick by strategy.
|
||||
selected := lb.pickByStrategy(available, strategy)
|
||||
return lb.buildSelection(selected.inst)
|
||||
}
|
||||
|
||||
// queryEnabledInstances returns enabled instances for providerKey that support paymentType.
|
||||
func (lb *DefaultLoadBalancer) queryEnabledInstances(
|
||||
ctx context.Context,
|
||||
providerKey string,
|
||||
paymentType PaymentType,
|
||||
) ([]*dbent.PaymentProviderInstance, error) {
|
||||
instances, err := lb.db.PaymentProviderInstance.Query().
|
||||
Where(
|
||||
paymentproviderinstance.ProviderKey(providerKey),
|
||||
paymentproviderinstance.Enabled(true),
|
||||
).
|
||||
Order(dbent.Asc(paymentproviderinstance.FieldSortOrder)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query provider instances: %w", err)
|
||||
}
|
||||
|
||||
var matched []*dbent.PaymentProviderInstance
|
||||
for _, inst := range instances {
|
||||
if paymentType == providerKey || InstanceSupportsType(inst.SupportedTypes, paymentType) {
|
||||
matched = append(matched, inst)
|
||||
}
|
||||
}
|
||||
if len(matched) == 0 {
|
||||
return nil, fmt.Errorf("no enabled instance for provider %s type %s", providerKey, paymentType)
|
||||
}
|
||||
return matched, nil
|
||||
}
|
||||
|
||||
// attachDailyUsage queries daily usage for each instance in a single pass.
|
||||
// Usage includes PENDING orders to avoid over-committing capacity.
|
||||
func (lb *DefaultLoadBalancer) attachDailyUsage(
|
||||
ctx context.Context,
|
||||
instances []*dbent.PaymentProviderInstance,
|
||||
) []instanceCandidate {
|
||||
todayStart := startOfDay(time.Now())
|
||||
|
||||
// Collect instance IDs.
|
||||
ids := make([]string, len(instances))
|
||||
for i, inst := range instances {
|
||||
ids[i] = fmt.Sprintf("%d", inst.ID)
|
||||
}
|
||||
|
||||
// Batch query: sum pay_amount grouped by provider_instance_id.
|
||||
type row struct {
|
||||
InstanceID string `json:"provider_instance_id"`
|
||||
Sum float64 `json:"sum"`
|
||||
}
|
||||
var rows []row
|
||||
err := lb.db.PaymentOrder.Query().
|
||||
Where(
|
||||
paymentorder.ProviderInstanceIDIn(ids...),
|
||||
paymentorder.StatusIn(
|
||||
OrderStatusPending, OrderStatusPaid,
|
||||
OrderStatusCompleted, OrderStatusRecharging,
|
||||
),
|
||||
paymentorder.CreatedAtGTE(todayStart),
|
||||
).
|
||||
GroupBy(paymentorder.FieldProviderInstanceID).
|
||||
Aggregate(dbent.Sum(paymentorder.FieldPayAmount)).
|
||||
Scan(ctx, &rows)
|
||||
if err != nil {
|
||||
slog.Warn("batch daily usage query failed, treating all as zero", "error", err)
|
||||
}
|
||||
|
||||
usageMap := make(map[string]float64, len(rows))
|
||||
for _, r := range rows {
|
||||
usageMap[r.InstanceID] = r.Sum
|
||||
}
|
||||
|
||||
candidates := make([]instanceCandidate, len(instances))
|
||||
for i, inst := range instances {
|
||||
candidates[i] = instanceCandidate{
|
||||
inst: inst,
|
||||
dailyUsed: usageMap[fmt.Sprintf("%d", inst.ID)],
|
||||
}
|
||||
}
|
||||
return candidates
|
||||
}
|
||||
|
||||
// filterByLimits removes instances that cannot accommodate the order:
|
||||
// - orderAmount outside single-transaction [min, max]
|
||||
// - daily remaining capacity (limit - used) < orderAmount
|
||||
func filterByLimits(candidates []instanceCandidate, paymentType PaymentType, orderAmount float64) []instanceCandidate {
|
||||
var result []instanceCandidate
|
||||
for _, c := range candidates {
|
||||
cl := getInstanceChannelLimits(c.inst, paymentType)
|
||||
|
||||
if cl.SingleMin > 0 && orderAmount < cl.SingleMin {
|
||||
slog.Info("order below instance single min, skipping",
|
||||
"instance_id", c.inst.ID, "order", orderAmount, "min", cl.SingleMin)
|
||||
continue
|
||||
}
|
||||
if cl.SingleMax > 0 && orderAmount > cl.SingleMax {
|
||||
slog.Info("order above instance single max, skipping",
|
||||
"instance_id", c.inst.ID, "order", orderAmount, "max", cl.SingleMax)
|
||||
continue
|
||||
}
|
||||
if cl.DailyLimit > 0 && c.dailyUsed+orderAmount > cl.DailyLimit {
|
||||
slog.Info("instance daily remaining insufficient, skipping",
|
||||
"instance_id", c.inst.ID, "used", c.dailyUsed,
|
||||
"order", orderAmount, "limit", cl.DailyLimit)
|
||||
continue
|
||||
}
|
||||
|
||||
result = append(result, c)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// getInstanceChannelLimits returns the channel limits for a specific payment type.
|
||||
func getInstanceChannelLimits(inst *dbent.PaymentProviderInstance, paymentType PaymentType) ChannelLimits {
|
||||
if inst.Limits == "" {
|
||||
return ChannelLimits{}
|
||||
}
|
||||
var limits InstanceLimits
|
||||
if err := json.Unmarshal([]byte(inst.Limits), &limits); err != nil {
|
||||
return ChannelLimits{}
|
||||
}
|
||||
// For Stripe, limits are stored under the provider key "stripe".
|
||||
lookupKey := paymentType
|
||||
if inst.ProviderKey == "stripe" {
|
||||
lookupKey = "stripe"
|
||||
}
|
||||
if cl, ok := limits[lookupKey]; ok {
|
||||
return cl
|
||||
}
|
||||
return ChannelLimits{}
|
||||
}
|
||||
|
||||
// pickByStrategy selects one instance from the available candidates.
|
||||
func (lb *DefaultLoadBalancer) pickByStrategy(candidates []instanceCandidate, strategy Strategy) instanceCandidate {
|
||||
if strategy == StrategyLeastAmount && len(candidates) > 1 {
|
||||
return pickLeastAmount(candidates)
|
||||
}
|
||||
// Default: round-robin.
|
||||
idx := lb.counter.Add(1) % uint64(len(candidates))
|
||||
return candidates[idx]
|
||||
}
|
||||
|
||||
// pickLeastAmount selects the instance with the lowest daily usage.
|
||||
// No extra DB queries — usage was pre-fetched in attachDailyUsage.
|
||||
func pickLeastAmount(candidates []instanceCandidate) instanceCandidate {
|
||||
best := candidates[0]
|
||||
for _, c := range candidates[1:] {
|
||||
if c.dailyUsed < best.dailyUsed {
|
||||
best = c
|
||||
}
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderInstance) (*InstanceSelection, error) {
|
||||
config, err := lb.decryptConfig(selected.Config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypt instance %d config: %w", selected.ID, err)
|
||||
}
|
||||
|
||||
if selected.PaymentMode != "" {
|
||||
config["paymentMode"] = selected.PaymentMode
|
||||
}
|
||||
|
||||
return &InstanceSelection{
|
||||
InstanceID: fmt.Sprintf("%d", selected.ID),
|
||||
Config: config,
|
||||
SupportedTypes: selected.SupportedTypes,
|
||||
PaymentMode: selected.PaymentMode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (lb *DefaultLoadBalancer) decryptConfig(encrypted string) (map[string]string, error) {
|
||||
plaintext, err := Decrypt(encrypted, lb.encryptionKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var config map[string]string
|
||||
if err := json.Unmarshal([]byte(plaintext), &config); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal config: %w", err)
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// GetInstanceDailyAmount returns the total completed order amount for an instance today.
|
||||
func (lb *DefaultLoadBalancer) GetInstanceDailyAmount(ctx context.Context, instanceID string) (float64, error) {
|
||||
todayStart := startOfDay(time.Now())
|
||||
|
||||
var result []struct {
|
||||
Sum float64 `json:"sum"`
|
||||
}
|
||||
err := lb.db.PaymentOrder.Query().
|
||||
Where(
|
||||
paymentorder.ProviderInstanceID(instanceID),
|
||||
paymentorder.StatusIn(OrderStatusCompleted, OrderStatusPaid, OrderStatusRecharging),
|
||||
paymentorder.PaidAtGTE(todayStart),
|
||||
).
|
||||
Aggregate(dbent.Sum(paymentorder.FieldPayAmount)).
|
||||
Scan(ctx, &result)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("query daily amount: %w", err)
|
||||
}
|
||||
if len(result) > 0 {
|
||||
return result[0].Sum, nil
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func startOfDay(t time.Time) time.Time {
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location())
|
||||
}
|
||||
|
||||
// InstanceSupportsType checks if the given supported types string includes the target type.
|
||||
// An empty supportedTypes string means all types are supported.
|
||||
func InstanceSupportsType(supportedTypes string, target PaymentType) bool {
|
||||
if supportedTypes == "" {
|
||||
return true
|
||||
}
|
||||
for _, t := range strings.Split(supportedTypes, ",") {
|
||||
if strings.TrimSpace(t) == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetInstanceConfig decrypts and returns the configuration for a provider instance by ID.
|
||||
func (lb *DefaultLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) {
|
||||
inst, err := lb.db.PaymentProviderInstance.Get(ctx, instanceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get instance %d: %w", instanceID, err)
|
||||
}
|
||||
return lb.decryptConfig(inst.Config)
|
||||
}
|
||||
474
backend/internal/payment/load_balancer_test.go
Normal file
474
backend/internal/payment/load_balancer_test.go
Normal file
@@ -0,0 +1,474 @@
|
||||
//go:build unit
|
||||
|
||||
package payment
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
)
|
||||
|
||||
func TestInstanceSupportsType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
supportedTypes string
|
||||
target PaymentType
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "exact match single type",
|
||||
supportedTypes: "alipay",
|
||||
target: "alipay",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "no match single type",
|
||||
supportedTypes: "wxpay",
|
||||
target: "alipay",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "match in comma-separated list",
|
||||
supportedTypes: "alipay,wxpay,stripe",
|
||||
target: "wxpay",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "first in comma-separated list",
|
||||
supportedTypes: "alipay,wxpay",
|
||||
target: "alipay",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "last in comma-separated list",
|
||||
supportedTypes: "alipay,wxpay,stripe",
|
||||
target: "stripe",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "no match in comma-separated list",
|
||||
supportedTypes: "alipay,wxpay",
|
||||
target: "stripe",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty target",
|
||||
supportedTypes: "alipay,wxpay",
|
||||
target: "",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "types with spaces are trimmed",
|
||||
supportedTypes: " alipay , wxpay ",
|
||||
target: "alipay",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "partial match should not succeed",
|
||||
supportedTypes: "alipay_direct",
|
||||
target: "alipay",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty supported types means all supported",
|
||||
supportedTypes: "",
|
||||
target: "alipay",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := InstanceSupportsType(tt.supportedTypes, tt.target)
|
||||
if got != tt.expected {
|
||||
t.Fatalf("InstanceSupportsType(%q, %q) = %v, want %v", tt.supportedTypes, tt.target, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helper to build test PaymentProviderInstance values
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func testInstance(id int64, providerKey, limits string) *dbent.PaymentProviderInstance {
|
||||
return &dbent.PaymentProviderInstance{
|
||||
ID: id,
|
||||
ProviderKey: providerKey,
|
||||
Limits: limits,
|
||||
Enabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
// makeLimitsJSON builds a limits JSON string for a single payment type.
|
||||
func makeLimitsJSON(paymentType string, cl ChannelLimits) string {
|
||||
m := map[string]ChannelLimits{paymentType: cl}
|
||||
b, _ := json.Marshal(m)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// filterByLimits
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestFilterByLimits(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
candidates []instanceCandidate
|
||||
paymentType PaymentType
|
||||
orderAmount float64
|
||||
wantIDs []int64 // expected surviving instance IDs
|
||||
}{
|
||||
{
|
||||
name: "order below SingleMin is filtered out",
|
||||
candidates: []instanceCandidate{
|
||||
{inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMin: 10})), dailyUsed: 0},
|
||||
},
|
||||
paymentType: "alipay",
|
||||
orderAmount: 5,
|
||||
wantIDs: nil,
|
||||
},
|
||||
{
|
||||
name: "order at exact SingleMin boundary passes",
|
||||
candidates: []instanceCandidate{
|
||||
{inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMin: 10})), dailyUsed: 0},
|
||||
},
|
||||
paymentType: "alipay",
|
||||
orderAmount: 10,
|
||||
wantIDs: []int64{1},
|
||||
},
|
||||
{
|
||||
name: "order above SingleMax is filtered out",
|
||||
candidates: []instanceCandidate{
|
||||
{inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMax: 100})), dailyUsed: 0},
|
||||
},
|
||||
paymentType: "alipay",
|
||||
orderAmount: 150,
|
||||
wantIDs: nil,
|
||||
},
|
||||
{
|
||||
name: "order at exact SingleMax boundary passes",
|
||||
candidates: []instanceCandidate{
|
||||
{inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMax: 100})), dailyUsed: 0},
|
||||
},
|
||||
paymentType: "alipay",
|
||||
orderAmount: 100,
|
||||
wantIDs: []int64{1},
|
||||
},
|
||||
{
|
||||
name: "daily used + orderAmount exceeding dailyLimit is filtered out",
|
||||
candidates: []instanceCandidate{
|
||||
{inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{DailyLimit: 500})), dailyUsed: 480},
|
||||
},
|
||||
paymentType: "alipay",
|
||||
orderAmount: 30,
|
||||
wantIDs: nil, // 480+30=510 > 500
|
||||
},
|
||||
{
|
||||
name: "daily used + orderAmount equal to dailyLimit passes (strict greater-than)",
|
||||
candidates: []instanceCandidate{
|
||||
{inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{DailyLimit: 500})), dailyUsed: 480},
|
||||
},
|
||||
paymentType: "alipay",
|
||||
orderAmount: 20,
|
||||
wantIDs: []int64{1}, // 480+20=500, 500 > 500 is false → passes
|
||||
},
|
||||
{
|
||||
name: "daily used + orderAmount below dailyLimit passes",
|
||||
candidates: []instanceCandidate{
|
||||
{inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{DailyLimit: 500})), dailyUsed: 400},
|
||||
},
|
||||
paymentType: "alipay",
|
||||
orderAmount: 50,
|
||||
wantIDs: []int64{1},
|
||||
},
|
||||
{
|
||||
name: "no limits configured passes through",
|
||||
candidates: []instanceCandidate{
|
||||
{inst: testInstance(1, "easypay", ""), dailyUsed: 99999},
|
||||
},
|
||||
paymentType: "alipay",
|
||||
orderAmount: 100,
|
||||
wantIDs: []int64{1},
|
||||
},
|
||||
{
|
||||
name: "multiple candidates with partial filtering",
|
||||
candidates: []instanceCandidate{
|
||||
// singleMax=50, order=80 → filtered out
|
||||
{inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMax: 50})), dailyUsed: 0},
|
||||
// no limits → passes
|
||||
{inst: testInstance(2, "easypay", ""), dailyUsed: 0},
|
||||
// singleMin=100, order=80 → filtered out
|
||||
{inst: testInstance(3, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMin: 100})), dailyUsed: 0},
|
||||
// daily limit ok → passes (500+80=580 < 1000)
|
||||
{inst: testInstance(4, "easypay", makeLimitsJSON("alipay", ChannelLimits{DailyLimit: 1000})), dailyUsed: 500},
|
||||
},
|
||||
paymentType: "alipay",
|
||||
orderAmount: 80,
|
||||
wantIDs: []int64{2, 4},
|
||||
},
|
||||
{
|
||||
name: "zero SingleMin and SingleMax means no single-transaction limit",
|
||||
candidates: []instanceCandidate{
|
||||
{inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMin: 0, SingleMax: 0, DailyLimit: 0})), dailyUsed: 0},
|
||||
},
|
||||
paymentType: "alipay",
|
||||
orderAmount: 99999,
|
||||
wantIDs: []int64{1},
|
||||
},
|
||||
{
|
||||
name: "all limits combined - order passes all checks",
|
||||
candidates: []instanceCandidate{
|
||||
{inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMin: 10, SingleMax: 200, DailyLimit: 1000})), dailyUsed: 500},
|
||||
},
|
||||
paymentType: "alipay",
|
||||
orderAmount: 50,
|
||||
wantIDs: []int64{1},
|
||||
},
|
||||
{
|
||||
name: "all limits combined - order fails SingleMin",
|
||||
candidates: []instanceCandidate{
|
||||
{inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMin: 10, SingleMax: 200, DailyLimit: 1000})), dailyUsed: 500},
|
||||
},
|
||||
paymentType: "alipay",
|
||||
orderAmount: 5,
|
||||
wantIDs: nil,
|
||||
},
|
||||
{
|
||||
name: "empty candidates returns empty",
|
||||
candidates: nil,
|
||||
paymentType: "alipay",
|
||||
orderAmount: 10,
|
||||
wantIDs: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := filterByLimits(tt.candidates, tt.paymentType, tt.orderAmount)
|
||||
gotIDs := make([]int64, len(got))
|
||||
for i, c := range got {
|
||||
gotIDs[i] = c.inst.ID
|
||||
}
|
||||
if !int64SliceEqual(gotIDs, tt.wantIDs) {
|
||||
t.Fatalf("filterByLimits() returned IDs %v, want %v", gotIDs, tt.wantIDs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// pickLeastAmount
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestPickLeastAmount(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("picks candidate with lowest dailyUsed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
candidates := []instanceCandidate{
|
||||
{inst: testInstance(1, "easypay", ""), dailyUsed: 300},
|
||||
{inst: testInstance(2, "easypay", ""), dailyUsed: 100},
|
||||
{inst: testInstance(3, "easypay", ""), dailyUsed: 200},
|
||||
}
|
||||
got := pickLeastAmount(candidates)
|
||||
if got.inst.ID != 2 {
|
||||
t.Fatalf("pickLeastAmount() picked instance %d, want 2", got.inst.ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with equal dailyUsed picks the first one", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
candidates := []instanceCandidate{
|
||||
{inst: testInstance(1, "easypay", ""), dailyUsed: 100},
|
||||
{inst: testInstance(2, "easypay", ""), dailyUsed: 100},
|
||||
{inst: testInstance(3, "easypay", ""), dailyUsed: 200},
|
||||
}
|
||||
got := pickLeastAmount(candidates)
|
||||
if got.inst.ID != 1 {
|
||||
t.Fatalf("pickLeastAmount() picked instance %d, want 1 (first with lowest)", got.inst.ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("single candidate returns that candidate", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
candidates := []instanceCandidate{
|
||||
{inst: testInstance(42, "easypay", ""), dailyUsed: 999},
|
||||
}
|
||||
got := pickLeastAmount(candidates)
|
||||
if got.inst.ID != 42 {
|
||||
t.Fatalf("pickLeastAmount() picked instance %d, want 42", got.inst.ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("zero usage among non-zero picks zero", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
candidates := []instanceCandidate{
|
||||
{inst: testInstance(1, "easypay", ""), dailyUsed: 500},
|
||||
{inst: testInstance(2, "easypay", ""), dailyUsed: 0},
|
||||
{inst: testInstance(3, "easypay", ""), dailyUsed: 300},
|
||||
}
|
||||
got := pickLeastAmount(candidates)
|
||||
if got.inst.ID != 2 {
|
||||
t.Fatalf("pickLeastAmount() picked instance %d, want 2", got.inst.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// getInstanceChannelLimits
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGetInstanceChannelLimits(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
inst *dbent.PaymentProviderInstance
|
||||
paymentType PaymentType
|
||||
want ChannelLimits
|
||||
}{
|
||||
{
|
||||
name: "empty limits string returns zero ChannelLimits",
|
||||
inst: testInstance(1, "easypay", ""),
|
||||
paymentType: "alipay",
|
||||
want: ChannelLimits{},
|
||||
},
|
||||
{
|
||||
name: "invalid JSON returns zero ChannelLimits",
|
||||
inst: testInstance(1, "easypay", "not-json{"),
|
||||
paymentType: "alipay",
|
||||
want: ChannelLimits{},
|
||||
},
|
||||
{
|
||||
name: "valid JSON with matching payment type",
|
||||
inst: testInstance(1, "easypay",
|
||||
`{"alipay":{"singleMin":5,"singleMax":200,"dailyLimit":1000}}`),
|
||||
paymentType: "alipay",
|
||||
want: ChannelLimits{SingleMin: 5, SingleMax: 200, DailyLimit: 1000},
|
||||
},
|
||||
{
|
||||
name: "payment type not in limits returns zero ChannelLimits",
|
||||
inst: testInstance(1, "easypay",
|
||||
`{"alipay":{"singleMin":5,"singleMax":200}}`),
|
||||
paymentType: "wxpay",
|
||||
want: ChannelLimits{},
|
||||
},
|
||||
{
|
||||
name: "stripe provider uses stripe lookup key regardless of payment type",
|
||||
inst: testInstance(1, "stripe",
|
||||
`{"stripe":{"singleMin":10,"singleMax":500,"dailyLimit":5000}}`),
|
||||
paymentType: "alipay",
|
||||
want: ChannelLimits{SingleMin: 10, SingleMax: 500, DailyLimit: 5000},
|
||||
},
|
||||
{
|
||||
name: "stripe provider ignores payment type key even if present",
|
||||
inst: testInstance(1, "stripe",
|
||||
`{"stripe":{"singleMin":10,"singleMax":500},"alipay":{"singleMin":1,"singleMax":100}}`),
|
||||
paymentType: "alipay",
|
||||
want: ChannelLimits{SingleMin: 10, SingleMax: 500},
|
||||
},
|
||||
{
|
||||
name: "non-stripe provider uses payment type as lookup key",
|
||||
inst: testInstance(1, "easypay",
|
||||
`{"alipay":{"singleMin":5},"wxpay":{"singleMin":10}}`),
|
||||
paymentType: "wxpay",
|
||||
want: ChannelLimits{SingleMin: 10},
|
||||
},
|
||||
{
|
||||
name: "valid JSON with partial limits (only dailyLimit)",
|
||||
inst: testInstance(1, "easypay",
|
||||
`{"alipay":{"dailyLimit":800}}`),
|
||||
paymentType: "alipay",
|
||||
want: ChannelLimits{DailyLimit: 800},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := getInstanceChannelLimits(tt.inst, tt.paymentType)
|
||||
if got != tt.want {
|
||||
t.Fatalf("getInstanceChannelLimits() = %+v, want %+v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// startOfDay
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestStartOfDay(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
in time.Time
|
||||
want time.Time
|
||||
}{
|
||||
{
|
||||
name: "midday returns midnight of same day",
|
||||
in: time.Date(2025, 6, 15, 14, 30, 45, 123456789, time.UTC),
|
||||
want: time.Date(2025, 6, 15, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "midnight returns same time",
|
||||
in: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
want: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "last second of day returns midnight of same day",
|
||||
in: time.Date(2025, 12, 31, 23, 59, 59, 999999999, time.UTC),
|
||||
want: time.Date(2025, 12, 31, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "preserves timezone location",
|
||||
in: time.Date(2025, 3, 10, 15, 0, 0, 0, time.FixedZone("CST", 8*3600)),
|
||||
want: time.Date(2025, 3, 10, 0, 0, 0, 0, time.FixedZone("CST", 8*3600)),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := startOfDay(tt.in)
|
||||
if !got.Equal(tt.want) {
|
||||
t.Fatalf("startOfDay(%v) = %v, want %v", tt.in, got, tt.want)
|
||||
}
|
||||
// Also verify location is preserved.
|
||||
if got.Location().String() != tt.want.Location().String() {
|
||||
t.Fatalf("startOfDay() location = %v, want %v", got.Location(), tt.want.Location())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// int64SliceEqual compares two int64 slices for equality.
|
||||
// Both nil and empty slices are treated as equal.
|
||||
func int64SliceEqual(a, b []int64) bool {
|
||||
if len(a) == 0 && len(b) == 0 {
|
||||
return true
|
||||
}
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
279
backend/internal/payment/provider/alipay.go
Normal file
279
backend/internal/payment/provider/alipay.go
Normal file
@@ -0,0 +1,279 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
"github.com/smartwalle/alipay/v3"
|
||||
)
|
||||
|
||||
// Alipay product codes.
|
||||
const (
|
||||
alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY"
|
||||
alipayProductCodeWapPay = "QUICK_WAP_WAY"
|
||||
)
|
||||
|
||||
// Alipay response constants.
|
||||
const (
|
||||
alipayFundChangeYes = "Y"
|
||||
alipayErrTradeNotExist = "ACQ.TRADE_NOT_EXIST"
|
||||
alipayRefundSuffix = "-refund"
|
||||
)
|
||||
|
||||
// Alipay implements payment.Provider and payment.CancelableProvider using the smartwalle/alipay SDK.
|
||||
type Alipay struct {
|
||||
instanceID string
|
||||
config map[string]string // appId, privateKey, publicKey (or alipayPublicKey), notifyUrl, returnUrl
|
||||
|
||||
mu sync.Mutex
|
||||
client *alipay.Client
|
||||
}
|
||||
|
||||
// NewAlipay creates a new Alipay provider instance.
|
||||
func NewAlipay(instanceID string, config map[string]string) (*Alipay, error) {
|
||||
required := []string{"appId", "privateKey"}
|
||||
for _, k := range required {
|
||||
if config[k] == "" {
|
||||
return nil, fmt.Errorf("alipay config missing required key: %s", k)
|
||||
}
|
||||
}
|
||||
return &Alipay{
|
||||
instanceID: instanceID,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *Alipay) getClient() (*alipay.Client, error) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
if a.client != nil {
|
||||
return a.client, nil
|
||||
}
|
||||
client, err := alipay.New(a.config["appId"], a.config["privateKey"], true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("alipay init client: %w", err)
|
||||
}
|
||||
pubKey := a.config["publicKey"]
|
||||
if pubKey == "" {
|
||||
pubKey = a.config["alipayPublicKey"]
|
||||
}
|
||||
if pubKey == "" {
|
||||
return nil, fmt.Errorf("alipay config missing required key: publicKey (or alipayPublicKey)")
|
||||
}
|
||||
if err := client.LoadAliPayPublicKey(pubKey); err != nil {
|
||||
return nil, fmt.Errorf("alipay load public key: %w", err)
|
||||
}
|
||||
a.client = client
|
||||
return a.client, nil
|
||||
}
|
||||
|
||||
func (a *Alipay) Name() string { return "Alipay" }
|
||||
func (a *Alipay) ProviderKey() string { return payment.TypeAlipay }
|
||||
func (a *Alipay) SupportedTypes() []payment.PaymentType {
|
||||
return []payment.PaymentType{payment.TypeAlipayDirect}
|
||||
}
|
||||
|
||||
// CreatePayment creates an Alipay payment page URL.
|
||||
func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
|
||||
client, err := a.getClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
notifyURL := a.config["notifyUrl"]
|
||||
if req.NotifyURL != "" {
|
||||
notifyURL = req.NotifyURL
|
||||
}
|
||||
returnURL := a.config["returnUrl"]
|
||||
if req.ReturnURL != "" {
|
||||
returnURL = req.ReturnURL
|
||||
}
|
||||
|
||||
if req.IsMobile {
|
||||
return a.createTrade(client, req, notifyURL, returnURL, true)
|
||||
}
|
||||
return a.createTrade(client, req, notifyURL, returnURL, false)
|
||||
}
|
||||
|
||||
func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string, isMobile bool) (*payment.CreatePaymentResponse, error) {
|
||||
if isMobile {
|
||||
param := alipay.TradeWapPay{}
|
||||
param.OutTradeNo = req.OrderID
|
||||
param.TotalAmount = req.Amount
|
||||
param.Subject = req.Subject
|
||||
param.ProductCode = alipayProductCodeWapPay
|
||||
param.NotifyURL = notifyURL
|
||||
param.ReturnURL = returnURL
|
||||
|
||||
payURL, err := client.TradeWapPay(param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("alipay TradeWapPay: %w", err)
|
||||
}
|
||||
return &payment.CreatePaymentResponse{
|
||||
TradeNo: req.OrderID,
|
||||
PayURL: payURL.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
param := alipay.TradePagePay{}
|
||||
param.OutTradeNo = req.OrderID
|
||||
param.TotalAmount = req.Amount
|
||||
param.Subject = req.Subject
|
||||
param.ProductCode = alipayProductCodePagePay
|
||||
param.NotifyURL = notifyURL
|
||||
param.ReturnURL = returnURL
|
||||
|
||||
payURL, err := client.TradePagePay(param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("alipay TradePagePay: %w", err)
|
||||
}
|
||||
return &payment.CreatePaymentResponse{
|
||||
TradeNo: req.OrderID,
|
||||
PayURL: payURL.String(),
|
||||
QRCode: payURL.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// QueryOrder queries the trade status via Alipay.
|
||||
func (a *Alipay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
|
||||
client, err := a.getClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, err := client.TradeQuery(ctx, alipay.TradeQuery{OutTradeNo: tradeNo})
|
||||
if err != nil {
|
||||
if isTradeNotExist(err) {
|
||||
return &payment.QueryOrderResponse{
|
||||
TradeNo: tradeNo,
|
||||
Status: payment.ProviderStatusPending,
|
||||
}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("alipay TradeQuery: %w", err)
|
||||
}
|
||||
|
||||
status := payment.ProviderStatusPending
|
||||
switch result.TradeStatus {
|
||||
case alipay.TradeStatusSuccess, alipay.TradeStatusFinished:
|
||||
status = payment.ProviderStatusPaid
|
||||
case alipay.TradeStatusClosed:
|
||||
status = payment.ProviderStatusFailed
|
||||
}
|
||||
|
||||
amount, err := strconv.ParseFloat(result.TotalAmount, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("alipay parse amount %q: %w", result.TotalAmount, err)
|
||||
}
|
||||
|
||||
return &payment.QueryOrderResponse{
|
||||
TradeNo: result.TradeNo,
|
||||
Status: status,
|
||||
Amount: amount,
|
||||
PaidAt: result.SendPayDate,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// VerifyNotification decodes and verifies an Alipay async notification.
|
||||
func (a *Alipay) VerifyNotification(ctx context.Context, rawBody string, _ map[string]string) (*payment.PaymentNotification, error) {
|
||||
client, err := a.getClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
values, err := url.ParseQuery(rawBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("alipay parse notification: %w", err)
|
||||
}
|
||||
|
||||
notification, err := client.DecodeNotification(ctx, values)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("alipay verify notification: %w", err)
|
||||
}
|
||||
|
||||
status := payment.ProviderStatusFailed
|
||||
if notification.TradeStatus == alipay.TradeStatusSuccess || notification.TradeStatus == alipay.TradeStatusFinished {
|
||||
status = payment.ProviderStatusSuccess
|
||||
}
|
||||
|
||||
amount, err := strconv.ParseFloat(notification.TotalAmount, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("alipay parse notification amount %q: %w", notification.TotalAmount, err)
|
||||
}
|
||||
|
||||
return &payment.PaymentNotification{
|
||||
TradeNo: notification.TradeNo,
|
||||
OrderID: notification.OutTradeNo,
|
||||
Amount: amount,
|
||||
Status: status,
|
||||
RawData: rawBody,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Refund requests a refund through Alipay.
|
||||
func (a *Alipay) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
|
||||
client, err := a.getClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, err := client.TradeRefund(ctx, alipay.TradeRefund{
|
||||
OutTradeNo: req.OrderID,
|
||||
RefundAmount: req.Amount,
|
||||
RefundReason: req.Reason,
|
||||
OutRequestNo: fmt.Sprintf("%s-refund-%d", req.OrderID, time.Now().UnixNano()),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("alipay TradeRefund: %w", err)
|
||||
}
|
||||
|
||||
refundStatus := payment.ProviderStatusPending
|
||||
if result.FundChange == alipayFundChangeYes {
|
||||
refundStatus = payment.ProviderStatusSuccess
|
||||
}
|
||||
|
||||
refundID := result.TradeNo
|
||||
if refundID == "" {
|
||||
refundID = req.OrderID + alipayRefundSuffix
|
||||
}
|
||||
|
||||
return &payment.RefundResponse{
|
||||
RefundID: refundID,
|
||||
Status: refundStatus,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CancelPayment closes a pending trade on Alipay.
|
||||
func (a *Alipay) CancelPayment(ctx context.Context, tradeNo string) error {
|
||||
client, err := a.getClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = client.TradeClose(ctx, alipay.TradeClose{OutTradeNo: tradeNo})
|
||||
if err != nil {
|
||||
if isTradeNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("alipay TradeClose: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isTradeNotExist(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(err.Error(), alipayErrTradeNotExist)
|
||||
}
|
||||
|
||||
// Ensure interface compliance.
|
||||
var (
|
||||
_ payment.Provider = (*Alipay)(nil)
|
||||
_ payment.CancelableProvider = (*Alipay)(nil)
|
||||
)
|
||||
132
backend/internal/payment/provider/alipay_test.go
Normal file
132
backend/internal/payment/provider/alipay_test.go
Normal file
@@ -0,0 +1,132 @@
|
||||
//go:build unit
|
||||
|
||||
package provider
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsTradeNotExist(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil error returns false",
|
||||
err: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "error containing ACQ.TRADE_NOT_EXIST returns true",
|
||||
err: errors.New("alipay: sub_code=ACQ.TRADE_NOT_EXIST, sub_msg=交易不存在"),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "error not containing the code returns false",
|
||||
err: errors.New("alipay: sub_code=ACQ.SYSTEM_ERROR, sub_msg=系统错误"),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "error with only partial match returns false",
|
||||
err: errors.New("ACQ.TRADE_NOT"),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "error with exact constant value returns true",
|
||||
err: errors.New(alipayErrTradeNotExist),
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := isTradeNotExist(tt.err)
|
||||
if got != tt.want {
|
||||
t.Errorf("isTradeNotExist(%v) = %v, want %v", tt.err, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAlipay(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
validConfig := map[string]string{
|
||||
"appId": "2021001234567890",
|
||||
"privateKey": "MIIEvQIBADANBgkqhkiG9w0BAQEFAASC...",
|
||||
}
|
||||
|
||||
// helper to clone and override config fields
|
||||
withOverride := func(overrides map[string]string) map[string]string {
|
||||
cfg := make(map[string]string, len(validConfig))
|
||||
for k, v := range validConfig {
|
||||
cfg[k] = v
|
||||
}
|
||||
for k, v := range overrides {
|
||||
cfg[k] = v
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config map[string]string
|
||||
wantErr bool
|
||||
errSubstr string
|
||||
}{
|
||||
{
|
||||
name: "valid config succeeds",
|
||||
config: validConfig,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing appId",
|
||||
config: withOverride(map[string]string{"appId": ""}),
|
||||
wantErr: true,
|
||||
errSubstr: "appId",
|
||||
},
|
||||
{
|
||||
name: "missing privateKey",
|
||||
config: withOverride(map[string]string{"privateKey": ""}),
|
||||
wantErr: true,
|
||||
errSubstr: "privateKey",
|
||||
},
|
||||
{
|
||||
name: "nil config map returns error for appId",
|
||||
config: map[string]string{},
|
||||
wantErr: true,
|
||||
errSubstr: "appId",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, err := NewAlipay("test-instance", tt.config)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if tt.errSubstr != "" && !strings.Contains(err.Error(), tt.errSubstr) {
|
||||
t.Errorf("error %q should contain %q", err.Error(), tt.errSubstr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatal("expected non-nil Alipay instance")
|
||||
}
|
||||
if got.instanceID != "test-instance" {
|
||||
t.Errorf("instanceID = %q, want %q", got.instanceID, "test-instance")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
278
backend/internal/payment/provider/easypay.go
Normal file
278
backend/internal/payment/provider/easypay.go
Normal file
@@ -0,0 +1,278 @@
|
||||
// Package provider contains concrete payment provider implementations.
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
)
|
||||
|
||||
// EasyPay constants.
|
||||
const (
|
||||
easypayCodeSuccess = 1
|
||||
easypayStatusPaid = 1
|
||||
easypayHTTPTimeout = 10 * time.Second
|
||||
maxEasypayResponseSize = 1 << 20 // 1MB
|
||||
tradeStatusSuccess = "TRADE_SUCCESS"
|
||||
signTypeMD5 = "MD5"
|
||||
)
|
||||
|
||||
// EasyPay implements payment.Provider for the EasyPay aggregation platform.
|
||||
type EasyPay struct {
|
||||
instanceID string
|
||||
config map[string]string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewEasyPay creates a new EasyPay provider.
|
||||
// config keys: pid, pkey, apiBase, notifyUrl, returnUrl, cid, cidAlipay, cidWxpay
|
||||
func NewEasyPay(instanceID string, config map[string]string) (*EasyPay, error) {
|
||||
for _, k := range []string{"pid", "pkey", "apiBase", "notifyUrl", "returnUrl"} {
|
||||
if config[k] == "" {
|
||||
return nil, fmt.Errorf("easypay config missing required key: %s", k)
|
||||
}
|
||||
}
|
||||
return &EasyPay{
|
||||
instanceID: instanceID,
|
||||
config: config,
|
||||
httpClient: &http.Client{Timeout: easypayHTTPTimeout},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *EasyPay) Name() string { return "EasyPay" }
|
||||
func (e *EasyPay) ProviderKey() string { return payment.TypeEasyPay }
|
||||
func (e *EasyPay) SupportedTypes() []payment.PaymentType {
|
||||
return []payment.PaymentType{payment.TypeAlipay, payment.TypeWxpay}
|
||||
}
|
||||
|
||||
func (e *EasyPay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
|
||||
// Payment mode determined by instance config, not payment type.
|
||||
// "popup" → hosted page (submit.php); "qrcode"/default → API call (mapi.php).
|
||||
mode := e.config["paymentMode"]
|
||||
if mode == "popup" {
|
||||
return e.createRedirectPayment(req)
|
||||
}
|
||||
return e.createAPIPayment(ctx, req)
|
||||
}
|
||||
|
||||
// createRedirectPayment builds a submit.php URL for browser redirect.
|
||||
// No server-side API call — the user is redirected to EasyPay's hosted page.
|
||||
// TradeNo is empty; it arrives via the notify callback after payment.
|
||||
func (e *EasyPay) createRedirectPayment(req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
|
||||
notifyURL, returnURL := e.resolveURLs(req)
|
||||
params := map[string]string{
|
||||
"pid": e.config["pid"], "type": req.PaymentType,
|
||||
"out_trade_no": req.OrderID, "notify_url": notifyURL,
|
||||
"return_url": returnURL, "name": req.Subject,
|
||||
"money": req.Amount,
|
||||
}
|
||||
if cid := e.resolveCID(req.PaymentType); cid != "" {
|
||||
params["cid"] = cid
|
||||
}
|
||||
params["sign"] = easyPaySign(params, e.config["pkey"])
|
||||
params["sign_type"] = signTypeMD5
|
||||
|
||||
q := url.Values{}
|
||||
for k, v := range params {
|
||||
q.Set(k, v)
|
||||
}
|
||||
base := strings.TrimRight(e.config["apiBase"], "/")
|
||||
payURL := base + "/submit.php?" + q.Encode()
|
||||
return &payment.CreatePaymentResponse{PayURL: payURL}, nil
|
||||
}
|
||||
|
||||
// createAPIPayment calls mapi.php to get payurl/qrcode (existing behavior).
|
||||
func (e *EasyPay) createAPIPayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
|
||||
notifyURL, returnURL := e.resolveURLs(req)
|
||||
params := map[string]string{
|
||||
"pid": e.config["pid"], "type": req.PaymentType,
|
||||
"out_trade_no": req.OrderID, "notify_url": notifyURL,
|
||||
"return_url": returnURL, "name": req.Subject,
|
||||
"money": req.Amount, "clientip": req.ClientIP,
|
||||
}
|
||||
if cid := e.resolveCID(req.PaymentType); cid != "" {
|
||||
params["cid"] = cid
|
||||
}
|
||||
if req.IsMobile {
|
||||
params["device"] = "mobile"
|
||||
}
|
||||
params["sign"] = easyPaySign(params, e.config["pkey"])
|
||||
params["sign_type"] = signTypeMD5
|
||||
|
||||
body, err := e.post(ctx, strings.TrimRight(e.config["apiBase"], "/")+"/mapi.php", params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("easypay create: %w", err)
|
||||
}
|
||||
var resp struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
TradeNo string `json:"trade_no"`
|
||||
PayURL string `json:"payurl"`
|
||||
QRCode string `json:"qrcode"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, fmt.Errorf("easypay parse: %w", err)
|
||||
}
|
||||
if resp.Code != easypayCodeSuccess {
|
||||
return nil, fmt.Errorf("easypay error: %s", resp.Msg)
|
||||
}
|
||||
return &payment.CreatePaymentResponse{TradeNo: resp.TradeNo, PayURL: resp.PayURL, QRCode: resp.QRCode}, nil
|
||||
}
|
||||
|
||||
// resolveURLs returns (notifyURL, returnURL) preferring request values,
|
||||
// falling back to instance config.
|
||||
func (e *EasyPay) resolveURLs(req payment.CreatePaymentRequest) (string, string) {
|
||||
notifyURL := req.NotifyURL
|
||||
if notifyURL == "" {
|
||||
notifyURL = e.config["notifyUrl"]
|
||||
}
|
||||
returnURL := req.ReturnURL
|
||||
if returnURL == "" {
|
||||
returnURL = e.config["returnUrl"]
|
||||
}
|
||||
return notifyURL, returnURL
|
||||
}
|
||||
|
||||
func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
|
||||
params := map[string]string{
|
||||
"act": "order", "pid": e.config["pid"],
|
||||
"key": e.config["pkey"], "out_trade_no": tradeNo,
|
||||
}
|
||||
body, err := e.post(ctx, e.config["apiBase"]+"/api.php", params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("easypay query: %w", err)
|
||||
}
|
||||
var resp struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Status int `json:"status"`
|
||||
Money string `json:"money"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, fmt.Errorf("easypay parse query: %w", err)
|
||||
}
|
||||
status := payment.ProviderStatusPending
|
||||
if resp.Status == easypayStatusPaid {
|
||||
status = payment.ProviderStatusPaid
|
||||
}
|
||||
amount, _ := strconv.ParseFloat(resp.Money, 64)
|
||||
return &payment.QueryOrderResponse{TradeNo: tradeNo, Status: status, Amount: amount}, nil
|
||||
}
|
||||
|
||||
func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[string]string) (*payment.PaymentNotification, error) {
|
||||
values, err := url.ParseQuery(rawBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse notify: %w", err)
|
||||
}
|
||||
// url.ParseQuery already decodes values — no additional decode needed.
|
||||
params := make(map[string]string)
|
||||
for k := range values {
|
||||
params[k] = values.Get(k)
|
||||
}
|
||||
sign := params["sign"]
|
||||
if sign == "" {
|
||||
return nil, fmt.Errorf("missing sign")
|
||||
}
|
||||
if !easyPayVerifySign(params, e.config["pkey"], sign) {
|
||||
return nil, fmt.Errorf("invalid signature")
|
||||
}
|
||||
status := payment.ProviderStatusFailed
|
||||
if params["trade_status"] == tradeStatusSuccess {
|
||||
status = payment.ProviderStatusSuccess
|
||||
}
|
||||
amount, _ := strconv.ParseFloat(params["money"], 64)
|
||||
return &payment.PaymentNotification{
|
||||
TradeNo: params["trade_no"], OrderID: params["out_trade_no"],
|
||||
Amount: amount, Status: status, RawData: rawBody,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *EasyPay) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
|
||||
params := map[string]string{
|
||||
"pid": e.config["pid"], "key": e.config["pkey"],
|
||||
"trade_no": req.TradeNo, "out_trade_no": req.OrderID, "money": req.Amount,
|
||||
}
|
||||
body, err := e.post(ctx, e.config["apiBase"]+"/api.php?act=refund", params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("easypay refund: %w", err)
|
||||
}
|
||||
var resp struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, fmt.Errorf("easypay parse refund: %w", err)
|
||||
}
|
||||
if resp.Code != easypayCodeSuccess {
|
||||
return nil, fmt.Errorf("easypay refund failed: %s", resp.Msg)
|
||||
}
|
||||
return &payment.RefundResponse{RefundID: req.TradeNo, Status: payment.ProviderStatusSuccess}, nil
|
||||
}
|
||||
|
||||
func (e *EasyPay) resolveCID(paymentType string) string {
|
||||
if strings.HasPrefix(paymentType, "alipay") {
|
||||
if v := e.config["cidAlipay"]; v != "" {
|
||||
return v
|
||||
}
|
||||
return e.config["cid"]
|
||||
}
|
||||
if v := e.config["cidWxpay"]; v != "" {
|
||||
return v
|
||||
}
|
||||
return e.config["cid"]
|
||||
}
|
||||
|
||||
func (e *EasyPay) post(ctx context.Context, endpoint string, params map[string]string) ([]byte, error) {
|
||||
form := url.Values{}
|
||||
for k, v := range params {
|
||||
form.Set(k, v)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
resp, err := e.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
return io.ReadAll(io.LimitReader(resp.Body, maxEasypayResponseSize))
|
||||
}
|
||||
|
||||
func easyPaySign(params map[string]string, pkey string) string {
|
||||
keys := make([]string, 0, len(params))
|
||||
for k, v := range params {
|
||||
if k == "sign" || k == "sign_type" || v == "" {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
var buf strings.Builder
|
||||
for i, k := range keys {
|
||||
if i > 0 {
|
||||
_ = buf.WriteByte('&')
|
||||
}
|
||||
_, _ = buf.WriteString(k + "=" + params[k])
|
||||
}
|
||||
_, _ = buf.WriteString(pkey)
|
||||
hash := md5.Sum([]byte(buf.String()))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
func easyPayVerifySign(params map[string]string, pkey string, sign string) bool {
|
||||
return hmac.Equal([]byte(easyPaySign(params, pkey)), []byte(sign))
|
||||
}
|
||||
180
backend/internal/payment/provider/easypay_sign_test.go
Normal file
180
backend/internal/payment/provider/easypay_sign_test.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEasyPaySignConsistentOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
params := map[string]string{
|
||||
"pid": "1001",
|
||||
"type": "alipay",
|
||||
"out_trade_no": "ORDER123",
|
||||
"name": "Test Product",
|
||||
"money": "10.00",
|
||||
}
|
||||
pkey := "test_secret_key"
|
||||
|
||||
sign1 := easyPaySign(params, pkey)
|
||||
sign2 := easyPaySign(params, pkey)
|
||||
if sign1 != sign2 {
|
||||
t.Fatalf("easyPaySign should be deterministic: %q != %q", sign1, sign2)
|
||||
}
|
||||
if len(sign1) != 32 {
|
||||
t.Fatalf("MD5 hex should be 32 chars, got %d", len(sign1))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEasyPaySignExcludesSignAndSignType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pkey := "my_key"
|
||||
base := map[string]string{
|
||||
"pid": "1001",
|
||||
"type": "alipay",
|
||||
}
|
||||
withSign := map[string]string{
|
||||
"pid": "1001",
|
||||
"type": "alipay",
|
||||
"sign": "should_be_ignored",
|
||||
"sign_type": "MD5",
|
||||
}
|
||||
|
||||
signBase := easyPaySign(base, pkey)
|
||||
signWithExtra := easyPaySign(withSign, pkey)
|
||||
|
||||
if signBase != signWithExtra {
|
||||
t.Fatalf("sign and sign_type should be excluded: base=%q, withExtra=%q", signBase, signWithExtra)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEasyPaySignExcludesEmptyValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pkey := "key123"
|
||||
base := map[string]string{
|
||||
"pid": "1001",
|
||||
"type": "alipay",
|
||||
}
|
||||
withEmpty := map[string]string{
|
||||
"pid": "1001",
|
||||
"type": "alipay",
|
||||
"device": "",
|
||||
"clientip": "",
|
||||
}
|
||||
|
||||
signBase := easyPaySign(base, pkey)
|
||||
signWithEmpty := easyPaySign(withEmpty, pkey)
|
||||
|
||||
if signBase != signWithEmpty {
|
||||
t.Fatalf("empty values should be excluded: base=%q, withEmpty=%q", signBase, signWithEmpty)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEasyPayVerifySignValid(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
params := map[string]string{
|
||||
"pid": "1001",
|
||||
"type": "alipay",
|
||||
"out_trade_no": "ORDER456",
|
||||
"money": "25.00",
|
||||
}
|
||||
pkey := "secret"
|
||||
|
||||
sign := easyPaySign(params, pkey)
|
||||
|
||||
// Add sign to params (as would come in a real callback)
|
||||
params["sign"] = sign
|
||||
params["sign_type"] = "MD5"
|
||||
|
||||
if !easyPayVerifySign(params, pkey, sign) {
|
||||
t.Fatal("easyPayVerifySign should return true for a valid signature")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEasyPayVerifySignTampered(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
params := map[string]string{
|
||||
"pid": "1001",
|
||||
"type": "alipay",
|
||||
"out_trade_no": "ORDER789",
|
||||
"money": "50.00",
|
||||
}
|
||||
pkey := "secret"
|
||||
|
||||
sign := easyPaySign(params, pkey)
|
||||
|
||||
// Tamper with the amount
|
||||
params["money"] = "99.99"
|
||||
|
||||
if easyPayVerifySign(params, pkey, sign) {
|
||||
t.Fatal("easyPayVerifySign should return false for tampered params")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEasyPayVerifySignWrongKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
params := map[string]string{
|
||||
"pid": "1001",
|
||||
"type": "wxpay",
|
||||
}
|
||||
|
||||
sign := easyPaySign(params, "correct_key")
|
||||
|
||||
if easyPayVerifySign(params, "wrong_key", sign) {
|
||||
t.Fatal("easyPayVerifySign should return false with wrong key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEasyPaySignEmptyParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sign := easyPaySign(map[string]string{}, "key123")
|
||||
if sign == "" {
|
||||
t.Fatal("easyPaySign with empty params should still produce a hash")
|
||||
}
|
||||
if len(sign) != 32 {
|
||||
t.Fatalf("MD5 hex should be 32 chars, got %d", len(sign))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEasyPaySignSortOrder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pkey := "test_key"
|
||||
params1 := map[string]string{
|
||||
"a": "1",
|
||||
"b": "2",
|
||||
"c": "3",
|
||||
}
|
||||
params2 := map[string]string{
|
||||
"c": "3",
|
||||
"a": "1",
|
||||
"b": "2",
|
||||
}
|
||||
|
||||
sign1 := easyPaySign(params1, pkey)
|
||||
sign2 := easyPaySign(params2, pkey)
|
||||
|
||||
if sign1 != sign2 {
|
||||
t.Fatalf("easyPaySign should be order-independent: %q != %q", sign1, sign2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEasyPayVerifySignWrongSignValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
params := map[string]string{
|
||||
"pid": "1001",
|
||||
"type": "alipay",
|
||||
}
|
||||
pkey := "key"
|
||||
|
||||
if easyPayVerifySign(params, pkey, "00000000000000000000000000000000") {
|
||||
t.Fatal("easyPayVerifySign should return false for an incorrect sign value")
|
||||
}
|
||||
}
|
||||
23
backend/internal/payment/provider/factory.go
Normal file
23
backend/internal/payment/provider/factory.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
)
|
||||
|
||||
// CreateProvider creates a Provider from a provider key, instance ID and decrypted config.
|
||||
func CreateProvider(providerKey string, instanceID string, config map[string]string) (payment.Provider, error) {
|
||||
switch providerKey {
|
||||
case "easypay":
|
||||
return NewEasyPay(instanceID, config)
|
||||
case "alipay":
|
||||
return NewAlipay(instanceID, config)
|
||||
case "wxpay":
|
||||
return NewWxpay(instanceID, config)
|
||||
case "stripe":
|
||||
return NewStripe(instanceID, config)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown provider key: %s", providerKey)
|
||||
}
|
||||
}
|
||||
262
backend/internal/payment/provider/stripe.go
Normal file
262
backend/internal/payment/provider/stripe.go
Normal file
@@ -0,0 +1,262 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
stripe "github.com/stripe/stripe-go/v85"
|
||||
"github.com/stripe/stripe-go/v85/webhook"
|
||||
)
|
||||
|
||||
// Stripe constants.
|
||||
const (
|
||||
stripeCurrency = "cny"
|
||||
stripeEventPaymentSuccess = "payment_intent.succeeded"
|
||||
stripeEventPaymentFailed = "payment_intent.payment_failed"
|
||||
)
|
||||
|
||||
// Stripe implements the payment.CancelableProvider interface for Stripe payments.
|
||||
type Stripe struct {
|
||||
instanceID string
|
||||
config map[string]string
|
||||
|
||||
mu sync.Mutex
|
||||
initialized bool
|
||||
sc *stripe.Client
|
||||
}
|
||||
|
||||
// NewStripe creates a new Stripe provider instance.
|
||||
func NewStripe(instanceID string, config map[string]string) (*Stripe, error) {
|
||||
if config["secretKey"] == "" {
|
||||
return nil, fmt.Errorf("stripe config missing required key: secretKey")
|
||||
}
|
||||
return &Stripe{
|
||||
instanceID: instanceID,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Stripe) ensureInit() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if !s.initialized {
|
||||
s.sc = stripe.NewClient(s.config["secretKey"])
|
||||
s.initialized = true
|
||||
}
|
||||
}
|
||||
|
||||
// GetPublishableKey returns the publishable key for frontend use.
|
||||
func (s *Stripe) GetPublishableKey() string {
|
||||
return s.config["publishableKey"]
|
||||
}
|
||||
|
||||
func (s *Stripe) Name() string { return "Stripe" }
|
||||
func (s *Stripe) ProviderKey() string { return payment.TypeStripe }
|
||||
func (s *Stripe) SupportedTypes() []payment.PaymentType {
|
||||
return []payment.PaymentType{payment.TypeStripe}
|
||||
}
|
||||
|
||||
// stripePaymentMethodTypes maps our PaymentType to Stripe payment_method_types.
|
||||
var stripePaymentMethodTypes = map[string][]string{
|
||||
payment.TypeCard: {"card"},
|
||||
payment.TypeAlipay: {"alipay"},
|
||||
payment.TypeWxpay: {"wechat_pay"},
|
||||
payment.TypeLink: {"link"},
|
||||
}
|
||||
|
||||
// CreatePayment creates a Stripe PaymentIntent.
|
||||
func (s *Stripe) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
|
||||
s.ensureInit()
|
||||
|
||||
amountInCents, err := payment.YuanToFen(req.Amount)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stripe create payment: %w", err)
|
||||
}
|
||||
|
||||
// Collect all Stripe payment_method_types from the instance's configured sub-methods
|
||||
methods := resolveStripeMethodTypes(req.InstanceSubMethods)
|
||||
|
||||
pmTypes := make([]*string, len(methods))
|
||||
for i, m := range methods {
|
||||
pmTypes[i] = stripe.String(m)
|
||||
}
|
||||
|
||||
params := &stripe.PaymentIntentCreateParams{
|
||||
Amount: stripe.Int64(amountInCents),
|
||||
Currency: stripe.String(stripeCurrency),
|
||||
PaymentMethodTypes: pmTypes,
|
||||
Description: stripe.String(req.Subject),
|
||||
Metadata: map[string]string{"orderId": req.OrderID},
|
||||
}
|
||||
|
||||
// WeChat Pay requires payment_method_options with client type
|
||||
if hasStripeMethod(methods, "wechat_pay") {
|
||||
params.PaymentMethodOptions = &stripe.PaymentIntentCreatePaymentMethodOptionsParams{
|
||||
WeChatPay: &stripe.PaymentIntentCreatePaymentMethodOptionsWeChatPayParams{
|
||||
Client: stripe.String("web"),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
params.SetIdempotencyKey(fmt.Sprintf("pi-%s", req.OrderID))
|
||||
params.Context = ctx
|
||||
|
||||
pi, err := s.sc.V1PaymentIntents.Create(ctx, params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stripe create payment: %w", err)
|
||||
}
|
||||
|
||||
return &payment.CreatePaymentResponse{
|
||||
TradeNo: pi.ID,
|
||||
ClientSecret: pi.ClientSecret,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// QueryOrder retrieves a PaymentIntent by ID.
|
||||
func (s *Stripe) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
|
||||
s.ensureInit()
|
||||
|
||||
pi, err := s.sc.V1PaymentIntents.Retrieve(ctx, tradeNo, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stripe query order: %w", err)
|
||||
}
|
||||
|
||||
status := payment.ProviderStatusPending
|
||||
switch pi.Status {
|
||||
case stripe.PaymentIntentStatusSucceeded:
|
||||
status = payment.ProviderStatusPaid
|
||||
case stripe.PaymentIntentStatusCanceled:
|
||||
status = payment.ProviderStatusFailed
|
||||
}
|
||||
|
||||
return &payment.QueryOrderResponse{
|
||||
TradeNo: pi.ID,
|
||||
Status: status,
|
||||
Amount: payment.FenToYuan(pi.Amount),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// VerifyNotification verifies a Stripe webhook event.
|
||||
func (s *Stripe) VerifyNotification(_ context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) {
|
||||
s.ensureInit()
|
||||
|
||||
webhookSecret := s.config["webhookSecret"]
|
||||
if webhookSecret == "" {
|
||||
return nil, fmt.Errorf("stripe webhookSecret not configured")
|
||||
}
|
||||
|
||||
sig := headers["stripe-signature"]
|
||||
if sig == "" {
|
||||
return nil, fmt.Errorf("stripe notification missing stripe-signature header")
|
||||
}
|
||||
|
||||
event, err := webhook.ConstructEvent([]byte(rawBody), sig, webhookSecret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stripe verify notification: %w", err)
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case stripeEventPaymentSuccess:
|
||||
return parseStripePaymentIntent(&event, payment.ProviderStatusSuccess, rawBody)
|
||||
case stripeEventPaymentFailed:
|
||||
return parseStripePaymentIntent(&event, payment.ProviderStatusFailed, rawBody)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func parseStripePaymentIntent(event *stripe.Event, status string, rawBody string) (*payment.PaymentNotification, error) {
|
||||
var pi stripe.PaymentIntent
|
||||
if err := json.Unmarshal(event.Data.Raw, &pi); err != nil {
|
||||
return nil, fmt.Errorf("stripe parse payment_intent: %w", err)
|
||||
}
|
||||
return &payment.PaymentNotification{
|
||||
TradeNo: pi.ID,
|
||||
OrderID: pi.Metadata["orderId"],
|
||||
Amount: payment.FenToYuan(pi.Amount),
|
||||
Status: status,
|
||||
RawData: rawBody,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Refund creates a Stripe refund.
|
||||
func (s *Stripe) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
|
||||
s.ensureInit()
|
||||
|
||||
amountInCents, err := payment.YuanToFen(req.Amount)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stripe refund: %w", err)
|
||||
}
|
||||
|
||||
params := &stripe.RefundCreateParams{
|
||||
PaymentIntent: stripe.String(req.TradeNo),
|
||||
Amount: stripe.Int64(amountInCents),
|
||||
Reason: stripe.String(string(stripe.RefundReasonRequestedByCustomer)),
|
||||
}
|
||||
params.Context = ctx
|
||||
|
||||
r, err := s.sc.V1Refunds.Create(ctx, params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stripe refund: %w", err)
|
||||
}
|
||||
|
||||
refundStatus := payment.ProviderStatusPending
|
||||
if r.Status == stripe.RefundStatusSucceeded {
|
||||
refundStatus = payment.ProviderStatusSuccess
|
||||
}
|
||||
|
||||
return &payment.RefundResponse{
|
||||
RefundID: r.ID,
|
||||
Status: refundStatus,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// resolveStripeMethodTypes converts instance supported_types (comma-separated)
|
||||
// into Stripe API payment_method_types. Falls back to ["card"] if empty.
|
||||
func resolveStripeMethodTypes(instanceSubMethods string) []string {
|
||||
if instanceSubMethods == "" {
|
||||
return []string{"card"}
|
||||
}
|
||||
var methods []string
|
||||
for _, t := range strings.Split(instanceSubMethods, ",") {
|
||||
t = strings.TrimSpace(t)
|
||||
if mapped, ok := stripePaymentMethodTypes[t]; ok {
|
||||
methods = append(methods, mapped...)
|
||||
}
|
||||
}
|
||||
if len(methods) == 0 {
|
||||
return []string{"card"}
|
||||
}
|
||||
return methods
|
||||
}
|
||||
|
||||
// hasStripeMethod checks if the given Stripe method list contains the target method.
|
||||
func hasStripeMethod(methods []string, target string) bool {
|
||||
for _, m := range methods {
|
||||
if m == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// CancelPayment cancels a pending PaymentIntent.
|
||||
func (s *Stripe) CancelPayment(ctx context.Context, tradeNo string) error {
|
||||
s.ensureInit()
|
||||
|
||||
_, err := s.sc.V1PaymentIntents.Cancel(ctx, tradeNo, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stripe cancel payment: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure interface compliance.
|
||||
var (
|
||||
_ payment.Provider = (*Stripe)(nil)
|
||||
_ payment.CancelableProvider = (*Stripe)(nil)
|
||||
)
|
||||
350
backend/internal/payment/provider/wxpay.go
Normal file
350
backend/internal/payment/provider/wxpay.go
Normal file
@@ -0,0 +1,350 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
"github.com/wechatpay-apiv3/wechatpay-go/core"
|
||||
"github.com/wechatpay-apiv3/wechatpay-go/core/auth/verifiers"
|
||||
"github.com/wechatpay-apiv3/wechatpay-go/core/notify"
|
||||
"github.com/wechatpay-apiv3/wechatpay-go/core/option"
|
||||
"github.com/wechatpay-apiv3/wechatpay-go/services/payments"
|
||||
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/h5"
|
||||
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/native"
|
||||
"github.com/wechatpay-apiv3/wechatpay-go/services/refunddomestic"
|
||||
"github.com/wechatpay-apiv3/wechatpay-go/utils"
|
||||
)
|
||||
|
||||
// WeChat Pay constants.
|
||||
const (
|
||||
wxpayCurrency = "CNY"
|
||||
wxpayH5Type = "Wap"
|
||||
)
|
||||
|
||||
// WeChat Pay trade states.
|
||||
const (
|
||||
wxpayTradeStateSuccess = "SUCCESS"
|
||||
wxpayTradeStateRefund = "REFUND"
|
||||
wxpayTradeStateClosed = "CLOSED"
|
||||
wxpayTradeStatePayError = "PAYERROR"
|
||||
)
|
||||
|
||||
// WeChat Pay notification event types.
|
||||
const (
|
||||
wxpayEventTransactionSuccess = "TRANSACTION.SUCCESS"
|
||||
)
|
||||
|
||||
// WeChat Pay error codes.
|
||||
const (
|
||||
wxpayErrNoAuth = "NO_AUTH"
|
||||
)
|
||||
|
||||
type Wxpay struct {
|
||||
instanceID string
|
||||
config map[string]string
|
||||
mu sync.Mutex
|
||||
coreClient *core.Client
|
||||
notifyHandler *notify.Handler
|
||||
}
|
||||
|
||||
func NewWxpay(instanceID string, config map[string]string) (*Wxpay, error) {
|
||||
required := []string{"appId", "mchId", "privateKey", "apiV3Key", "publicKey", "publicKeyId", "certSerial"}
|
||||
for _, k := range required {
|
||||
if config[k] == "" {
|
||||
return nil, fmt.Errorf("wxpay config missing required key: %s", k)
|
||||
}
|
||||
}
|
||||
if len(config["apiV3Key"]) != 32 {
|
||||
return nil, fmt.Errorf("wxpay apiV3Key must be exactly 32 bytes, got %d", len(config["apiV3Key"]))
|
||||
}
|
||||
return &Wxpay{instanceID: instanceID, config: config}, nil
|
||||
}
|
||||
|
||||
func (w *Wxpay) Name() string { return "Wxpay" }
|
||||
func (w *Wxpay) ProviderKey() string { return payment.TypeWxpay }
|
||||
func (w *Wxpay) SupportedTypes() []payment.PaymentType {
|
||||
return []payment.PaymentType{payment.TypeWxpayDirect}
|
||||
}
|
||||
|
||||
func formatPEM(key, keyType string) string {
|
||||
key = strings.TrimSpace(key)
|
||||
if strings.HasPrefix(key, "-----BEGIN") {
|
||||
return key
|
||||
}
|
||||
return fmt.Sprintf("-----BEGIN %s-----\n%s\n-----END %s-----", keyType, key, keyType)
|
||||
}
|
||||
|
||||
func (w *Wxpay) ensureClient() (*core.Client, error) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if w.coreClient != nil {
|
||||
return w.coreClient, nil
|
||||
}
|
||||
privateKey, publicKey, err := w.loadKeyPair()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
certSerial := w.config["certSerial"]
|
||||
verifier := verifiers.NewSHA256WithRSAPubkeyVerifier(w.config["publicKeyId"], *publicKey)
|
||||
client, err := core.NewClient(context.Background(),
|
||||
option.WithMerchantCredential(w.config["mchId"], certSerial, privateKey),
|
||||
option.WithVerifier(verifier))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("wxpay init client: %w", err)
|
||||
}
|
||||
handler, err := notify.NewRSANotifyHandler(w.config["apiV3Key"], verifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("wxpay init notify handler: %w", err)
|
||||
}
|
||||
w.notifyHandler = handler
|
||||
w.coreClient = client
|
||||
return w.coreClient, nil
|
||||
}
|
||||
|
||||
func (w *Wxpay) loadKeyPair() (*rsa.PrivateKey, *rsa.PublicKey, error) {
|
||||
privateKey, err := utils.LoadPrivateKey(formatPEM(w.config["privateKey"], "PRIVATE KEY"))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("wxpay load private key: %w", err)
|
||||
}
|
||||
publicKey, err := utils.LoadPublicKey(formatPEM(w.config["publicKey"], "PUBLIC KEY"))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("wxpay load public key: %w", err)
|
||||
}
|
||||
return privateKey, publicKey, nil
|
||||
}
|
||||
|
||||
func (w *Wxpay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
|
||||
client, err := w.ensureClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Request-first, config-fallback (consistent with EasyPay/Alipay)
|
||||
notifyURL := req.NotifyURL
|
||||
if notifyURL == "" {
|
||||
notifyURL = w.config["notifyUrl"]
|
||||
}
|
||||
if notifyURL == "" {
|
||||
return nil, fmt.Errorf("wxpay notifyUrl is required")
|
||||
}
|
||||
totalFen, err := payment.YuanToFen(req.Amount)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("wxpay create payment: %w", err)
|
||||
}
|
||||
if req.IsMobile && req.ClientIP != "" {
|
||||
resp, err := w.createOrder(ctx, client, req, notifyURL, totalFen, true)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
if !strings.Contains(err.Error(), wxpayErrNoAuth) {
|
||||
return nil, err
|
||||
}
|
||||
slog.Warn("wxpay H5 payment not authorized, falling back to native", "order", req.OrderID)
|
||||
}
|
||||
return w.createOrder(ctx, client, req, notifyURL, totalFen, false)
|
||||
}
|
||||
|
||||
func (w *Wxpay) createOrder(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64, useH5 bool) (*payment.CreatePaymentResponse, error) {
|
||||
if useH5 {
|
||||
return w.prepayH5(ctx, c, req, notifyURL, totalFen)
|
||||
}
|
||||
return w.prepayNative(ctx, c, req, notifyURL, totalFen)
|
||||
}
|
||||
|
||||
func (w *Wxpay) prepayNative(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64) (*payment.CreatePaymentResponse, error) {
|
||||
svc := native.NativeApiService{Client: c}
|
||||
cur := wxpayCurrency
|
||||
resp, _, err := svc.Prepay(ctx, native.PrepayRequest{
|
||||
Appid: core.String(w.config["appId"]), Mchid: core.String(w.config["mchId"]),
|
||||
Description: core.String(req.Subject), OutTradeNo: core.String(req.OrderID),
|
||||
NotifyUrl: core.String(notifyURL),
|
||||
Amount: &native.Amount{Total: core.Int64(totalFen), Currency: &cur},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("wxpay native prepay: %w", err)
|
||||
}
|
||||
codeURL := ""
|
||||
if resp.CodeUrl != nil {
|
||||
codeURL = *resp.CodeUrl
|
||||
}
|
||||
return &payment.CreatePaymentResponse{TradeNo: req.OrderID, QRCode: codeURL}, nil
|
||||
}
|
||||
|
||||
func (w *Wxpay) prepayH5(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64) (*payment.CreatePaymentResponse, error) {
|
||||
svc := h5.H5ApiService{Client: c}
|
||||
cur := wxpayCurrency
|
||||
tp := wxpayH5Type
|
||||
resp, _, err := svc.Prepay(ctx, h5.PrepayRequest{
|
||||
Appid: core.String(w.config["appId"]), Mchid: core.String(w.config["mchId"]),
|
||||
Description: core.String(req.Subject), OutTradeNo: core.String(req.OrderID),
|
||||
NotifyUrl: core.String(notifyURL),
|
||||
Amount: &h5.Amount{Total: core.Int64(totalFen), Currency: &cur},
|
||||
SceneInfo: &h5.SceneInfo{PayerClientIp: core.String(req.ClientIP), H5Info: &h5.H5Info{Type: &tp}},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("wxpay h5 prepay: %w", err)
|
||||
}
|
||||
h5URL := ""
|
||||
if resp.H5Url != nil {
|
||||
h5URL = *resp.H5Url
|
||||
}
|
||||
return &payment.CreatePaymentResponse{TradeNo: req.OrderID, PayURL: h5URL}, nil
|
||||
}
|
||||
|
||||
func wxSV(s *string) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return *s
|
||||
}
|
||||
|
||||
func mapWxState(s string) string {
|
||||
switch s {
|
||||
case wxpayTradeStateSuccess:
|
||||
return payment.ProviderStatusPaid
|
||||
case wxpayTradeStateRefund:
|
||||
return payment.ProviderStatusRefunded
|
||||
case wxpayTradeStateClosed, wxpayTradeStatePayError:
|
||||
return payment.ProviderStatusFailed
|
||||
default:
|
||||
return payment.ProviderStatusPending
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Wxpay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
|
||||
c, err := w.ensureClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
svc := native.NativeApiService{Client: c}
|
||||
tx, _, err := svc.QueryOrderByOutTradeNo(ctx, native.QueryOrderByOutTradeNoRequest{
|
||||
OutTradeNo: core.String(tradeNo), Mchid: core.String(w.config["mchId"]),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("wxpay query order: %w", err)
|
||||
}
|
||||
var amt float64
|
||||
if tx.Amount != nil && tx.Amount.Total != nil {
|
||||
amt = payment.FenToYuan(*tx.Amount.Total)
|
||||
}
|
||||
id := tradeNo
|
||||
if tx.TransactionId != nil {
|
||||
id = *tx.TransactionId
|
||||
}
|
||||
pa := ""
|
||||
if tx.SuccessTime != nil {
|
||||
pa = *tx.SuccessTime
|
||||
}
|
||||
return &payment.QueryOrderResponse{TradeNo: id, Status: mapWxState(wxSV(tx.TradeState)), Amount: amt, PaidAt: pa}, nil
|
||||
}
|
||||
|
||||
func (w *Wxpay) VerifyNotification(ctx context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) {
|
||||
if _, err := w.ensureClient(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := http.NewRequestWithContext(ctx, http.MethodPost, "/", io.NopCloser(bytes.NewBufferString(rawBody)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("wxpay construct request: %w", err)
|
||||
}
|
||||
for k, v := range headers {
|
||||
r.Header.Set(k, v)
|
||||
}
|
||||
var tx payments.Transaction
|
||||
nr, err := w.notifyHandler.ParseNotifyRequest(ctx, r, &tx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("wxpay verify notification: %w", err)
|
||||
}
|
||||
if nr.EventType != wxpayEventTransactionSuccess {
|
||||
return nil, nil
|
||||
}
|
||||
var amt float64
|
||||
if tx.Amount != nil && tx.Amount.Total != nil {
|
||||
amt = payment.FenToYuan(*tx.Amount.Total)
|
||||
}
|
||||
st := payment.ProviderStatusFailed
|
||||
if wxSV(tx.TradeState) == wxpayTradeStateSuccess {
|
||||
st = payment.ProviderStatusSuccess
|
||||
}
|
||||
return &payment.PaymentNotification{
|
||||
TradeNo: wxSV(tx.TransactionId), OrderID: wxSV(tx.OutTradeNo),
|
||||
Amount: amt, Status: st, RawData: rawBody,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *Wxpay) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
|
||||
c, err := w.ensureClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rf, err := payment.YuanToFen(req.Amount)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("wxpay refund amount: %w", err)
|
||||
}
|
||||
tf, err := w.queryOrderTotalFen(ctx, c, req.OrderID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rs := refunddomestic.RefundsApiService{Client: c}
|
||||
cur := wxpayCurrency
|
||||
res, _, err := rs.Create(ctx, refunddomestic.CreateRequest{
|
||||
OutTradeNo: core.String(req.OrderID),
|
||||
OutRefundNo: core.String(fmt.Sprintf("%s-refund-%d", req.OrderID, time.Now().UnixNano())),
|
||||
Reason: core.String(req.Reason),
|
||||
Amount: &refunddomestic.AmountReq{Refund: core.Int64(rf), Total: core.Int64(tf), Currency: &cur},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("wxpay refund: %w", err)
|
||||
}
|
||||
rid := wxSV(res.RefundId)
|
||||
if rid == "" {
|
||||
rid = fmt.Sprintf("%s-refund", req.OrderID)
|
||||
}
|
||||
st := payment.ProviderStatusPending
|
||||
if res.Status != nil && *res.Status == refunddomestic.STATUS_SUCCESS {
|
||||
st = payment.ProviderStatusSuccess
|
||||
}
|
||||
return &payment.RefundResponse{RefundID: rid, Status: st}, nil
|
||||
}
|
||||
|
||||
func (w *Wxpay) queryOrderTotalFen(ctx context.Context, c *core.Client, orderID string) (int64, error) {
|
||||
svc := native.NativeApiService{Client: c}
|
||||
tx, _, err := svc.QueryOrderByOutTradeNo(ctx, native.QueryOrderByOutTradeNoRequest{
|
||||
OutTradeNo: core.String(orderID), Mchid: core.String(w.config["mchId"]),
|
||||
})
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("wxpay refund query order: %w", err)
|
||||
}
|
||||
var tf int64
|
||||
if tx.Amount != nil && tx.Amount.Total != nil {
|
||||
tf = *tx.Amount.Total
|
||||
}
|
||||
return tf, nil
|
||||
}
|
||||
|
||||
func (w *Wxpay) CancelPayment(ctx context.Context, tradeNo string) error {
|
||||
c, err := w.ensureClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
svc := native.NativeApiService{Client: c}
|
||||
_, err = svc.CloseOrder(ctx, native.CloseOrderRequest{
|
||||
OutTradeNo: core.String(tradeNo), Mchid: core.String(w.config["mchId"]),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("wxpay cancel payment: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
_ payment.Provider = (*Wxpay)(nil)
|
||||
_ payment.CancelableProvider = (*Wxpay)(nil)
|
||||
)
|
||||
259
backend/internal/payment/provider/wxpay_test.go
Normal file
259
backend/internal/payment/provider/wxpay_test.go
Normal file
@@ -0,0 +1,259 @@
|
||||
//go:build unit
|
||||
|
||||
package provider
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
)
|
||||
|
||||
func TestMapWxState(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "SUCCESS maps to paid",
|
||||
input: wxpayTradeStateSuccess,
|
||||
want: payment.ProviderStatusPaid,
|
||||
},
|
||||
{
|
||||
name: "REFUND maps to refunded",
|
||||
input: wxpayTradeStateRefund,
|
||||
want: payment.ProviderStatusRefunded,
|
||||
},
|
||||
{
|
||||
name: "CLOSED maps to failed",
|
||||
input: wxpayTradeStateClosed,
|
||||
want: payment.ProviderStatusFailed,
|
||||
},
|
||||
{
|
||||
name: "PAYERROR maps to failed",
|
||||
input: wxpayTradeStatePayError,
|
||||
want: payment.ProviderStatusFailed,
|
||||
},
|
||||
{
|
||||
name: "unknown state maps to pending",
|
||||
input: "NOTPAY",
|
||||
want: payment.ProviderStatusPending,
|
||||
},
|
||||
{
|
||||
name: "empty string maps to pending",
|
||||
input: "",
|
||||
want: payment.ProviderStatusPending,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := mapWxState(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("mapWxState(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWxSV(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input *string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "nil pointer returns empty string",
|
||||
input: nil,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "non-nil pointer returns value",
|
||||
input: strPtr("hello"),
|
||||
want: "hello",
|
||||
},
|
||||
{
|
||||
name: "pointer to empty string returns empty string",
|
||||
input: strPtr(""),
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := wxSV(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("wxSV() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func strPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func TestFormatPEM(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
keyType string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "raw key gets wrapped with headers",
|
||||
key: "MIIBIjANBgkqhki...",
|
||||
keyType: "PUBLIC KEY",
|
||||
want: "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhki...\n-----END PUBLIC KEY-----",
|
||||
},
|
||||
{
|
||||
name: "already formatted key is returned as-is",
|
||||
key: "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBg...\n-----END PRIVATE KEY-----",
|
||||
keyType: "PRIVATE KEY",
|
||||
want: "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBg...\n-----END PRIVATE KEY-----",
|
||||
},
|
||||
{
|
||||
name: "key with leading/trailing whitespace is trimmed before check",
|
||||
key: " \n MIIBIjANBgkqhki... \n ",
|
||||
keyType: "PUBLIC KEY",
|
||||
want: "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhki...\n-----END PUBLIC KEY-----",
|
||||
},
|
||||
{
|
||||
name: "already formatted key with whitespace is trimmed and returned",
|
||||
key: " -----BEGIN RSA PRIVATE KEY-----\ndata\n-----END RSA PRIVATE KEY----- ",
|
||||
keyType: "RSA PRIVATE KEY",
|
||||
want: "-----BEGIN RSA PRIVATE KEY-----\ndata\n-----END RSA PRIVATE KEY-----",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := formatPEM(tt.key, tt.keyType)
|
||||
if got != tt.want {
|
||||
t.Errorf("formatPEM(%q, %q) =\n%s\nwant:\n%s", tt.key, tt.keyType, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWxpay(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
validConfig := map[string]string{
|
||||
"appId": "wx1234567890",
|
||||
"mchId": "1234567890",
|
||||
"privateKey": "fake-private-key",
|
||||
"apiV3Key": "12345678901234567890123456789012", // exactly 32 bytes
|
||||
"publicKey": "fake-public-key",
|
||||
"publicKeyId": "key-id-001",
|
||||
"certSerial": "SERIAL001",
|
||||
}
|
||||
|
||||
// helper to clone and override config fields
|
||||
withOverride := func(overrides map[string]string) map[string]string {
|
||||
cfg := make(map[string]string, len(validConfig))
|
||||
for k, v := range validConfig {
|
||||
cfg[k] = v
|
||||
}
|
||||
for k, v := range overrides {
|
||||
cfg[k] = v
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config map[string]string
|
||||
wantErr bool
|
||||
errSubstr string
|
||||
}{
|
||||
{
|
||||
name: "valid config succeeds",
|
||||
config: validConfig,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing appId",
|
||||
config: withOverride(map[string]string{"appId": ""}),
|
||||
wantErr: true,
|
||||
errSubstr: "appId",
|
||||
},
|
||||
{
|
||||
name: "missing mchId",
|
||||
config: withOverride(map[string]string{"mchId": ""}),
|
||||
wantErr: true,
|
||||
errSubstr: "mchId",
|
||||
},
|
||||
{
|
||||
name: "missing privateKey",
|
||||
config: withOverride(map[string]string{"privateKey": ""}),
|
||||
wantErr: true,
|
||||
errSubstr: "privateKey",
|
||||
},
|
||||
{
|
||||
name: "missing apiV3Key",
|
||||
config: withOverride(map[string]string{"apiV3Key": ""}),
|
||||
wantErr: true,
|
||||
errSubstr: "apiV3Key",
|
||||
},
|
||||
{
|
||||
name: "missing publicKey",
|
||||
config: withOverride(map[string]string{"publicKey": ""}),
|
||||
wantErr: true,
|
||||
errSubstr: "publicKey",
|
||||
},
|
||||
{
|
||||
name: "missing publicKeyId",
|
||||
config: withOverride(map[string]string{"publicKeyId": ""}),
|
||||
wantErr: true,
|
||||
errSubstr: "publicKeyId",
|
||||
},
|
||||
{
|
||||
name: "apiV3Key too short",
|
||||
config: withOverride(map[string]string{"apiV3Key": "short"}),
|
||||
wantErr: true,
|
||||
errSubstr: "exactly 32 bytes",
|
||||
},
|
||||
{
|
||||
name: "apiV3Key too long",
|
||||
config: withOverride(map[string]string{"apiV3Key": "123456789012345678901234567890123"}), // 33 bytes
|
||||
wantErr: true,
|
||||
errSubstr: "exactly 32 bytes",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, err := NewWxpay("test-instance", tt.config)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if tt.errSubstr != "" && !strings.Contains(err.Error(), tt.errSubstr) {
|
||||
t.Errorf("error %q should contain %q", err.Error(), tt.errSubstr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatal("expected non-nil Wxpay instance")
|
||||
}
|
||||
if got.instanceID != "test-instance" {
|
||||
t.Errorf("instanceID = %q, want %q", got.instanceID, "test-instance")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
85
backend/internal/payment/registry.go
Normal file
85
backend/internal/payment/registry.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package payment
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
// Registry is a thread-safe registry mapping PaymentType to Provider.
|
||||
type Registry struct {
|
||||
mu sync.RWMutex
|
||||
providers map[PaymentType]Provider
|
||||
}
|
||||
|
||||
// ErrProviderNotFound is returned when a requested payment provider is not registered.
|
||||
var ErrProviderNotFound = infraerrors.NotFound("PROVIDER_NOT_FOUND", "payment provider not registered")
|
||||
|
||||
// NewRegistry creates a new empty provider registry.
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{
|
||||
providers: make(map[PaymentType]Provider),
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds a provider for each of its supported payment types.
|
||||
// If a type was previously registered, it is overwritten.
|
||||
func (r *Registry) Register(p Provider) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, t := range p.SupportedTypes() {
|
||||
r.providers[t] = p
|
||||
}
|
||||
}
|
||||
|
||||
// GetProvider returns the provider registered for the given payment type.
|
||||
func (r *Registry) GetProvider(t PaymentType) (Provider, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
p, ok := r.providers[t]
|
||||
if !ok {
|
||||
return nil, ErrProviderNotFound
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// GetProviderByKey returns the first provider whose ProviderKey matches the given key.
|
||||
func (r *Registry) GetProviderByKey(key string) (Provider, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
for _, p := range r.providers {
|
||||
if p.ProviderKey() == key {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
return nil, ErrProviderNotFound
|
||||
}
|
||||
|
||||
// GetProviderKey returns the provider key for the given payment type, or empty string if not found.
|
||||
func (r *Registry) GetProviderKey(t PaymentType) string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
p, ok := r.providers[t]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return p.ProviderKey()
|
||||
}
|
||||
|
||||
// SupportedTypes returns all currently registered payment types.
|
||||
func (r *Registry) SupportedTypes() []PaymentType {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
types := make([]PaymentType, 0, len(r.providers))
|
||||
for t := range r.providers {
|
||||
types = append(types, t)
|
||||
}
|
||||
return types
|
||||
}
|
||||
|
||||
// Clear removes all registered providers.
|
||||
func (r *Registry) Clear() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.providers = make(map[PaymentType]Provider)
|
||||
}
|
||||
234
backend/internal/payment/registry_test.go
Normal file
234
backend/internal/payment/registry_test.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package payment
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// mockProvider implements the Provider interface for testing.
|
||||
type mockProvider struct {
|
||||
name string
|
||||
key string
|
||||
supportedTypes []PaymentType
|
||||
}
|
||||
|
||||
func (m *mockProvider) Name() string { return m.name }
|
||||
func (m *mockProvider) ProviderKey() string { return m.key }
|
||||
func (m *mockProvider) SupportedTypes() []PaymentType { return m.supportedTypes }
|
||||
func (m *mockProvider) CreatePayment(_ context.Context, _ CreatePaymentRequest) (*CreatePaymentResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockProvider) QueryOrder(_ context.Context, _ string) (*QueryOrderResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockProvider) VerifyNotification(_ context.Context, _ string, _ map[string]string) (*PaymentNotification, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockProvider) Refund(_ context.Context, _ RefundRequest) (*RefundResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestRegistryRegisterAndGetProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := NewRegistry()
|
||||
|
||||
p := &mockProvider{
|
||||
name: "TestPay",
|
||||
key: "testpay",
|
||||
supportedTypes: []PaymentType{TypeAlipay, TypeWxpay},
|
||||
}
|
||||
r.Register(p)
|
||||
|
||||
got, err := r.GetProvider(TypeAlipay)
|
||||
if err != nil {
|
||||
t.Fatalf("GetProvider(alipay) error: %v", err)
|
||||
}
|
||||
if got.ProviderKey() != "testpay" {
|
||||
t.Fatalf("GetProvider(alipay) key = %q, want %q", got.ProviderKey(), "testpay")
|
||||
}
|
||||
|
||||
got2, err := r.GetProvider(TypeWxpay)
|
||||
if err != nil {
|
||||
t.Fatalf("GetProvider(wxpay) error: %v", err)
|
||||
}
|
||||
if got2.ProviderKey() != "testpay" {
|
||||
t.Fatalf("GetProvider(wxpay) key = %q, want %q", got2.ProviderKey(), "testpay")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryGetProviderNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := NewRegistry()
|
||||
|
||||
_, err := r.GetProvider("nonexistent")
|
||||
if err == nil {
|
||||
t.Fatal("GetProvider for unregistered type should return error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryGetProviderByKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := NewRegistry()
|
||||
|
||||
p := &mockProvider{
|
||||
name: "EasyPay",
|
||||
key: "easypay",
|
||||
supportedTypes: []PaymentType{TypeAlipay},
|
||||
}
|
||||
r.Register(p)
|
||||
|
||||
got, err := r.GetProviderByKey("easypay")
|
||||
if err != nil {
|
||||
t.Fatalf("GetProviderByKey error: %v", err)
|
||||
}
|
||||
if got.Name() != "EasyPay" {
|
||||
t.Fatalf("GetProviderByKey name = %q, want %q", got.Name(), "EasyPay")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryGetProviderByKeyNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := NewRegistry()
|
||||
|
||||
_, err := r.GetProviderByKey("nonexistent")
|
||||
if err == nil {
|
||||
t.Fatal("GetProviderByKey for unknown key should return error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryGetProviderKeyUnknownType(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := NewRegistry()
|
||||
|
||||
key := r.GetProviderKey("unknown_type")
|
||||
if key != "" {
|
||||
t.Fatalf("GetProviderKey for unknown type should return empty, got %q", key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryGetProviderKeyKnownType(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := NewRegistry()
|
||||
|
||||
p := &mockProvider{
|
||||
name: "Stripe",
|
||||
key: "stripe",
|
||||
supportedTypes: []PaymentType{TypeStripe},
|
||||
}
|
||||
r.Register(p)
|
||||
|
||||
key := r.GetProviderKey(TypeStripe)
|
||||
if key != "stripe" {
|
||||
t.Fatalf("GetProviderKey(stripe) = %q, want %q", key, "stripe")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistrySupportedTypes(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := NewRegistry()
|
||||
|
||||
p1 := &mockProvider{
|
||||
name: "EasyPay",
|
||||
key: "easypay",
|
||||
supportedTypes: []PaymentType{TypeAlipay, TypeWxpay},
|
||||
}
|
||||
p2 := &mockProvider{
|
||||
name: "Stripe",
|
||||
key: "stripe",
|
||||
supportedTypes: []PaymentType{TypeStripe},
|
||||
}
|
||||
r.Register(p1)
|
||||
r.Register(p2)
|
||||
|
||||
types := r.SupportedTypes()
|
||||
if len(types) != 3 {
|
||||
t.Fatalf("SupportedTypes() len = %d, want 3", len(types))
|
||||
}
|
||||
|
||||
typeSet := make(map[PaymentType]bool)
|
||||
for _, tp := range types {
|
||||
typeSet[tp] = true
|
||||
}
|
||||
for _, expected := range []PaymentType{TypeAlipay, TypeWxpay, TypeStripe} {
|
||||
if !typeSet[expected] {
|
||||
t.Fatalf("SupportedTypes() missing %q", expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistrySupportedTypesEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := NewRegistry()
|
||||
|
||||
types := r.SupportedTypes()
|
||||
if len(types) != 0 {
|
||||
t.Fatalf("SupportedTypes() on empty registry should be empty, got %d", len(types))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryOverwriteExisting(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := NewRegistry()
|
||||
|
||||
p1 := &mockProvider{
|
||||
name: "OldPay",
|
||||
key: "old",
|
||||
supportedTypes: []PaymentType{TypeAlipay},
|
||||
}
|
||||
p2 := &mockProvider{
|
||||
name: "NewPay",
|
||||
key: "new",
|
||||
supportedTypes: []PaymentType{TypeAlipay},
|
||||
}
|
||||
r.Register(p1)
|
||||
r.Register(p2)
|
||||
|
||||
got, err := r.GetProvider(TypeAlipay)
|
||||
if err != nil {
|
||||
t.Fatalf("GetProvider error: %v", err)
|
||||
}
|
||||
if got.Name() != "NewPay" {
|
||||
t.Fatalf("expected overwritten provider, got %q", got.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := NewRegistry()
|
||||
|
||||
const goroutines = 50
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines * 2)
|
||||
|
||||
// Concurrent writers
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
p := &mockProvider{
|
||||
name: fmt.Sprintf("Provider-%d", idx),
|
||||
key: fmt.Sprintf("key-%d", idx),
|
||||
supportedTypes: []PaymentType{PaymentType(fmt.Sprintf("type-%d", idx))},
|
||||
}
|
||||
r.Register(p)
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Concurrent readers
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = r.SupportedTypes()
|
||||
_, _ = r.GetProvider("some-type")
|
||||
_ = r.GetProviderKey("some-type")
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
types := r.SupportedTypes()
|
||||
if len(types) != goroutines {
|
||||
t.Fatalf("after concurrent registration, expected %d types, got %d", goroutines, len(types))
|
||||
}
|
||||
}
|
||||
180
backend/internal/payment/types.go
Normal file
180
backend/internal/payment/types.go
Normal file
@@ -0,0 +1,180 @@
|
||||
// Package payment provides the core payment provider abstraction,
|
||||
// registry, load balancing, and shared utilities for the payment subsystem.
|
||||
package payment
|
||||
|
||||
import "context"
|
||||
|
||||
// PaymentType represents a supported payment method.
|
||||
type PaymentType = string
|
||||
|
||||
// Supported payment type constants.
|
||||
const (
|
||||
TypeAlipay PaymentType = "alipay"
|
||||
TypeWxpay PaymentType = "wxpay"
|
||||
TypeAlipayDirect PaymentType = "alipay_direct"
|
||||
TypeWxpayDirect PaymentType = "wxpay_direct"
|
||||
TypeStripe PaymentType = "stripe"
|
||||
TypeCard PaymentType = "card"
|
||||
TypeLink PaymentType = "link"
|
||||
TypeEasyPay PaymentType = "easypay"
|
||||
)
|
||||
|
||||
// Order status constants shared across payment and service layers.
|
||||
const (
|
||||
OrderStatusPending = "PENDING"
|
||||
OrderStatusPaid = "PAID"
|
||||
OrderStatusRecharging = "RECHARGING"
|
||||
OrderStatusCompleted = "COMPLETED"
|
||||
OrderStatusExpired = "EXPIRED"
|
||||
OrderStatusCancelled = "CANCELLED"
|
||||
OrderStatusFailed = "FAILED"
|
||||
OrderStatusRefundRequested = "REFUND_REQUESTED"
|
||||
OrderStatusRefunding = "REFUNDING"
|
||||
OrderStatusPartiallyRefunded = "PARTIALLY_REFUNDED"
|
||||
OrderStatusRefunded = "REFUNDED"
|
||||
OrderStatusRefundFailed = "REFUND_FAILED"
|
||||
)
|
||||
|
||||
// Order types distinguish balance recharges from subscription purchases.
|
||||
const (
|
||||
OrderTypeBalance = "balance"
|
||||
OrderTypeSubscription = "subscription"
|
||||
)
|
||||
|
||||
// Entity statuses shared across users, groups, etc.
|
||||
const (
|
||||
EntityStatusActive = "active"
|
||||
)
|
||||
|
||||
// Deduction types for refund flow.
|
||||
const (
|
||||
DeductionTypeBalance = "balance"
|
||||
DeductionTypeSubscription = "subscription"
|
||||
DeductionTypeNone = "none"
|
||||
)
|
||||
|
||||
// Payment notification status values.
|
||||
const (
|
||||
NotificationStatusSuccess = "success"
|
||||
NotificationStatusPaid = "paid"
|
||||
)
|
||||
|
||||
// Provider-level status constants returned by provider implementations
|
||||
// to the service layer (lowercase, distinct from OrderStatus uppercase constants).
|
||||
const (
|
||||
ProviderStatusPending = "pending"
|
||||
ProviderStatusPaid = "paid"
|
||||
ProviderStatusSuccess = "success"
|
||||
ProviderStatusFailed = "failed"
|
||||
ProviderStatusRefunded = "refunded"
|
||||
)
|
||||
|
||||
// DefaultLoadBalanceStrategy is the default load-balancing strategy
|
||||
// used when no strategy is configured.
|
||||
const DefaultLoadBalanceStrategy = "round-robin"
|
||||
|
||||
// ConfigKeyPublishableKey is the config map key for Stripe's publishable key.
|
||||
const ConfigKeyPublishableKey = "publishableKey"
|
||||
|
||||
// GetBasePaymentType extracts the base payment method from a composite key.
|
||||
// For example, "alipay_direct" -> "alipay".
|
||||
func GetBasePaymentType(t string) string {
|
||||
switch {
|
||||
case t == TypeEasyPay:
|
||||
return TypeEasyPay
|
||||
case t == TypeStripe || t == TypeCard || t == TypeLink:
|
||||
return TypeStripe
|
||||
case len(t) >= len(TypeAlipay) && t[:len(TypeAlipay)] == TypeAlipay:
|
||||
return TypeAlipay
|
||||
case len(t) >= len(TypeWxpay) && t[:len(TypeWxpay)] == TypeWxpay:
|
||||
return TypeWxpay
|
||||
default:
|
||||
return t
|
||||
}
|
||||
}
|
||||
|
||||
// CreatePaymentRequest holds the parameters for creating a new payment.
|
||||
type CreatePaymentRequest struct {
|
||||
OrderID string // Internal order ID
|
||||
Amount string // Pay amount in CNY (formatted to 2 decimal places)
|
||||
PaymentType string // e.g. "alipay", "wxpay", "stripe"
|
||||
Subject string // Product description
|
||||
NotifyURL string // Webhook callback URL
|
||||
ReturnURL string // Browser redirect URL after payment
|
||||
ClientIP string // Payer's IP address
|
||||
IsMobile bool // Whether the request comes from a mobile device
|
||||
InstanceSubMethods string // Comma-separated sub-methods from instance supported_types (for Stripe)
|
||||
}
|
||||
|
||||
// CreatePaymentResponse is returned after successfully initiating a payment.
|
||||
type CreatePaymentResponse struct {
|
||||
TradeNo string // Third-party transaction ID
|
||||
PayURL string // H5 payment URL (alipay/wxpay)
|
||||
QRCode string // QR code content for scanning
|
||||
ClientSecret string // Stripe PaymentIntent client secret
|
||||
}
|
||||
|
||||
// QueryOrderResponse describes the payment status from the upstream provider.
|
||||
type QueryOrderResponse struct {
|
||||
TradeNo string
|
||||
Status string // "pending", "paid", "failed", "refunded"
|
||||
Amount float64 // Amount in CNY
|
||||
PaidAt string // RFC3339 timestamp or empty
|
||||
}
|
||||
|
||||
// PaymentNotification is the parsed result of a webhook/notify callback.
|
||||
type PaymentNotification struct {
|
||||
TradeNo string
|
||||
OrderID string
|
||||
Amount float64
|
||||
Status string // "success" or "failed"
|
||||
RawData string // Raw notification body for audit
|
||||
}
|
||||
|
||||
// RefundRequest contains the parameters for requesting a refund.
|
||||
type RefundRequest struct {
|
||||
TradeNo string
|
||||
OrderID string
|
||||
Amount string // Refund amount formatted to 2 decimal places
|
||||
Reason string
|
||||
}
|
||||
|
||||
// RefundResponse is returned after a refund request.
|
||||
type RefundResponse struct {
|
||||
RefundID string
|
||||
Status string // "success", "pending", "failed"
|
||||
}
|
||||
|
||||
// InstanceSelection holds the selected provider instance and its decrypted config.
|
||||
type InstanceSelection struct {
|
||||
InstanceID string
|
||||
Config map[string]string
|
||||
SupportedTypes string // Comma-separated list of supported payment types from the instance
|
||||
PaymentMode string // Payment display mode: "qrcode", "redirect", "popup"
|
||||
}
|
||||
|
||||
// Provider defines the interface that all payment providers must implement.
|
||||
type Provider interface {
|
||||
// Name returns a human-readable name for this provider.
|
||||
Name() string
|
||||
// ProviderKey returns the unique key identifying this provider type (e.g. "easypay").
|
||||
ProviderKey() string
|
||||
// SupportedTypes returns the list of payment types this provider handles.
|
||||
SupportedTypes() []PaymentType
|
||||
// CreatePayment initiates a payment and returns the upstream response.
|
||||
CreatePayment(ctx context.Context, req CreatePaymentRequest) (*CreatePaymentResponse, error)
|
||||
// QueryOrder queries the payment status of the given trade number.
|
||||
QueryOrder(ctx context.Context, tradeNo string) (*QueryOrderResponse, error)
|
||||
// VerifyNotification parses and verifies a webhook callback.
|
||||
// Returns nil for unrecognized or irrelevant events (caller should return 200).
|
||||
VerifyNotification(ctx context.Context, rawBody string, headers map[string]string) (*PaymentNotification, error)
|
||||
// Refund requests a refund from the upstream provider.
|
||||
Refund(ctx context.Context, req RefundRequest) (*RefundResponse, error)
|
||||
}
|
||||
|
||||
// CancelableProvider extends Provider with the ability to cancel pending payments.
|
||||
type CancelableProvider interface {
|
||||
Provider
|
||||
// CancelPayment cancels/expires a pending payment on the upstream platform.
|
||||
CancelPayment(ctx context.Context, tradeNo string) error
|
||||
}
|
||||
53
backend/internal/payment/wire.go
Normal file
53
backend/internal/payment/wire.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package payment
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
// EncryptionKey is a named type for the payment encryption key (AES-256, 32 bytes).
|
||||
// Using a named type avoids Wire ambiguity with other []byte parameters.
|
||||
type EncryptionKey []byte
|
||||
|
||||
// ProvideEncryptionKey derives the payment encryption key from the TOTP encryption key in config.
|
||||
// When the key is empty, nil is returned (payment features that need encryption will be disabled).
|
||||
// When the key is non-empty but invalid (bad hex or wrong length), an error is returned
|
||||
// to prevent startup with a misconfigured encryption key.
|
||||
func ProvideEncryptionKey(cfg *config.Config) (EncryptionKey, error) {
|
||||
if cfg.Totp.EncryptionKey == "" {
|
||||
slog.Warn("payment encryption key not configured — encrypted payment config will be unavailable")
|
||||
return nil, nil
|
||||
}
|
||||
key, err := hex.DecodeString(cfg.Totp.EncryptionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid payment encryption key (hex decode): %w", err)
|
||||
}
|
||||
if len(key) != 32 {
|
||||
return nil, fmt.Errorf("payment encryption key must be 32 bytes, got %d", len(key))
|
||||
}
|
||||
return EncryptionKey(key), nil
|
||||
}
|
||||
|
||||
// ProvideRegistry creates an empty payment provider registry.
|
||||
// Providers are registered at runtime after application startup.
|
||||
func ProvideRegistry() *Registry {
|
||||
return NewRegistry()
|
||||
}
|
||||
|
||||
// ProvideDefaultLoadBalancer creates a DefaultLoadBalancer backed by the ent client.
|
||||
func ProvideDefaultLoadBalancer(client *dbent.Client, key EncryptionKey) *DefaultLoadBalancer {
|
||||
return NewDefaultLoadBalancer(client, []byte(key))
|
||||
}
|
||||
|
||||
// ProviderSet is the Wire provider set for the payment package.
|
||||
var ProviderSet = wire.NewSet(
|
||||
ProvideEncryptionKey,
|
||||
ProvideRegistry,
|
||||
ProvideDefaultLoadBalancer,
|
||||
wire.Bind(new(LoadBalancer), new(*DefaultLoadBalancer)),
|
||||
)
|
||||
@@ -583,6 +583,24 @@ func TestAPIContracts(t *testing.T) {
|
||||
"enable_cch_signing": false,
|
||||
"enable_fingerprint_unification": true,
|
||||
"enable_metadata_passthrough": false,
|
||||
"payment_enabled": false,
|
||||
"payment_min_amount": 0,
|
||||
"payment_max_amount": 0,
|
||||
"payment_daily_limit": 0,
|
||||
"payment_order_timeout_minutes": 0,
|
||||
"payment_max_pending_orders": 0,
|
||||
"payment_enabled_types": null,
|
||||
"payment_balance_disabled": false,
|
||||
"payment_load_balance_strategy": "",
|
||||
"payment_product_name_prefix": "",
|
||||
"payment_product_name_suffix": "",
|
||||
"payment_help_image_url": "",
|
||||
"payment_help_text": "",
|
||||
"payment_cancel_rate_limit_enabled": false,
|
||||
"payment_cancel_rate_limit_max": 0,
|
||||
"payment_cancel_rate_limit_window": 0,
|
||||
"payment_cancel_rate_limit_unit": "",
|
||||
"payment_cancel_rate_limit_window_mode": "",
|
||||
"custom_menu_items": [],
|
||||
"custom_endpoints": []
|
||||
}
|
||||
@@ -696,7 +714,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
|
||||
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil, nil)
|
||||
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
jwtAuth := func(c *gin.Context) {
|
||||
|
||||
@@ -111,4 +111,5 @@ func registerRoutes(
|
||||
routes.RegisterUserRoutes(v1, h, jwtAuth, settingService)
|
||||
routes.RegisterAdminRoutes(v1, h, adminAuth)
|
||||
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
|
||||
routes.RegisterPaymentRoutes(v1, h.Payment, h.PaymentWebhook, h.Admin.Payment, jwtAuth, adminAuth, settingService)
|
||||
}
|
||||
|
||||
103
backend/internal/server/routes/payment.go
Normal file
103
backend/internal/server/routes/payment.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RegisterPaymentRoutes registers all payment-related routes:
|
||||
// user-facing endpoints, webhook endpoints, and admin endpoints.
|
||||
func RegisterPaymentRoutes(
|
||||
v1 *gin.RouterGroup,
|
||||
paymentHandler *handler.PaymentHandler,
|
||||
webhookHandler *handler.PaymentWebhookHandler,
|
||||
adminPaymentHandler *admin.PaymentHandler,
|
||||
jwtAuth middleware.JWTAuthMiddleware,
|
||||
adminAuth middleware.AdminAuthMiddleware,
|
||||
settingService *service.SettingService,
|
||||
) {
|
||||
// --- User-facing payment endpoints (authenticated) ---
|
||||
authenticated := v1.Group("/payment")
|
||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||
authenticated.Use(middleware.BackendModeUserGuard(settingService))
|
||||
{
|
||||
authenticated.GET("/config", paymentHandler.GetPaymentConfig)
|
||||
authenticated.GET("/checkout-info", paymentHandler.GetCheckoutInfo)
|
||||
authenticated.GET("/plans", paymentHandler.GetPlans)
|
||||
authenticated.GET("/channels", paymentHandler.GetChannels)
|
||||
authenticated.GET("/limits", paymentHandler.GetLimits)
|
||||
|
||||
orders := authenticated.Group("/orders")
|
||||
{
|
||||
orders.POST("", paymentHandler.CreateOrder)
|
||||
orders.POST("/verify", paymentHandler.VerifyOrder)
|
||||
orders.GET("/my", paymentHandler.GetMyOrders)
|
||||
orders.GET("/:id", paymentHandler.GetOrder)
|
||||
orders.POST("/:id/cancel", paymentHandler.CancelOrder)
|
||||
orders.POST("/:id/refund-request", paymentHandler.RequestRefund)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Public payment endpoints (no auth) ---
|
||||
// Payment result page needs to verify order status without login
|
||||
// (user session may have expired during provider redirect).
|
||||
public := v1.Group("/payment/public")
|
||||
{
|
||||
public.POST("/orders/verify", paymentHandler.VerifyOrderPublic)
|
||||
}
|
||||
|
||||
// --- Webhook endpoints (no auth) ---
|
||||
webhook := v1.Group("/payment/webhook")
|
||||
{
|
||||
// EasyPay sends GET callbacks with query params
|
||||
webhook.GET("/easypay", webhookHandler.EasyPayNotify)
|
||||
webhook.POST("/easypay", webhookHandler.EasyPayNotify)
|
||||
webhook.POST("/alipay", webhookHandler.AlipayNotify)
|
||||
webhook.POST("/wxpay", webhookHandler.WxpayNotify)
|
||||
webhook.POST("/stripe", webhookHandler.StripeWebhook)
|
||||
}
|
||||
|
||||
// --- Admin payment endpoints (admin auth) ---
|
||||
adminGroup := v1.Group("/admin/payment")
|
||||
adminGroup.Use(gin.HandlerFunc(adminAuth))
|
||||
{
|
||||
// Dashboard
|
||||
adminGroup.GET("/dashboard", adminPaymentHandler.GetDashboard)
|
||||
|
||||
// Config
|
||||
adminGroup.GET("/config", adminPaymentHandler.GetConfig)
|
||||
adminGroup.PUT("/config", adminPaymentHandler.UpdateConfig)
|
||||
|
||||
// Orders
|
||||
adminOrders := adminGroup.Group("/orders")
|
||||
{
|
||||
adminOrders.GET("", adminPaymentHandler.ListOrders)
|
||||
adminOrders.GET("/:id", adminPaymentHandler.GetOrderDetail)
|
||||
adminOrders.POST("/:id/cancel", adminPaymentHandler.CancelOrder)
|
||||
adminOrders.POST("/:id/retry", adminPaymentHandler.RetryFulfillment)
|
||||
adminOrders.POST("/:id/refund", adminPaymentHandler.ProcessRefund)
|
||||
}
|
||||
|
||||
// Subscription Plans
|
||||
plans := adminGroup.Group("/plans")
|
||||
{
|
||||
plans.GET("", adminPaymentHandler.ListPlans)
|
||||
plans.POST("", adminPaymentHandler.CreatePlan)
|
||||
plans.PUT("/:id", adminPaymentHandler.UpdatePlan)
|
||||
plans.DELETE("/:id", adminPaymentHandler.DeletePlan)
|
||||
}
|
||||
|
||||
// Provider Instances
|
||||
providers := adminGroup.Group("/providers")
|
||||
{
|
||||
providers.GET("", adminPaymentHandler.ListProviders)
|
||||
providers.POST("", adminPaymentHandler.CreateProvider)
|
||||
providers.PUT("/:id", adminPaymentHandler.UpdateProvider)
|
||||
providers.DELETE("/:id", adminPaymentHandler.DeleteProvider)
|
||||
}
|
||||
}
|
||||
}
|
||||
172
backend/internal/service/payment_config_limits.go
Normal file
172
backend/internal/service/payment_config_limits.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
)
|
||||
|
||||
// GetAvailableMethodLimits collects all payment types from enabled provider
|
||||
// instances and returns limits for each, plus the global widest range.
|
||||
// Stripe sub-types (card, link) are aggregated under "stripe".
|
||||
func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*MethodLimitsResponse, error) {
|
||||
instances, err := s.entClient.PaymentProviderInstance.Query().
|
||||
Where(paymentproviderinstance.EnabledEQ(true)).All(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query provider instances: %w", err)
|
||||
}
|
||||
typeInstances := pcGroupByPaymentType(instances)
|
||||
resp := &MethodLimitsResponse{
|
||||
Methods: make(map[string]MethodLimits, len(typeInstances)),
|
||||
}
|
||||
for pt, insts := range typeInstances {
|
||||
ml := pcAggregateMethodLimits(pt, insts)
|
||||
resp.Methods[ml.PaymentType] = ml
|
||||
}
|
||||
resp.GlobalMin, resp.GlobalMax = pcComputeGlobalRange(resp.Methods)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// GetMethodLimits returns per-payment-type limits from enabled provider instances.
|
||||
func (s *PaymentConfigService) GetMethodLimits(ctx context.Context, types []string) ([]MethodLimits, error) {
|
||||
instances, err := s.entClient.PaymentProviderInstance.Query().
|
||||
Where(paymentproviderinstance.EnabledEQ(true)).All(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query provider instances: %w", err)
|
||||
}
|
||||
result := make([]MethodLimits, 0, len(types))
|
||||
for _, pt := range types {
|
||||
var matching []*dbent.PaymentProviderInstance
|
||||
for _, inst := range instances {
|
||||
if payment.InstanceSupportsType(inst.SupportedTypes, pt) {
|
||||
matching = append(matching, inst)
|
||||
}
|
||||
}
|
||||
result = append(result, pcAggregateMethodLimits(pt, matching))
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// pcGroupByPaymentType groups instances by user-facing payment type.
|
||||
// For Stripe providers, ALL sub-types (card, link, alipay, wxpay) map to "stripe"
|
||||
// because the user sees a single "Stripe" button, not individual sub-methods.
|
||||
// Uses a seen set to avoid counting one instance twice.
|
||||
func pcGroupByPaymentType(instances []*dbent.PaymentProviderInstance) map[string][]*dbent.PaymentProviderInstance {
|
||||
typeInstances := make(map[string][]*dbent.PaymentProviderInstance)
|
||||
seen := make(map[string]map[int64]bool)
|
||||
add := func(key string, inst *dbent.PaymentProviderInstance) {
|
||||
if seen[key] == nil {
|
||||
seen[key] = make(map[int64]bool)
|
||||
}
|
||||
if !seen[key][int64(inst.ID)] {
|
||||
seen[key][int64(inst.ID)] = true
|
||||
typeInstances[key] = append(typeInstances[key], inst)
|
||||
}
|
||||
}
|
||||
for _, inst := range instances {
|
||||
// Stripe provider: all sub-types → single "stripe" group
|
||||
if inst.ProviderKey == payment.TypeStripe {
|
||||
add(payment.TypeStripe, inst)
|
||||
continue
|
||||
}
|
||||
for _, t := range splitTypes(inst.SupportedTypes) {
|
||||
add(t, inst)
|
||||
}
|
||||
}
|
||||
return typeInstances
|
||||
}
|
||||
|
||||
// pcInstanceTypeLimits extracts per-type limits from a provider instance.
|
||||
// Returns (limits, true) if configured; (zero, false) if unlimited.
|
||||
// For Stripe instances, limits are stored under "stripe" key regardless of sub-types.
|
||||
func pcInstanceTypeLimits(inst *dbent.PaymentProviderInstance, pt string) (payment.ChannelLimits, bool) {
|
||||
if inst.Limits == "" {
|
||||
return payment.ChannelLimits{}, false
|
||||
}
|
||||
var limits payment.InstanceLimits
|
||||
if err := json.Unmarshal([]byte(inst.Limits), &limits); err != nil {
|
||||
return payment.ChannelLimits{}, false
|
||||
}
|
||||
cl, ok := limits[pt]
|
||||
return cl, ok
|
||||
}
|
||||
|
||||
// unionFloat merges a single limit value into the aggregate using UNION semantics.
|
||||
// - For "min" fields (wantMin=true): keeps the lowest non-zero value
|
||||
// - For "max"/"cap" fields (wantMin=false): keeps the highest non-zero value
|
||||
// - If any value is 0 (unlimited), the result is unlimited.
|
||||
//
|
||||
// Returns (aggregated value, still limited).
|
||||
func unionFloat(agg float64, limited bool, val float64, wantMin bool) (float64, bool) {
|
||||
if val == 0 {
|
||||
return agg, false
|
||||
}
|
||||
if !limited {
|
||||
return agg, false
|
||||
}
|
||||
if agg == 0 {
|
||||
return val, true
|
||||
}
|
||||
if wantMin && val < agg {
|
||||
return val, true
|
||||
}
|
||||
if !wantMin && val > agg {
|
||||
return val, true
|
||||
}
|
||||
return agg, true
|
||||
}
|
||||
|
||||
// pcAggregateMethodLimits computes the UNION (least restrictive) of limits
|
||||
// across all provider instances for a given payment type.
|
||||
//
|
||||
// Since the load balancer can route an order to any available instance,
|
||||
// the user should see the widest possible range:
|
||||
// - SingleMin: lowest floor across instances; 0 if any is unlimited
|
||||
// - SingleMax: highest ceiling across instances; 0 if any is unlimited
|
||||
// - DailyLimit: highest cap across instances; 0 if any is unlimited
|
||||
func pcAggregateMethodLimits(pt string, instances []*dbent.PaymentProviderInstance) MethodLimits {
|
||||
ml := MethodLimits{PaymentType: pt}
|
||||
minLimited, maxLimited, dailyLimited := true, true, true
|
||||
|
||||
for _, inst := range instances {
|
||||
cl, hasLimits := pcInstanceTypeLimits(inst, pt)
|
||||
if !hasLimits {
|
||||
return MethodLimits{PaymentType: pt} // any unlimited instance → all zeros
|
||||
}
|
||||
ml.SingleMin, minLimited = unionFloat(ml.SingleMin, minLimited, cl.SingleMin, true)
|
||||
ml.SingleMax, maxLimited = unionFloat(ml.SingleMax, maxLimited, cl.SingleMax, false)
|
||||
ml.DailyLimit, dailyLimited = unionFloat(ml.DailyLimit, dailyLimited, cl.DailyLimit, false)
|
||||
}
|
||||
|
||||
if !minLimited {
|
||||
ml.SingleMin = 0
|
||||
}
|
||||
if !maxLimited {
|
||||
ml.SingleMax = 0
|
||||
}
|
||||
if !dailyLimited {
|
||||
ml.DailyLimit = 0
|
||||
}
|
||||
return ml
|
||||
}
|
||||
|
||||
// pcComputeGlobalRange computes the widest [min, max] across all methods.
|
||||
// Uses the same union logic: lowest min, highest max, 0 if any is unlimited.
|
||||
func pcComputeGlobalRange(methods map[string]MethodLimits) (globalMin, globalMax float64) {
|
||||
minLimited, maxLimited := true, true
|
||||
for _, ml := range methods {
|
||||
globalMin, minLimited = unionFloat(globalMin, minLimited, ml.SingleMin, true)
|
||||
globalMax, maxLimited = unionFloat(globalMax, maxLimited, ml.SingleMax, false)
|
||||
}
|
||||
if !minLimited {
|
||||
globalMin = 0
|
||||
}
|
||||
if !maxLimited {
|
||||
globalMax = 0
|
||||
}
|
||||
return globalMin, globalMax
|
||||
}
|
||||
301
backend/internal/service/payment_config_limits_test.go
Normal file
301
backend/internal/service/payment_config_limits_test.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
)
|
||||
|
||||
func TestUnionFloat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
agg float64
|
||||
limited bool
|
||||
val float64
|
||||
wantMin bool
|
||||
wantAgg float64
|
||||
wantLimited bool
|
||||
}{
|
||||
{"first non-zero value", 0, true, 5, true, 5, true},
|
||||
{"lower min replaces", 10, true, 3, true, 3, true},
|
||||
{"higher min does not replace", 3, true, 10, true, 3, true},
|
||||
{"higher max replaces", 10, true, 20, false, 20, true},
|
||||
{"lower max does not replace", 20, true, 10, false, 20, true},
|
||||
{"zero value makes unlimited", 5, true, 0, true, 5, false},
|
||||
{"already unlimited stays unlimited", 5, false, 10, true, 5, false},
|
||||
{"zero on first call", 0, true, 0, true, 0, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
gotAgg, gotLimited := unionFloat(tt.agg, tt.limited, tt.val, tt.wantMin)
|
||||
if gotAgg != tt.wantAgg || gotLimited != tt.wantLimited {
|
||||
t.Fatalf("unionFloat(%v, %v, %v, %v) = (%v, %v), want (%v, %v)",
|
||||
tt.agg, tt.limited, tt.val, tt.wantMin,
|
||||
gotAgg, gotLimited, tt.wantAgg, tt.wantLimited)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func makeInstance(id int64, providerKey, supportedTypes, limits string) *dbent.PaymentProviderInstance {
|
||||
return &dbent.PaymentProviderInstance{
|
||||
ID: id,
|
||||
ProviderKey: providerKey,
|
||||
SupportedTypes: supportedTypes,
|
||||
Limits: limits,
|
||||
Enabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
func TestPcAggregateMethodLimits(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("single instance with limits", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
inst := makeInstance(1, "easypay", "alipay,wxpay",
|
||||
`{"alipay":{"singleMin":2,"singleMax":14},"wxpay":{"singleMin":1,"singleMax":12}}`)
|
||||
ml := pcAggregateMethodLimits("alipay", []*dbent.PaymentProviderInstance{inst})
|
||||
if ml.SingleMin != 2 || ml.SingleMax != 14 {
|
||||
t.Fatalf("alipay limits = min:%v max:%v, want min:2 max:14", ml.SingleMin, ml.SingleMax)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("two instances union takes widest range", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
inst1 := makeInstance(1, "easypay", "alipay,wxpay",
|
||||
`{"alipay":{"singleMin":5,"singleMax":100}}`)
|
||||
inst2 := makeInstance(2, "easypay", "alipay,wxpay",
|
||||
`{"alipay":{"singleMin":2,"singleMax":200}}`)
|
||||
ml := pcAggregateMethodLimits("alipay", []*dbent.PaymentProviderInstance{inst1, inst2})
|
||||
if ml.SingleMin != 2 {
|
||||
t.Fatalf("SingleMin = %v, want 2 (lowest floor)", ml.SingleMin)
|
||||
}
|
||||
if ml.SingleMax != 200 {
|
||||
t.Fatalf("SingleMax = %v, want 200 (highest ceiling)", ml.SingleMax)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("one instance unlimited makes aggregate unlimited", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
inst1 := makeInstance(1, "easypay", "wxpay",
|
||||
`{"wxpay":{"singleMin":3,"singleMax":10}}`)
|
||||
inst2 := makeInstance(2, "easypay", "wxpay", "") // no limits = unlimited
|
||||
ml := pcAggregateMethodLimits("wxpay", []*dbent.PaymentProviderInstance{inst1, inst2})
|
||||
if ml.SingleMin != 0 || ml.SingleMax != 0 {
|
||||
t.Fatalf("limits = min:%v max:%v, want min:0 max:0 (unlimited)", ml.SingleMin, ml.SingleMax)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("one field unlimited others limited", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
inst1 := makeInstance(1, "easypay", "alipay",
|
||||
`{"alipay":{"singleMin":5,"singleMax":100}}`)
|
||||
inst2 := makeInstance(2, "easypay", "alipay",
|
||||
`{"alipay":{"singleMin":3,"singleMax":0}}`) // singleMax=0 = unlimited
|
||||
ml := pcAggregateMethodLimits("alipay", []*dbent.PaymentProviderInstance{inst1, inst2})
|
||||
if ml.SingleMin != 3 {
|
||||
t.Fatalf("SingleMin = %v, want 3 (lowest floor)", ml.SingleMin)
|
||||
}
|
||||
if ml.SingleMax != 0 {
|
||||
t.Fatalf("SingleMax = %v, want 0 (unlimited)", ml.SingleMax)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty instances returns zeros", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ml := pcAggregateMethodLimits("alipay", nil)
|
||||
if ml.SingleMin != 0 || ml.SingleMax != 0 || ml.DailyLimit != 0 {
|
||||
t.Fatalf("empty instances should return all zeros, got %+v", ml)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid JSON treated as unlimited", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
inst := makeInstance(1, "easypay", "alipay", `{invalid json}`)
|
||||
ml := pcAggregateMethodLimits("alipay", []*dbent.PaymentProviderInstance{inst})
|
||||
if ml.SingleMin != 0 || ml.SingleMax != 0 {
|
||||
t.Fatalf("invalid JSON should be treated as unlimited, got %+v", ml)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("type not in limits JSON treated as unlimited", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
inst := makeInstance(1, "easypay", "alipay,wxpay",
|
||||
`{"wxpay":{"singleMin":1,"singleMax":10}}`) // only wxpay, no alipay
|
||||
ml := pcAggregateMethodLimits("alipay", []*dbent.PaymentProviderInstance{inst})
|
||||
if ml.SingleMin != 0 || ml.SingleMax != 0 {
|
||||
t.Fatalf("missing type should be treated as unlimited, got %+v", ml)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("daily limit aggregation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
inst1 := makeInstance(1, "easypay", "alipay",
|
||||
`{"alipay":{"singleMin":1,"singleMax":100,"dailyLimit":500}}`)
|
||||
inst2 := makeInstance(2, "easypay", "alipay",
|
||||
`{"alipay":{"singleMin":2,"singleMax":200,"dailyLimit":1000}}`)
|
||||
ml := pcAggregateMethodLimits("alipay", []*dbent.PaymentProviderInstance{inst1, inst2})
|
||||
if ml.DailyLimit != 1000 {
|
||||
t.Fatalf("DailyLimit = %v, want 1000 (highest cap)", ml.DailyLimit)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPcGroupByPaymentType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("stripe instance maps all types to stripe group", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
stripe := makeInstance(1, payment.TypeStripe, "card,alipay,link,wxpay", "")
|
||||
easypay := makeInstance(2, payment.TypeEasyPay, "alipay,wxpay", "")
|
||||
|
||||
groups := pcGroupByPaymentType([]*dbent.PaymentProviderInstance{stripe, easypay})
|
||||
|
||||
// Stripe instance should only be in "stripe" group
|
||||
if len(groups[payment.TypeStripe]) != 1 || groups[payment.TypeStripe][0].ID != 1 {
|
||||
t.Fatalf("stripe group should contain only stripe instance, got %v", groups[payment.TypeStripe])
|
||||
}
|
||||
// alipay group should only contain easypay, NOT stripe
|
||||
if len(groups[payment.TypeAlipay]) != 1 || groups[payment.TypeAlipay][0].ID != 2 {
|
||||
t.Fatalf("alipay group should contain only easypay instance, got %v", groups[payment.TypeAlipay])
|
||||
}
|
||||
// wxpay group should only contain easypay, NOT stripe
|
||||
if len(groups[payment.TypeWxpay]) != 1 || groups[payment.TypeWxpay][0].ID != 2 {
|
||||
t.Fatalf("wxpay group should contain only easypay instance, got %v", groups[payment.TypeWxpay])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple easypay instances in same groups", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ep1 := makeInstance(1, payment.TypeEasyPay, "alipay,wxpay", "")
|
||||
ep2 := makeInstance(2, payment.TypeEasyPay, "alipay,wxpay", "")
|
||||
|
||||
groups := pcGroupByPaymentType([]*dbent.PaymentProviderInstance{ep1, ep2})
|
||||
|
||||
if len(groups[payment.TypeAlipay]) != 2 {
|
||||
t.Fatalf("alipay group should have 2 instances, got %d", len(groups[payment.TypeAlipay]))
|
||||
}
|
||||
if len(groups[payment.TypeWxpay]) != 2 {
|
||||
t.Fatalf("wxpay group should have 2 instances, got %d", len(groups[payment.TypeWxpay]))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("stripe with no supported types still in stripe group", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
stripe := makeInstance(1, payment.TypeStripe, "", "")
|
||||
|
||||
groups := pcGroupByPaymentType([]*dbent.PaymentProviderInstance{stripe})
|
||||
|
||||
if len(groups[payment.TypeStripe]) != 1 {
|
||||
t.Fatalf("stripe with empty types should still be in stripe group, got %v", groups)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPcComputeGlobalRange(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("all methods have limits", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
methods := map[string]MethodLimits{
|
||||
"alipay": {SingleMin: 2, SingleMax: 14},
|
||||
"wxpay": {SingleMin: 1, SingleMax: 12},
|
||||
"stripe": {SingleMin: 5, SingleMax: 100},
|
||||
}
|
||||
gMin, gMax := pcComputeGlobalRange(methods)
|
||||
if gMin != 1 {
|
||||
t.Fatalf("global min = %v, want 1 (lowest floor)", gMin)
|
||||
}
|
||||
if gMax != 100 {
|
||||
t.Fatalf("global max = %v, want 100 (highest ceiling)", gMax)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("one method unlimited makes global unlimited", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
methods := map[string]MethodLimits{
|
||||
"alipay": {SingleMin: 2, SingleMax: 14},
|
||||
"stripe": {SingleMin: 0, SingleMax: 0}, // unlimited
|
||||
}
|
||||
gMin, gMax := pcComputeGlobalRange(methods)
|
||||
if gMin != 0 {
|
||||
t.Fatalf("global min = %v, want 0 (unlimited)", gMin)
|
||||
}
|
||||
if gMax != 0 {
|
||||
t.Fatalf("global max = %v, want 0 (unlimited)", gMax)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty methods returns zeros", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
gMin, gMax := pcComputeGlobalRange(map[string]MethodLimits{})
|
||||
if gMin != 0 || gMax != 0 {
|
||||
t.Fatalf("empty methods should return (0, 0), got (%v, %v)", gMin, gMax)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("only min unlimited", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
methods := map[string]MethodLimits{
|
||||
"alipay": {SingleMin: 0, SingleMax: 100},
|
||||
"wxpay": {SingleMin: 5, SingleMax: 50},
|
||||
}
|
||||
gMin, gMax := pcComputeGlobalRange(methods)
|
||||
if gMin != 0 {
|
||||
t.Fatalf("global min = %v, want 0 (unlimited)", gMin)
|
||||
}
|
||||
if gMax != 100 {
|
||||
t.Fatalf("global max = %v, want 100", gMax)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPcInstanceTypeLimits(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("empty limits string returns false", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
inst := makeInstance(1, "easypay", "alipay", "")
|
||||
_, ok := pcInstanceTypeLimits(inst, "alipay")
|
||||
if ok {
|
||||
t.Fatal("expected ok=false for empty limits")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("type found returns correct values", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
inst := makeInstance(1, "easypay", "alipay",
|
||||
`{"alipay":{"singleMin":2,"singleMax":14,"dailyLimit":500}}`)
|
||||
cl, ok := pcInstanceTypeLimits(inst, "alipay")
|
||||
if !ok {
|
||||
t.Fatal("expected ok=true")
|
||||
}
|
||||
if cl.SingleMin != 2 || cl.SingleMax != 14 || cl.DailyLimit != 500 {
|
||||
t.Fatalf("limits = %+v, want min:2 max:14 daily:500", cl)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("type not found returns false", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
inst := makeInstance(1, "easypay", "alipay",
|
||||
`{"wxpay":{"singleMin":1}}`)
|
||||
_, ok := pcInstanceTypeLimits(inst, "alipay")
|
||||
if ok {
|
||||
t.Fatal("expected ok=false for missing type")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid JSON returns false", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
inst := makeInstance(1, "easypay", "alipay", `{bad json}`)
|
||||
_, ok := pcInstanceTypeLimits(inst, "alipay")
|
||||
if ok {
|
||||
t.Fatal("expected ok=false for invalid JSON")
|
||||
}
|
||||
})
|
||||
}
|
||||
147
backend/internal/service/payment_config_plans.go
Normal file
147
backend/internal/service/payment_config_plans.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/ent/subscriptionplan"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
// --- Plan CRUD ---
|
||||
|
||||
// PlanGroupInfo holds the group details needed for subscription plan display.
|
||||
type PlanGroupInfo struct {
|
||||
Platform string `json:"platform"`
|
||||
Name string `json:"name"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
ModelScopes []string `json:"supported_model_scopes"`
|
||||
}
|
||||
|
||||
// GetGroupPlatformMap returns a map of group_id → platform for the given plans.
|
||||
func (s *PaymentConfigService) GetGroupPlatformMap(ctx context.Context, plans []*dbent.SubscriptionPlan) map[int64]string {
|
||||
info := s.GetGroupInfoMap(ctx, plans)
|
||||
m := make(map[int64]string, len(info))
|
||||
for id, gi := range info {
|
||||
m[id] = gi.Platform
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// GetGroupInfoMap returns a map of group_id → PlanGroupInfo for the given plans.
|
||||
func (s *PaymentConfigService) GetGroupInfoMap(ctx context.Context, plans []*dbent.SubscriptionPlan) map[int64]PlanGroupInfo {
|
||||
ids := make([]int64, 0, len(plans))
|
||||
seen := make(map[int64]bool)
|
||||
for _, p := range plans {
|
||||
if !seen[p.GroupID] {
|
||||
seen[p.GroupID] = true
|
||||
ids = append(ids, p.GroupID)
|
||||
}
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
groups, err := s.entClient.Group.Query().Where(group.IDIn(ids...)).All(ctx)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
m := make(map[int64]PlanGroupInfo, len(groups))
|
||||
for _, g := range groups {
|
||||
m[int64(g.ID)] = PlanGroupInfo{
|
||||
Platform: g.Platform,
|
||||
Name: g.Name,
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
DailyLimitUSD: g.DailyLimitUsd,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUsd,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUsd,
|
||||
ModelScopes: g.SupportedModelScopes,
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (s *PaymentConfigService) ListPlans(ctx context.Context) ([]*dbent.SubscriptionPlan, error) {
|
||||
return s.entClient.SubscriptionPlan.Query().Order(subscriptionplan.BySortOrder()).All(ctx)
|
||||
}
|
||||
|
||||
func (s *PaymentConfigService) ListPlansForSale(ctx context.Context) ([]*dbent.SubscriptionPlan, error) {
|
||||
return s.entClient.SubscriptionPlan.Query().Where(subscriptionplan.ForSaleEQ(true)).Order(subscriptionplan.BySortOrder()).All(ctx)
|
||||
}
|
||||
|
||||
func (s *PaymentConfigService) CreatePlan(ctx context.Context, req CreatePlanRequest) (*dbent.SubscriptionPlan, error) {
|
||||
b := s.entClient.SubscriptionPlan.Create().
|
||||
SetGroupID(req.GroupID).SetName(req.Name).SetDescription(req.Description).
|
||||
SetPrice(req.Price).SetValidityDays(req.ValidityDays).SetValidityUnit(req.ValidityUnit).
|
||||
SetFeatures(req.Features).SetProductName(req.ProductName).
|
||||
SetForSale(req.ForSale).SetSortOrder(req.SortOrder)
|
||||
if req.OriginalPrice != nil {
|
||||
b.SetOriginalPrice(*req.OriginalPrice)
|
||||
}
|
||||
return b.Save(ctx)
|
||||
}
|
||||
|
||||
// UpdatePlan updates a subscription plan by ID (patch semantics).
|
||||
// NOTE: This function exceeds 30 lines due to per-field nil-check patch update boilerplate.
|
||||
func (s *PaymentConfigService) UpdatePlan(ctx context.Context, id int64, req UpdatePlanRequest) (*dbent.SubscriptionPlan, error) {
|
||||
u := s.entClient.SubscriptionPlan.UpdateOneID(id)
|
||||
if req.GroupID != nil {
|
||||
u.SetGroupID(*req.GroupID)
|
||||
}
|
||||
if req.Name != nil {
|
||||
u.SetName(*req.Name)
|
||||
}
|
||||
if req.Description != nil {
|
||||
u.SetDescription(*req.Description)
|
||||
}
|
||||
if req.Price != nil {
|
||||
u.SetPrice(*req.Price)
|
||||
}
|
||||
if req.OriginalPrice != nil {
|
||||
u.SetOriginalPrice(*req.OriginalPrice)
|
||||
}
|
||||
if req.ValidityDays != nil {
|
||||
u.SetValidityDays(*req.ValidityDays)
|
||||
}
|
||||
if req.ValidityUnit != nil {
|
||||
u.SetValidityUnit(*req.ValidityUnit)
|
||||
}
|
||||
if req.Features != nil {
|
||||
u.SetFeatures(*req.Features)
|
||||
}
|
||||
if req.ProductName != nil {
|
||||
u.SetProductName(*req.ProductName)
|
||||
}
|
||||
if req.ForSale != nil {
|
||||
u.SetForSale(*req.ForSale)
|
||||
}
|
||||
if req.SortOrder != nil {
|
||||
u.SetSortOrder(*req.SortOrder)
|
||||
}
|
||||
return u.Save(ctx)
|
||||
}
|
||||
|
||||
func (s *PaymentConfigService) DeletePlan(ctx context.Context, id int64) error {
|
||||
count, err := s.countPendingOrdersByPlan(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check pending orders: %w", err)
|
||||
}
|
||||
if count > 0 {
|
||||
return infraerrors.Conflict("PENDING_ORDERS",
|
||||
fmt.Sprintf("this plan has %d in-progress orders and cannot be deleted — wait for orders to complete first", count))
|
||||
}
|
||||
return s.entClient.SubscriptionPlan.DeleteOneID(id).Exec(ctx)
|
||||
}
|
||||
|
||||
// GetPlan returns a subscription plan by ID.
|
||||
func (s *PaymentConfigService) GetPlan(ctx context.Context, id int64) (*dbent.SubscriptionPlan, error) {
|
||||
plan, err := s.entClient.SubscriptionPlan.Get(ctx, id)
|
||||
if err != nil {
|
||||
return nil, infraerrors.NotFound("PLAN_NOT_FOUND", "subscription plan not found")
|
||||
}
|
||||
return plan, nil
|
||||
}
|
||||
286
backend/internal/service/payment_config_providers.go
Normal file
286
backend/internal/service/payment_config_providers.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
|
||||
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
// --- Provider Instance CRUD ---
|
||||
|
||||
func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*dbent.PaymentProviderInstance, error) {
|
||||
return s.entClient.PaymentProviderInstance.Query().Order(paymentproviderinstance.BySortOrder()).All(ctx)
|
||||
}
|
||||
|
||||
// ProviderInstanceResponse is the API response for a provider instance.
|
||||
type ProviderInstanceResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
ProviderKey string `json:"provider_key"`
|
||||
Name string `json:"name"`
|
||||
Config map[string]string `json:"config"`
|
||||
SupportedTypes []string `json:"supported_types"`
|
||||
Limits string `json:"limits"`
|
||||
Enabled bool `json:"enabled"`
|
||||
RefundEnabled bool `json:"refund_enabled"`
|
||||
SortOrder int `json:"sort_order"`
|
||||
PaymentMode string `json:"payment_mode"`
|
||||
}
|
||||
|
||||
// ListProviderInstancesWithConfig returns provider instances with decrypted config.
|
||||
func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Context) ([]ProviderInstanceResponse, error) {
|
||||
instances, err := s.entClient.PaymentProviderInstance.Query().
|
||||
Order(paymentproviderinstance.BySortOrder()).All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]ProviderInstanceResponse, 0, len(instances))
|
||||
for _, inst := range instances {
|
||||
resp := ProviderInstanceResponse{
|
||||
ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name,
|
||||
SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits,
|
||||
Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, SortOrder: inst.SortOrder,
|
||||
PaymentMode: inst.PaymentMode,
|
||||
}
|
||||
resp.Config, err = s.decryptAndMaskConfig(inst.Config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypt config for instance %d: %w", inst.ID, err)
|
||||
}
|
||||
result = append(result, resp)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *PaymentConfigService) decryptAndMaskConfig(encrypted string) (map[string]string, error) {
|
||||
return s.decryptConfig(encrypted)
|
||||
}
|
||||
|
||||
// pendingOrderStatuses are order statuses considered "in progress".
|
||||
var pendingOrderStatuses = []string{
|
||||
payment.OrderStatusPending,
|
||||
payment.OrderStatusPaid,
|
||||
payment.OrderStatusRecharging,
|
||||
}
|
||||
|
||||
var sensitiveConfigPatterns = []string{"key", "pkey", "secret", "private", "password"}
|
||||
|
||||
func isSensitiveConfigField(fieldName string) bool {
|
||||
lower := strings.ToLower(fieldName)
|
||||
for _, p := range sensitiveConfigPatterns {
|
||||
if strings.Contains(lower, p) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *PaymentConfigService) countPendingOrders(ctx context.Context, providerInstanceID int64) (int, error) {
|
||||
return s.entClient.PaymentOrder.Query().
|
||||
Where(
|
||||
paymentorder.ProviderInstanceIDEQ(strconv.FormatInt(providerInstanceID, 10)),
|
||||
paymentorder.StatusIn(pendingOrderStatuses...),
|
||||
).Count(ctx)
|
||||
}
|
||||
|
||||
func (s *PaymentConfigService) countPendingOrdersByPlan(ctx context.Context, planID int64) (int, error) {
|
||||
return s.entClient.PaymentOrder.Query().
|
||||
Where(
|
||||
paymentorder.PlanIDEQ(planID),
|
||||
paymentorder.StatusIn(pendingOrderStatuses...),
|
||||
).Count(ctx)
|
||||
}
|
||||
|
||||
var validProviderKeys = map[string]bool{
|
||||
payment.TypeEasyPay: true, payment.TypeAlipay: true, payment.TypeWxpay: true, payment.TypeStripe: true,
|
||||
}
|
||||
|
||||
func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req CreateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) {
|
||||
typesStr := joinTypes(req.SupportedTypes)
|
||||
if err := validateProviderRequest(req.ProviderKey, req.Name, typesStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
enc, err := s.encryptConfig(req.Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.entClient.PaymentProviderInstance.Create().
|
||||
SetProviderKey(req.ProviderKey).SetName(req.Name).SetConfig(enc).
|
||||
SetSupportedTypes(typesStr).SetEnabled(req.Enabled).SetPaymentMode(req.PaymentMode).
|
||||
SetSortOrder(req.SortOrder).SetLimits(req.Limits).SetRefundEnabled(req.RefundEnabled).
|
||||
Save(ctx)
|
||||
}
|
||||
|
||||
func validateProviderRequest(providerKey, name, supportedTypes string) error {
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return infraerrors.BadRequest("VALIDATION_ERROR", "provider name is required")
|
||||
}
|
||||
if !validProviderKeys[providerKey] {
|
||||
return infraerrors.BadRequest("VALIDATION_ERROR", fmt.Sprintf("invalid provider key: %s", providerKey))
|
||||
}
|
||||
// supported_types can be empty (provider accepts no payment types until configured)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateProviderInstance updates a provider instance by ID (patch semantics).
|
||||
// NOTE: This function exceeds 30 lines due to per-field nil-check patch update
|
||||
// boilerplate and pending-order safety checks.
|
||||
func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id int64, req UpdateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) {
|
||||
if req.Config != nil {
|
||||
hasSensitive := false
|
||||
for k := range req.Config {
|
||||
if isSensitiveConfigField(k) && req.Config[k] != "" {
|
||||
hasSensitive = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasSensitive {
|
||||
count, err := s.countPendingOrders(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check pending orders: %w", err)
|
||||
}
|
||||
if count > 0 {
|
||||
return nil, infraerrors.Conflict("PENDING_ORDERS", "instance has pending orders").
|
||||
WithMetadata(map[string]string{"count": strconv.Itoa(count)})
|
||||
}
|
||||
}
|
||||
}
|
||||
if req.Enabled != nil && !*req.Enabled {
|
||||
count, err := s.countPendingOrders(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check pending orders: %w", err)
|
||||
}
|
||||
if count > 0 {
|
||||
return nil, infraerrors.Conflict("PENDING_ORDERS", "instance has pending orders").
|
||||
WithMetadata(map[string]string{"count": strconv.Itoa(count)})
|
||||
}
|
||||
}
|
||||
u := s.entClient.PaymentProviderInstance.UpdateOneID(id)
|
||||
if req.Name != nil {
|
||||
u.SetName(*req.Name)
|
||||
}
|
||||
if req.Config != nil {
|
||||
merged, err := s.mergeConfig(ctx, id, req.Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
enc, err := s.encryptConfig(merged)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.SetConfig(enc)
|
||||
}
|
||||
if req.SupportedTypes != nil {
|
||||
// Check pending orders before removing payment types
|
||||
count, err := s.countPendingOrders(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check pending orders: %w", err)
|
||||
}
|
||||
if count > 0 {
|
||||
// Load current instance to compare types
|
||||
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load provider instance: %w", err)
|
||||
}
|
||||
oldTypes := strings.Split(inst.SupportedTypes, ",")
|
||||
newTypes := req.SupportedTypes
|
||||
for _, ot := range oldTypes {
|
||||
ot = strings.TrimSpace(ot)
|
||||
if ot == "" {
|
||||
continue
|
||||
}
|
||||
found := false
|
||||
for _, nt := range newTypes {
|
||||
if strings.TrimSpace(nt) == ot {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, infraerrors.Conflict("PENDING_ORDERS", "cannot remove payment types while instance has pending orders").
|
||||
WithMetadata(map[string]string{"count": strconv.Itoa(count)})
|
||||
}
|
||||
}
|
||||
}
|
||||
u.SetSupportedTypes(joinTypes(req.SupportedTypes))
|
||||
}
|
||||
if req.Enabled != nil {
|
||||
u.SetEnabled(*req.Enabled)
|
||||
}
|
||||
if req.SortOrder != nil {
|
||||
u.SetSortOrder(*req.SortOrder)
|
||||
}
|
||||
if req.Limits != nil {
|
||||
u.SetLimits(*req.Limits)
|
||||
}
|
||||
if req.RefundEnabled != nil {
|
||||
u.SetRefundEnabled(*req.RefundEnabled)
|
||||
}
|
||||
if req.PaymentMode != nil {
|
||||
u.SetPaymentMode(*req.PaymentMode)
|
||||
}
|
||||
return u.Save(ctx)
|
||||
}
|
||||
|
||||
func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newConfig map[string]string) (map[string]string, error) {
|
||||
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load existing provider: %w", err)
|
||||
}
|
||||
existing, err := s.decryptConfig(inst.Config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypt existing config for instance %d: %w", id, err)
|
||||
}
|
||||
if existing == nil {
|
||||
return newConfig, nil
|
||||
}
|
||||
for k, v := range newConfig {
|
||||
existing[k] = v
|
||||
}
|
||||
return existing, nil
|
||||
}
|
||||
|
||||
func (s *PaymentConfigService) decryptConfig(encrypted string) (map[string]string, error) {
|
||||
if encrypted == "" {
|
||||
return nil, nil
|
||||
}
|
||||
decrypted, err := payment.Decrypt(encrypted, s.encryptionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypt config: %w", err)
|
||||
}
|
||||
var raw map[string]string
|
||||
if err := json.Unmarshal([]byte(decrypted), &raw); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal decrypted config: %w", err)
|
||||
}
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id int64) error {
|
||||
count, err := s.countPendingOrders(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check pending orders: %w", err)
|
||||
}
|
||||
if count > 0 {
|
||||
return infraerrors.Conflict("PENDING_ORDERS",
|
||||
fmt.Sprintf("this instance has %d in-progress orders and cannot be deleted — wait for orders to complete or disable the instance first", count))
|
||||
}
|
||||
return s.entClient.PaymentProviderInstance.DeleteOneID(id).Exec(ctx)
|
||||
}
|
||||
|
||||
func (s *PaymentConfigService) encryptConfig(cfg map[string]string) (string, error) {
|
||||
data, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal config: %w", err)
|
||||
}
|
||||
enc, err := payment.Encrypt(string(data), s.encryptionKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("encrypt config: %w", err)
|
||||
}
|
||||
return enc, nil
|
||||
}
|
||||
187
backend/internal/service/payment_config_providers_test.go
Normal file
187
backend/internal/service/payment_config_providers_test.go
Normal file
@@ -0,0 +1,187 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestValidateProviderRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
providerKey string
|
||||
providerName string
|
||||
supportedTypes string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "valid easypay with types",
|
||||
providerKey: "easypay",
|
||||
providerName: "MyProvider",
|
||||
supportedTypes: "alipay,wxpay",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid stripe with empty types",
|
||||
providerKey: "stripe",
|
||||
providerName: "Stripe Provider",
|
||||
supportedTypes: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid alipay provider",
|
||||
providerKey: "alipay",
|
||||
providerName: "Alipay Direct",
|
||||
supportedTypes: "alipay",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid wxpay provider",
|
||||
providerKey: "wxpay",
|
||||
providerName: "WeChat Pay",
|
||||
supportedTypes: "wxpay",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid provider key",
|
||||
providerKey: "invalid",
|
||||
providerName: "Name",
|
||||
supportedTypes: "alipay",
|
||||
wantErr: true,
|
||||
errContains: "invalid provider key",
|
||||
},
|
||||
{
|
||||
name: "empty name",
|
||||
providerKey: "easypay",
|
||||
providerName: "",
|
||||
supportedTypes: "alipay",
|
||||
wantErr: true,
|
||||
errContains: "provider name is required",
|
||||
},
|
||||
{
|
||||
name: "whitespace-only name",
|
||||
providerKey: "easypay",
|
||||
providerName: " ",
|
||||
supportedTypes: "alipay",
|
||||
wantErr: true,
|
||||
errContains: "provider name is required",
|
||||
},
|
||||
{
|
||||
name: "tab-only name",
|
||||
providerKey: "easypay",
|
||||
providerName: "\t",
|
||||
supportedTypes: "alipay",
|
||||
wantErr: true,
|
||||
errContains: "provider name is required",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := validateProviderRequest(tc.providerKey, tc.providerName, tc.supportedTypes)
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tc.errContains)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSensitiveConfigField(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
field string
|
||||
wantSen bool
|
||||
}{
|
||||
// Sensitive fields (contain key/secret/private/password/pkey patterns)
|
||||
{"secretKey", true},
|
||||
{"apiSecret", true},
|
||||
{"pkey", true},
|
||||
{"privateKey", true},
|
||||
{"apiPassword", true},
|
||||
{"appKey", true},
|
||||
{"SECRET_TOKEN", true},
|
||||
{"PrivateData", true},
|
||||
{"PASSWORD", true},
|
||||
{"mySecretValue", true},
|
||||
|
||||
// Non-sensitive fields
|
||||
{"appId", false},
|
||||
{"mchId", false},
|
||||
{"apiBase", false},
|
||||
{"endpoint", false},
|
||||
{"merchantNo", false},
|
||||
{"paymentMode", false},
|
||||
{"notifyUrl", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.field, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := isSensitiveConfigField(tc.field)
|
||||
assert.Equal(t, tc.wantSen, got, "isSensitiveConfigField(%q)", tc.field)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJoinTypes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input []string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "multiple types",
|
||||
input: []string{"alipay", "wxpay"},
|
||||
want: "alipay,wxpay",
|
||||
},
|
||||
{
|
||||
name: "single type",
|
||||
input: []string{"stripe"},
|
||||
want: "stripe",
|
||||
},
|
||||
{
|
||||
name: "empty slice",
|
||||
input: []string{},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "nil slice",
|
||||
input: nil,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "three types",
|
||||
input: []string{"alipay", "wxpay", "stripe"},
|
||||
want: "alipay,wxpay,stripe",
|
||||
},
|
||||
{
|
||||
name: "types with spaces are not trimmed",
|
||||
input: []string{" alipay ", " wxpay "},
|
||||
want: " alipay , wxpay ",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := joinTypes(tc.input)
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
351
backend/internal/service/payment_config_service.go
Normal file
351
backend/internal/service/payment_config_service.go
Normal file
@@ -0,0 +1,351 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
)
|
||||
|
||||
const (
|
||||
SettingPaymentEnabled = "payment_enabled"
|
||||
SettingMinRechargeAmount = "MIN_RECHARGE_AMOUNT"
|
||||
SettingMaxRechargeAmount = "MAX_RECHARGE_AMOUNT"
|
||||
SettingDailyRechargeLimit = "DAILY_RECHARGE_LIMIT"
|
||||
SettingOrderTimeoutMinutes = "ORDER_TIMEOUT_MINUTES"
|
||||
SettingMaxPendingOrders = "MAX_PENDING_ORDERS"
|
||||
SettingEnabledPaymentTypes = "ENABLED_PAYMENT_TYPES"
|
||||
SettingLoadBalanceStrategy = "LOAD_BALANCE_STRATEGY"
|
||||
SettingBalancePayDisabled = "BALANCE_PAYMENT_DISABLED"
|
||||
SettingProductNamePrefix = "PRODUCT_NAME_PREFIX"
|
||||
SettingProductNameSuffix = "PRODUCT_NAME_SUFFIX"
|
||||
SettingHelpImageURL = "PAYMENT_HELP_IMAGE_URL"
|
||||
SettingHelpText = "PAYMENT_HELP_TEXT"
|
||||
SettingCancelRateLimitOn = "CANCEL_RATE_LIMIT_ENABLED"
|
||||
SettingCancelRateLimitMax = "CANCEL_RATE_LIMIT_MAX"
|
||||
SettingCancelWindowSize = "CANCEL_RATE_LIMIT_WINDOW"
|
||||
SettingCancelWindowUnit = "CANCEL_RATE_LIMIT_UNIT"
|
||||
SettingCancelWindowMode = "CANCEL_RATE_LIMIT_WINDOW_MODE"
|
||||
)
|
||||
|
||||
// Default values for payment configuration settings.
|
||||
const (
|
||||
defaultOrderTimeoutMin = 30
|
||||
defaultMaxPendingOrders = 3
|
||||
)
|
||||
|
||||
// PaymentConfig holds the payment system configuration.
|
||||
type PaymentConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
MinAmount float64 `json:"min_amount"`
|
||||
MaxAmount float64 `json:"max_amount"`
|
||||
DailyLimit float64 `json:"daily_limit"`
|
||||
OrderTimeoutMin int `json:"order_timeout_minutes"`
|
||||
MaxPendingOrders int `json:"max_pending_orders"`
|
||||
EnabledTypes []string `json:"enabled_payment_types"`
|
||||
BalanceDisabled bool `json:"balance_disabled"`
|
||||
LoadBalanceStrategy string `json:"load_balance_strategy"`
|
||||
ProductNamePrefix string `json:"product_name_prefix"`
|
||||
ProductNameSuffix string `json:"product_name_suffix"`
|
||||
HelpImageURL string `json:"help_image_url"`
|
||||
HelpText string `json:"help_text"`
|
||||
StripePublishableKey string `json:"stripe_publishable_key,omitempty"`
|
||||
|
||||
// Cancel rate limit settings
|
||||
CancelRateLimitEnabled bool `json:"cancel_rate_limit_enabled"`
|
||||
CancelRateLimitMax int `json:"cancel_rate_limit_max"`
|
||||
CancelRateLimitWindow int `json:"cancel_rate_limit_window"`
|
||||
CancelRateLimitUnit string `json:"cancel_rate_limit_unit"`
|
||||
CancelRateLimitMode string `json:"cancel_rate_limit_window_mode"`
|
||||
}
|
||||
|
||||
// UpdatePaymentConfigRequest contains fields to update payment configuration.
|
||||
type UpdatePaymentConfigRequest struct {
|
||||
Enabled *bool `json:"enabled"`
|
||||
MinAmount *float64 `json:"min_amount"`
|
||||
MaxAmount *float64 `json:"max_amount"`
|
||||
DailyLimit *float64 `json:"daily_limit"`
|
||||
OrderTimeoutMin *int `json:"order_timeout_minutes"`
|
||||
MaxPendingOrders *int `json:"max_pending_orders"`
|
||||
EnabledTypes []string `json:"enabled_payment_types"`
|
||||
BalanceDisabled *bool `json:"balance_disabled"`
|
||||
LoadBalanceStrategy *string `json:"load_balance_strategy"`
|
||||
ProductNamePrefix *string `json:"product_name_prefix"`
|
||||
ProductNameSuffix *string `json:"product_name_suffix"`
|
||||
HelpImageURL *string `json:"help_image_url"`
|
||||
HelpText *string `json:"help_text"`
|
||||
|
||||
// Cancel rate limit settings
|
||||
CancelRateLimitEnabled *bool `json:"cancel_rate_limit_enabled"`
|
||||
CancelRateLimitMax *int `json:"cancel_rate_limit_max"`
|
||||
CancelRateLimitWindow *int `json:"cancel_rate_limit_window"`
|
||||
CancelRateLimitUnit *string `json:"cancel_rate_limit_unit"`
|
||||
CancelRateLimitMode *string `json:"cancel_rate_limit_window_mode"`
|
||||
}
|
||||
|
||||
// MethodLimits holds per-payment-type limits.
|
||||
type MethodLimits struct {
|
||||
PaymentType string `json:"payment_type"`
|
||||
FeeRate float64 `json:"fee_rate"`
|
||||
DailyLimit float64 `json:"daily_limit"`
|
||||
SingleMin float64 `json:"single_min"`
|
||||
SingleMax float64 `json:"single_max"`
|
||||
}
|
||||
|
||||
// MethodLimitsResponse is the full response for the user-facing /limits API.
|
||||
// It includes per-method limits and the global widest range (union of all methods).
|
||||
type MethodLimitsResponse struct {
|
||||
Methods map[string]MethodLimits `json:"methods"`
|
||||
GlobalMin float64 `json:"global_min"` // 0 = no minimum
|
||||
GlobalMax float64 `json:"global_max"` // 0 = no maximum
|
||||
}
|
||||
|
||||
type CreateProviderInstanceRequest struct {
|
||||
ProviderKey string `json:"provider_key"`
|
||||
Name string `json:"name"`
|
||||
Config map[string]string `json:"config"`
|
||||
SupportedTypes []string `json:"supported_types"`
|
||||
Enabled bool `json:"enabled"`
|
||||
PaymentMode string `json:"payment_mode"`
|
||||
SortOrder int `json:"sort_order"`
|
||||
Limits string `json:"limits"`
|
||||
RefundEnabled bool `json:"refund_enabled"`
|
||||
}
|
||||
|
||||
type UpdateProviderInstanceRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Config map[string]string `json:"config"`
|
||||
SupportedTypes []string `json:"supported_types"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
PaymentMode *string `json:"payment_mode"`
|
||||
SortOrder *int `json:"sort_order"`
|
||||
Limits *string `json:"limits"`
|
||||
RefundEnabled *bool `json:"refund_enabled"`
|
||||
}
|
||||
type CreatePlanRequest struct {
|
||||
GroupID int64 `json:"group_id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Price float64 `json:"price"`
|
||||
OriginalPrice *float64 `json:"original_price"`
|
||||
ValidityDays int `json:"validity_days"`
|
||||
ValidityUnit string `json:"validity_unit"`
|
||||
Features string `json:"features"`
|
||||
ProductName string `json:"product_name"`
|
||||
ForSale bool `json:"for_sale"`
|
||||
SortOrder int `json:"sort_order"`
|
||||
}
|
||||
|
||||
type UpdatePlanRequest struct {
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Name *string `json:"name"`
|
||||
Description *string `json:"description"`
|
||||
Price *float64 `json:"price"`
|
||||
OriginalPrice *float64 `json:"original_price"`
|
||||
ValidityDays *int `json:"validity_days"`
|
||||
ValidityUnit *string `json:"validity_unit"`
|
||||
Features *string `json:"features"`
|
||||
ProductName *string `json:"product_name"`
|
||||
ForSale *bool `json:"for_sale"`
|
||||
SortOrder *int `json:"sort_order"`
|
||||
}
|
||||
|
||||
// PaymentConfigService manages payment configuration and CRUD for
|
||||
// provider instances, channels, and subscription plans.
|
||||
type PaymentConfigService struct {
|
||||
entClient *dbent.Client
|
||||
settingRepo SettingRepository
|
||||
encryptionKey []byte
|
||||
}
|
||||
|
||||
// NewPaymentConfigService creates a new PaymentConfigService.
|
||||
func NewPaymentConfigService(entClient *dbent.Client, settingRepo SettingRepository, encryptionKey []byte) *PaymentConfigService {
|
||||
return &PaymentConfigService{entClient: entClient, settingRepo: settingRepo, encryptionKey: encryptionKey}
|
||||
}
|
||||
|
||||
// IsPaymentEnabled returns whether the payment system is enabled.
|
||||
func (s *PaymentConfigService) IsPaymentEnabled(ctx context.Context) bool {
|
||||
val, err := s.settingRepo.GetValue(ctx, SettingPaymentEnabled)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return val == "true"
|
||||
}
|
||||
|
||||
// GetPaymentConfig returns the full payment configuration.
|
||||
func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentConfig, error) {
|
||||
keys := []string{
|
||||
SettingPaymentEnabled, SettingMinRechargeAmount, SettingMaxRechargeAmount,
|
||||
SettingDailyRechargeLimit, SettingOrderTimeoutMinutes, SettingMaxPendingOrders,
|
||||
SettingEnabledPaymentTypes, SettingBalancePayDisabled, SettingLoadBalanceStrategy,
|
||||
SettingProductNamePrefix, SettingProductNameSuffix,
|
||||
SettingHelpImageURL, SettingHelpText,
|
||||
SettingCancelRateLimitOn, SettingCancelRateLimitMax,
|
||||
SettingCancelWindowSize, SettingCancelWindowUnit, SettingCancelWindowMode,
|
||||
}
|
||||
vals, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get payment config settings: %w", err)
|
||||
}
|
||||
cfg := s.parsePaymentConfig(vals)
|
||||
// Load Stripe publishable key from the first enabled Stripe provider instance
|
||||
cfg.StripePublishableKey = s.getStripePublishableKey(ctx)
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *PaymentConfig {
|
||||
cfg := &PaymentConfig{
|
||||
Enabled: vals[SettingPaymentEnabled] == "true",
|
||||
MinAmount: pcParseFloat(vals[SettingMinRechargeAmount], 1),
|
||||
MaxAmount: pcParseFloat(vals[SettingMaxRechargeAmount], 0),
|
||||
DailyLimit: pcParseFloat(vals[SettingDailyRechargeLimit], 0),
|
||||
OrderTimeoutMin: pcParseInt(vals[SettingOrderTimeoutMinutes], defaultOrderTimeoutMin),
|
||||
MaxPendingOrders: pcParseInt(vals[SettingMaxPendingOrders], defaultMaxPendingOrders),
|
||||
BalanceDisabled: vals[SettingBalancePayDisabled] == "true",
|
||||
LoadBalanceStrategy: vals[SettingLoadBalanceStrategy],
|
||||
ProductNamePrefix: vals[SettingProductNamePrefix],
|
||||
ProductNameSuffix: vals[SettingProductNameSuffix],
|
||||
HelpImageURL: vals[SettingHelpImageURL],
|
||||
HelpText: vals[SettingHelpText],
|
||||
|
||||
CancelRateLimitEnabled: vals[SettingCancelRateLimitOn] == "true",
|
||||
CancelRateLimitMax: pcParseInt(vals[SettingCancelRateLimitMax], 10),
|
||||
CancelRateLimitWindow: pcParseInt(vals[SettingCancelWindowSize], 1),
|
||||
CancelRateLimitUnit: vals[SettingCancelWindowUnit],
|
||||
CancelRateLimitMode: vals[SettingCancelWindowMode],
|
||||
}
|
||||
if cfg.LoadBalanceStrategy == "" {
|
||||
cfg.LoadBalanceStrategy = payment.DefaultLoadBalanceStrategy
|
||||
}
|
||||
if raw := vals[SettingEnabledPaymentTypes]; raw != "" {
|
||||
for _, t := range strings.Split(raw, ",") {
|
||||
t = strings.TrimSpace(t)
|
||||
if t != "" {
|
||||
cfg.EnabledTypes = append(cfg.EnabledTypes, t)
|
||||
}
|
||||
}
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// getStripePublishableKey finds the publishable key from the first enabled Stripe provider instance.
|
||||
func (s *PaymentConfigService) getStripePublishableKey(ctx context.Context) string {
|
||||
instances, err := s.entClient.PaymentProviderInstance.Query().
|
||||
Where(
|
||||
paymentproviderinstance.EnabledEQ(true),
|
||||
paymentproviderinstance.ProviderKeyEQ(payment.TypeStripe),
|
||||
).Limit(1).All(ctx)
|
||||
if err != nil || len(instances) == 0 {
|
||||
return ""
|
||||
}
|
||||
cfg, err := s.decryptConfig(instances[0].Config)
|
||||
if err != nil || cfg == nil {
|
||||
return ""
|
||||
}
|
||||
return cfg[payment.ConfigKeyPublishableKey]
|
||||
}
|
||||
|
||||
// UpdatePaymentConfig updates the payment configuration settings.
|
||||
// NOTE: This function exceeds 30 lines because each field requires an independent
|
||||
// nil-check before serialisation — this is inherent to patch-style update patterns
|
||||
// and cannot be meaningfully decomposed without introducing unnecessary abstraction.
|
||||
func (s *PaymentConfigService) UpdatePaymentConfig(ctx context.Context, req UpdatePaymentConfigRequest) error {
|
||||
m := map[string]string{
|
||||
SettingPaymentEnabled: formatBoolOrEmpty(req.Enabled),
|
||||
SettingMinRechargeAmount: formatPositiveFloat(req.MinAmount),
|
||||
SettingMaxRechargeAmount: formatPositiveFloat(req.MaxAmount),
|
||||
SettingDailyRechargeLimit: formatPositiveFloat(req.DailyLimit),
|
||||
SettingOrderTimeoutMinutes: formatPositiveInt(req.OrderTimeoutMin),
|
||||
SettingMaxPendingOrders: formatPositiveInt(req.MaxPendingOrders),
|
||||
SettingBalancePayDisabled: formatBoolOrEmpty(req.BalanceDisabled),
|
||||
SettingLoadBalanceStrategy: derefStr(req.LoadBalanceStrategy),
|
||||
SettingProductNamePrefix: derefStr(req.ProductNamePrefix),
|
||||
SettingProductNameSuffix: derefStr(req.ProductNameSuffix),
|
||||
SettingHelpImageURL: derefStr(req.HelpImageURL),
|
||||
SettingHelpText: derefStr(req.HelpText),
|
||||
SettingCancelRateLimitOn: formatBoolOrEmpty(req.CancelRateLimitEnabled),
|
||||
SettingCancelRateLimitMax: formatPositiveInt(req.CancelRateLimitMax),
|
||||
SettingCancelWindowSize: formatPositiveInt(req.CancelRateLimitWindow),
|
||||
SettingCancelWindowUnit: derefStr(req.CancelRateLimitUnit),
|
||||
SettingCancelWindowMode: derefStr(req.CancelRateLimitMode),
|
||||
}
|
||||
if req.EnabledTypes != nil {
|
||||
m[SettingEnabledPaymentTypes] = strings.Join(req.EnabledTypes, ",")
|
||||
} else {
|
||||
m[SettingEnabledPaymentTypes] = ""
|
||||
}
|
||||
return s.settingRepo.SetMultiple(ctx, m)
|
||||
}
|
||||
|
||||
func formatBoolOrEmpty(v *bool) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return strconv.FormatBool(*v)
|
||||
}
|
||||
|
||||
func formatPositiveFloat(v *float64) string {
|
||||
if v == nil || *v <= 0 {
|
||||
return "" // empty → parsePaymentConfig uses default
|
||||
}
|
||||
return strconv.FormatFloat(*v, 'f', 2, 64)
|
||||
}
|
||||
|
||||
func formatPositiveInt(v *int) string {
|
||||
if v == nil || *v <= 0 {
|
||||
return ""
|
||||
}
|
||||
return strconv.Itoa(*v)
|
||||
}
|
||||
|
||||
func derefStr(v *string) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return *v
|
||||
}
|
||||
|
||||
func splitTypes(s string) []string {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(s, ",")
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
result = append(result, p)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func joinTypes(types []string) string {
|
||||
return strings.Join(types, ",")
|
||||
}
|
||||
|
||||
func pcParseFloat(s string, defaultVal float64) float64 {
|
||||
if s == "" {
|
||||
return defaultVal
|
||||
}
|
||||
v, err := strconv.ParseFloat(s, 64)
|
||||
if err != nil {
|
||||
return defaultVal
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func pcParseInt(s string, defaultVal int) int {
|
||||
if s == "" {
|
||||
return defaultVal
|
||||
}
|
||||
v, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return defaultVal
|
||||
}
|
||||
return v
|
||||
}
|
||||
206
backend/internal/service/payment_config_service_test.go
Normal file
206
backend/internal/service/payment_config_service_test.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
)
|
||||
|
||||
func TestPcParseFloat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
defaultVal float64
|
||||
expected float64
|
||||
}{
|
||||
{"empty string returns default", "", 1.0, 1.0},
|
||||
{"valid float", "3.14", 0, 3.14},
|
||||
{"valid integer as float", "42", 0, 42.0},
|
||||
{"invalid string returns default", "notanumber", 9.99, 9.99},
|
||||
{"zero value", "0", 5.0, 0},
|
||||
{"negative value", "-10.5", 0, -10.5},
|
||||
{"very large value", "99999999.99", 0, 99999999.99},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := pcParseFloat(tt.input, tt.defaultVal)
|
||||
if got != tt.expected {
|
||||
t.Fatalf("pcParseFloat(%q, %v) = %v, want %v", tt.input, tt.defaultVal, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPcParseInt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
defaultVal int
|
||||
expected int
|
||||
}{
|
||||
{"empty string returns default", "", 30, 30},
|
||||
{"valid int", "10", 0, 10},
|
||||
{"invalid string returns default", "abc", 5, 5},
|
||||
{"float string returns default", "3.14", 0, 0},
|
||||
{"zero value", "0", 99, 0},
|
||||
{"negative value", "-1", 0, -1},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := pcParseInt(tt.input, tt.defaultVal)
|
||||
if got != tt.expected {
|
||||
t.Fatalf("pcParseInt(%q, %v) = %v, want %v", tt.input, tt.defaultVal, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePaymentConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := &PaymentConfigService{}
|
||||
|
||||
t.Run("empty vals uses defaults", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
cfg := svc.parsePaymentConfig(map[string]string{})
|
||||
if cfg.Enabled {
|
||||
t.Fatal("expected Enabled=false by default")
|
||||
}
|
||||
if cfg.MinAmount != 1 {
|
||||
t.Fatalf("expected MinAmount=1, got %v", cfg.MinAmount)
|
||||
}
|
||||
if cfg.MaxAmount != 0 {
|
||||
t.Fatalf("expected MaxAmount=0 (no limit), got %v", cfg.MaxAmount)
|
||||
}
|
||||
if cfg.OrderTimeoutMin != 30 {
|
||||
t.Fatalf("expected OrderTimeoutMin=30, got %v", cfg.OrderTimeoutMin)
|
||||
}
|
||||
if cfg.MaxPendingOrders != 3 {
|
||||
t.Fatalf("expected MaxPendingOrders=3, got %v", cfg.MaxPendingOrders)
|
||||
}
|
||||
if cfg.LoadBalanceStrategy != payment.DefaultLoadBalanceStrategy {
|
||||
t.Fatalf("expected LoadBalanceStrategy=%s, got %q", payment.DefaultLoadBalanceStrategy, cfg.LoadBalanceStrategy)
|
||||
}
|
||||
if len(cfg.EnabledTypes) != 0 {
|
||||
t.Fatalf("expected empty EnabledTypes, got %v", cfg.EnabledTypes)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("all values populated", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
vals := map[string]string{
|
||||
SettingPaymentEnabled: "true",
|
||||
SettingMinRechargeAmount: "5.00",
|
||||
SettingMaxRechargeAmount: "1000.00",
|
||||
SettingDailyRechargeLimit: "5000.00",
|
||||
SettingOrderTimeoutMinutes: "15",
|
||||
SettingMaxPendingOrders: "5",
|
||||
SettingEnabledPaymentTypes: "alipay,wxpay,stripe",
|
||||
SettingBalancePayDisabled: "true",
|
||||
SettingLoadBalanceStrategy: "least_amount",
|
||||
SettingProductNamePrefix: "PRE",
|
||||
SettingProductNameSuffix: "SUF",
|
||||
}
|
||||
cfg := svc.parsePaymentConfig(vals)
|
||||
|
||||
if !cfg.Enabled {
|
||||
t.Fatal("expected Enabled=true")
|
||||
}
|
||||
if cfg.MinAmount != 5 {
|
||||
t.Fatalf("MinAmount = %v, want 5", cfg.MinAmount)
|
||||
}
|
||||
if cfg.MaxAmount != 1000 {
|
||||
t.Fatalf("MaxAmount = %v, want 1000", cfg.MaxAmount)
|
||||
}
|
||||
if cfg.DailyLimit != 5000 {
|
||||
t.Fatalf("DailyLimit = %v, want 5000", cfg.DailyLimit)
|
||||
}
|
||||
if cfg.OrderTimeoutMin != 15 {
|
||||
t.Fatalf("OrderTimeoutMin = %v, want 15", cfg.OrderTimeoutMin)
|
||||
}
|
||||
if cfg.MaxPendingOrders != 5 {
|
||||
t.Fatalf("MaxPendingOrders = %v, want 5", cfg.MaxPendingOrders)
|
||||
}
|
||||
if len(cfg.EnabledTypes) != 3 {
|
||||
t.Fatalf("EnabledTypes len = %d, want 3", len(cfg.EnabledTypes))
|
||||
}
|
||||
if cfg.EnabledTypes[0] != "alipay" || cfg.EnabledTypes[1] != "wxpay" || cfg.EnabledTypes[2] != "stripe" {
|
||||
t.Fatalf("EnabledTypes = %v, want [alipay wxpay stripe]", cfg.EnabledTypes)
|
||||
}
|
||||
if !cfg.BalanceDisabled {
|
||||
t.Fatal("expected BalanceDisabled=true")
|
||||
}
|
||||
if cfg.LoadBalanceStrategy != "least_amount" {
|
||||
t.Fatalf("LoadBalanceStrategy = %q, want %q", cfg.LoadBalanceStrategy, "least_amount")
|
||||
}
|
||||
if cfg.ProductNamePrefix != "PRE" {
|
||||
t.Fatalf("ProductNamePrefix = %q, want %q", cfg.ProductNamePrefix, "PRE")
|
||||
}
|
||||
if cfg.ProductNameSuffix != "SUF" {
|
||||
t.Fatalf("ProductNameSuffix = %q, want %q", cfg.ProductNameSuffix, "SUF")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("enabled types with spaces are trimmed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
vals := map[string]string{
|
||||
SettingEnabledPaymentTypes: " alipay , wxpay ",
|
||||
}
|
||||
cfg := svc.parsePaymentConfig(vals)
|
||||
if len(cfg.EnabledTypes) != 2 {
|
||||
t.Fatalf("EnabledTypes len = %d, want 2", len(cfg.EnabledTypes))
|
||||
}
|
||||
if cfg.EnabledTypes[0] != "alipay" || cfg.EnabledTypes[1] != "wxpay" {
|
||||
t.Fatalf("EnabledTypes = %v, want [alipay wxpay]", cfg.EnabledTypes)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty enabled types string", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
vals := map[string]string{
|
||||
SettingEnabledPaymentTypes: "",
|
||||
}
|
||||
cfg := svc.parsePaymentConfig(vals)
|
||||
if len(cfg.EnabledTypes) != 0 {
|
||||
t.Fatalf("expected empty EnabledTypes for empty string, got %v", cfg.EnabledTypes)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetBasePaymentType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{payment.TypeEasyPay, payment.TypeEasyPay},
|
||||
{payment.TypeStripe, payment.TypeStripe},
|
||||
{payment.TypeCard, payment.TypeStripe},
|
||||
{payment.TypeLink, payment.TypeStripe},
|
||||
{payment.TypeAlipay, payment.TypeAlipay},
|
||||
{payment.TypeAlipayDirect, payment.TypeAlipay},
|
||||
{payment.TypeWxpay, payment.TypeWxpay},
|
||||
{payment.TypeWxpayDirect, payment.TypeWxpay},
|
||||
{"unknown", "unknown"},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := payment.GetBasePaymentType(tt.input)
|
||||
if got != tt.expected {
|
||||
t.Fatalf("GetBasePaymentType(%q) = %q, want %q", tt.input, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
325
backend/internal/service/payment_fulfillment.go
Normal file
325
backend/internal/service/payment_fulfillment.go
Normal file
@@ -0,0 +1,325 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
|
||||
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
// --- Payment Notification & Fulfillment ---
|
||||
|
||||
func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payment.PaymentNotification, pk string) error {
|
||||
if n.Status != payment.NotificationStatusSuccess {
|
||||
return nil
|
||||
}
|
||||
// Look up order by out_trade_no (the external order ID we sent to the provider)
|
||||
order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(n.OrderID)).Only(ctx)
|
||||
if err != nil {
|
||||
// Fallback: try legacy format (sub2_N where N is DB ID)
|
||||
trimmed := strings.TrimPrefix(n.OrderID, orderIDPrefix)
|
||||
if oid, parseErr := strconv.ParseInt(trimmed, 10, 64); parseErr == nil {
|
||||
return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk)
|
||||
}
|
||||
return fmt.Errorf("order not found for out_trade_no: %s", n.OrderID)
|
||||
}
|
||||
return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk)
|
||||
}
|
||||
|
||||
func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string) error {
|
||||
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
|
||||
if err != nil {
|
||||
slog.Error("order not found", "orderID", oid)
|
||||
return nil
|
||||
}
|
||||
// Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount).
|
||||
// Also skip if paid is NaN/Inf (malformed provider data).
|
||||
if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) {
|
||||
if math.Abs(paid-o.PayAmount) > amountToleranceCNY {
|
||||
s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo})
|
||||
return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid)
|
||||
}
|
||||
}
|
||||
// Use order's expected amount when provider didn't report one
|
||||
if paid <= 0 || math.IsNaN(paid) || math.IsInf(paid, 0) {
|
||||
paid = o.PayAmount
|
||||
}
|
||||
return s.toPaid(ctx, o, tradeNo, paid, pk)
|
||||
}
|
||||
|
||||
func (s *PaymentService) toPaid(ctx context.Context, o *dbent.PaymentOrder, tradeNo string, paid float64, pk string) error {
|
||||
previousStatus := o.Status
|
||||
now := time.Now()
|
||||
grace := now.Add(-paymentGraceMinutes * time.Minute)
|
||||
c, err := s.entClient.PaymentOrder.Update().Where(
|
||||
paymentorder.IDEQ(o.ID),
|
||||
paymentorder.Or(
|
||||
paymentorder.StatusEQ(OrderStatusPending),
|
||||
paymentorder.StatusEQ(OrderStatusCancelled),
|
||||
paymentorder.And(
|
||||
paymentorder.StatusEQ(OrderStatusExpired),
|
||||
paymentorder.UpdatedAtGTE(grace),
|
||||
),
|
||||
),
|
||||
).SetStatus(OrderStatusPaid).SetPayAmount(paid).SetPaymentTradeNo(tradeNo).SetPaidAt(now).ClearFailedAt().ClearFailedReason().Save(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update to PAID: %w", err)
|
||||
}
|
||||
if c == 0 {
|
||||
return s.alreadyProcessed(ctx, o)
|
||||
}
|
||||
if previousStatus == OrderStatusCancelled || previousStatus == OrderStatusExpired {
|
||||
slog.Info("order recovered from webhook payment success",
|
||||
"orderID", o.ID,
|
||||
"previousStatus", previousStatus,
|
||||
"tradeNo", tradeNo,
|
||||
"provider", pk,
|
||||
)
|
||||
s.writeAuditLog(ctx, o.ID, "ORDER_RECOVERED", pk, map[string]any{
|
||||
"previous_status": previousStatus,
|
||||
"tradeNo": tradeNo,
|
||||
"paidAmount": paid,
|
||||
"reason": "webhook payment success received after order " + previousStatus,
|
||||
})
|
||||
}
|
||||
s.writeAuditLog(ctx, o.ID, "ORDER_PAID", pk, map[string]any{"tradeNo": tradeNo, "paidAmount": paid})
|
||||
return s.executeFulfillment(ctx, o.ID)
|
||||
}
|
||||
|
||||
func (s *PaymentService) alreadyProcessed(ctx context.Context, o *dbent.PaymentOrder) error {
|
||||
cur, err := s.entClient.PaymentOrder.Get(ctx, o.ID)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
switch cur.Status {
|
||||
case OrderStatusCompleted, OrderStatusRefunded:
|
||||
return nil
|
||||
case OrderStatusFailed:
|
||||
return s.executeFulfillment(ctx, o.ID)
|
||||
case OrderStatusPaid, OrderStatusRecharging:
|
||||
return fmt.Errorf("order %d is being processed", o.ID)
|
||||
case OrderStatusExpired:
|
||||
slog.Warn("webhook payment success for expired order beyond grace period",
|
||||
"orderID", o.ID,
|
||||
"status", cur.Status,
|
||||
"updatedAt", cur.UpdatedAt,
|
||||
)
|
||||
s.writeAuditLog(ctx, o.ID, "PAYMENT_AFTER_EXPIRY", "system", map[string]any{
|
||||
"status": cur.Status,
|
||||
"updatedAt": cur.UpdatedAt,
|
||||
"reason": "payment arrived after expiry grace period",
|
||||
})
|
||||
return nil
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PaymentService) executeFulfillment(ctx context.Context, oid int64) error {
|
||||
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get order: %w", err)
|
||||
}
|
||||
if o.OrderType == payment.OrderTypeSubscription {
|
||||
return s.ExecuteSubscriptionFulfillment(ctx, oid)
|
||||
}
|
||||
return s.ExecuteBalanceFulfillment(ctx, oid)
|
||||
}
|
||||
|
||||
func (s *PaymentService) ExecuteBalanceFulfillment(ctx context.Context, oid int64) error {
|
||||
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
|
||||
if err != nil {
|
||||
return infraerrors.NotFound("NOT_FOUND", "order not found")
|
||||
}
|
||||
if o.Status == OrderStatusCompleted {
|
||||
return nil
|
||||
}
|
||||
if psIsRefundStatus(o.Status) {
|
||||
return infraerrors.BadRequest("INVALID_STATUS", "refund-related order cannot fulfill")
|
||||
}
|
||||
if o.Status != OrderStatusPaid && o.Status != OrderStatusFailed {
|
||||
return infraerrors.BadRequest("INVALID_STATUS", "order cannot fulfill in status "+o.Status)
|
||||
}
|
||||
c, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(oid), paymentorder.StatusIn(OrderStatusPaid, OrderStatusFailed)).SetStatus(OrderStatusRecharging).Save(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("lock: %w", err)
|
||||
}
|
||||
if c == 0 {
|
||||
return nil
|
||||
}
|
||||
if err := s.doBalance(ctx, o); err != nil {
|
||||
s.markFailed(ctx, oid, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// redeemAction represents the idempotency decision for balance fulfillment.
|
||||
type redeemAction int
|
||||
|
||||
const (
|
||||
// redeemActionCreate: code does not exist — create it, then redeem.
|
||||
redeemActionCreate redeemAction = iota
|
||||
// redeemActionRedeem: code exists but is unused — skip creation, redeem only.
|
||||
redeemActionRedeem
|
||||
// redeemActionSkipCompleted: code exists and is already used — skip to mark completed.
|
||||
redeemActionSkipCompleted
|
||||
)
|
||||
|
||||
// resolveRedeemAction decides the idempotency action based on an existing redeem code lookup.
|
||||
// existing is the result of GetByCode; lookupErr is the error from that call.
|
||||
func resolveRedeemAction(existing *RedeemCode, lookupErr error) redeemAction {
|
||||
if existing == nil || lookupErr != nil {
|
||||
return redeemActionCreate
|
||||
}
|
||||
if existing.IsUsed() {
|
||||
return redeemActionSkipCompleted
|
||||
}
|
||||
return redeemActionRedeem
|
||||
}
|
||||
|
||||
func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) error {
|
||||
// Idempotency: check if redeem code already exists (from a previous partial run)
|
||||
existing, lookupErr := s.redeemService.GetByCode(ctx, o.RechargeCode)
|
||||
action := resolveRedeemAction(existing, lookupErr)
|
||||
|
||||
switch action {
|
||||
case redeemActionSkipCompleted:
|
||||
// Code already created and redeemed — just mark completed
|
||||
return s.markCompleted(ctx, o, "RECHARGE_SUCCESS")
|
||||
case redeemActionCreate:
|
||||
rc := &RedeemCode{Code: o.RechargeCode, Type: RedeemTypeBalance, Value: o.Amount, Status: StatusUnused}
|
||||
if err := s.redeemService.CreateCode(ctx, rc); err != nil {
|
||||
return fmt.Errorf("create redeem code: %w", err)
|
||||
}
|
||||
case redeemActionRedeem:
|
||||
// Code exists but unused — skip creation, proceed to redeem
|
||||
}
|
||||
if _, err := s.redeemService.Redeem(ctx, o.UserID, o.RechargeCode); err != nil {
|
||||
return fmt.Errorf("redeem balance: %w", err)
|
||||
}
|
||||
return s.markCompleted(ctx, o, "RECHARGE_SUCCESS")
|
||||
}
|
||||
|
||||
func (s *PaymentService) markCompleted(ctx context.Context, o *dbent.PaymentOrder, auditAction string) error {
|
||||
now := time.Now()
|
||||
_, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(o.ID), paymentorder.StatusEQ(OrderStatusRecharging)).SetStatus(OrderStatusCompleted).SetCompletedAt(now).Save(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("mark completed: %w", err)
|
||||
}
|
||||
s.writeAuditLog(ctx, o.ID, auditAction, "system", map[string]any{"rechargeCode": o.RechargeCode, "amount": o.Amount})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PaymentService) ExecuteSubscriptionFulfillment(ctx context.Context, oid int64) error {
|
||||
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
|
||||
if err != nil {
|
||||
return infraerrors.NotFound("NOT_FOUND", "order not found")
|
||||
}
|
||||
if o.Status == OrderStatusCompleted {
|
||||
return nil
|
||||
}
|
||||
if psIsRefundStatus(o.Status) {
|
||||
return infraerrors.BadRequest("INVALID_STATUS", "refund-related order cannot fulfill")
|
||||
}
|
||||
if o.Status != OrderStatusPaid && o.Status != OrderStatusFailed {
|
||||
return infraerrors.BadRequest("INVALID_STATUS", "order cannot fulfill in status "+o.Status)
|
||||
}
|
||||
if o.SubscriptionGroupID == nil || o.SubscriptionDays == nil {
|
||||
return infraerrors.BadRequest("INVALID_STATUS", "missing subscription info")
|
||||
}
|
||||
c, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(oid), paymentorder.StatusIn(OrderStatusPaid, OrderStatusFailed)).SetStatus(OrderStatusRecharging).Save(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("lock: %w", err)
|
||||
}
|
||||
if c == 0 {
|
||||
return nil
|
||||
}
|
||||
if err := s.doSub(ctx, o); err != nil {
|
||||
s.markFailed(ctx, oid, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PaymentService) doSub(ctx context.Context, o *dbent.PaymentOrder) error {
|
||||
gid := *o.SubscriptionGroupID
|
||||
days := *o.SubscriptionDays
|
||||
g, err := s.groupRepo.GetByID(ctx, gid)
|
||||
if err != nil || g.Status != payment.EntityStatusActive {
|
||||
return fmt.Errorf("group %d no longer exists or inactive", gid)
|
||||
}
|
||||
// Idempotency: check audit log to see if subscription was already assigned.
|
||||
// Prevents double-extension on retry after markCompleted fails.
|
||||
if s.hasAuditLog(ctx, o.ID, "SUBSCRIPTION_SUCCESS") {
|
||||
slog.Info("subscription already assigned for order, skipping", "orderID", o.ID, "groupID", gid)
|
||||
return s.markCompleted(ctx, o, "SUBSCRIPTION_SUCCESS")
|
||||
}
|
||||
orderNote := fmt.Sprintf("payment order %d", o.ID)
|
||||
_, _, err = s.subscriptionSvc.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{UserID: o.UserID, GroupID: gid, ValidityDays: days, AssignedBy: 0, Notes: orderNote})
|
||||
if err != nil {
|
||||
return fmt.Errorf("assign subscription: %w", err)
|
||||
}
|
||||
return s.markCompleted(ctx, o, "SUBSCRIPTION_SUCCESS")
|
||||
}
|
||||
|
||||
func (s *PaymentService) hasAuditLog(ctx context.Context, orderID int64, action string) bool {
|
||||
oid := strconv.FormatInt(orderID, 10)
|
||||
c, _ := s.entClient.PaymentAuditLog.Query().
|
||||
Where(paymentauditlog.OrderIDEQ(oid), paymentauditlog.ActionEQ(action)).
|
||||
Limit(1).Count(ctx)
|
||||
return c > 0
|
||||
}
|
||||
|
||||
func (s *PaymentService) markFailed(ctx context.Context, oid int64, cause error) {
|
||||
now := time.Now()
|
||||
r := psErrMsg(cause)
|
||||
// Only mark FAILED if still in RECHARGING state — prevents overwriting
|
||||
// a COMPLETED order when markCompleted failed but fulfillment succeeded.
|
||||
c, e := s.entClient.PaymentOrder.Update().
|
||||
Where(paymentorder.IDEQ(oid), paymentorder.StatusEQ(OrderStatusRecharging)).
|
||||
SetStatus(OrderStatusFailed).SetFailedAt(now).SetFailedReason(r).Save(ctx)
|
||||
if e != nil {
|
||||
slog.Error("mark FAILED", "orderID", oid, "error", e)
|
||||
}
|
||||
if c > 0 {
|
||||
s.writeAuditLog(ctx, oid, "FULFILLMENT_FAILED", "system", map[string]any{"reason": r})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PaymentService) RetryFulfillment(ctx context.Context, oid int64) error {
|
||||
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
|
||||
if err != nil {
|
||||
return infraerrors.NotFound("NOT_FOUND", "order not found")
|
||||
}
|
||||
if o.PaidAt == nil {
|
||||
return infraerrors.BadRequest("INVALID_STATUS", "order is not paid")
|
||||
}
|
||||
if psIsRefundStatus(o.Status) {
|
||||
return infraerrors.BadRequest("INVALID_STATUS", "refund-related order cannot retry")
|
||||
}
|
||||
if o.Status == OrderStatusRecharging {
|
||||
return infraerrors.Conflict("CONFLICT", "order is being processed")
|
||||
}
|
||||
if o.Status == OrderStatusCompleted {
|
||||
return infraerrors.BadRequest("INVALID_STATUS", "order already completed")
|
||||
}
|
||||
if o.Status != OrderStatusFailed && o.Status != OrderStatusPaid {
|
||||
return infraerrors.BadRequest("INVALID_STATUS", "only paid and failed orders can retry")
|
||||
}
|
||||
_, err = s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(oid), paymentorder.StatusIn(OrderStatusFailed, OrderStatusPaid)).SetStatus(OrderStatusPaid).ClearFailedAt().ClearFailedReason().Save(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reset for retry: %w", err)
|
||||
}
|
||||
s.writeAuditLog(ctx, oid, "RECHARGE_RETRY", "admin", map[string]any{"detail": "admin manual retry"})
|
||||
return s.executeFulfillment(ctx, oid)
|
||||
}
|
||||
163
backend/internal/service/payment_fulfillment_test.go
Normal file
163
backend/internal/service/payment_fulfillment_test.go
Normal file
@@ -0,0 +1,163 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// resolveRedeemAction — pure idempotency decision logic
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResolveRedeemAction_CodeNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
action := resolveRedeemAction(nil, nil)
|
||||
assert.Equal(t, redeemActionCreate, action, "nil code with nil error should create")
|
||||
}
|
||||
|
||||
func TestResolveRedeemAction_LookupError(t *testing.T) {
|
||||
t.Parallel()
|
||||
action := resolveRedeemAction(nil, errors.New("db connection lost"))
|
||||
assert.Equal(t, redeemActionCreate, action, "lookup error should fall back to create")
|
||||
}
|
||||
|
||||
func TestResolveRedeemAction_LookupErrorWithNonNilCode(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Edge case: both code and error are non-nil (shouldn't happen in practice,
|
||||
// but the function should still treat error as authoritative)
|
||||
code := &RedeemCode{Status: StatusUnused}
|
||||
action := resolveRedeemAction(code, errors.New("partial error"))
|
||||
assert.Equal(t, redeemActionCreate, action, "non-nil error should always result in create regardless of code")
|
||||
}
|
||||
|
||||
func TestResolveRedeemAction_CodeExistsAndUsed(t *testing.T) {
|
||||
t.Parallel()
|
||||
code := &RedeemCode{
|
||||
Code: "test-code-123",
|
||||
Status: StatusUsed,
|
||||
Type: RedeemTypeBalance,
|
||||
Value: 10.0,
|
||||
}
|
||||
action := resolveRedeemAction(code, nil)
|
||||
assert.Equal(t, redeemActionSkipCompleted, action, "used code should skip to completed")
|
||||
}
|
||||
|
||||
func TestResolveRedeemAction_CodeExistsAndUnused(t *testing.T) {
|
||||
t.Parallel()
|
||||
code := &RedeemCode{
|
||||
Code: "test-code-456",
|
||||
Status: StatusUnused,
|
||||
Type: RedeemTypeBalance,
|
||||
Value: 25.0,
|
||||
}
|
||||
action := resolveRedeemAction(code, nil)
|
||||
assert.Equal(t, redeemActionRedeem, action, "unused code should skip creation and proceed to redeem")
|
||||
}
|
||||
|
||||
func TestResolveRedeemAction_CodeExistsWithExpiredStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
// A code with a non-standard status (neither "unused" nor "used")
|
||||
// should NOT be treated as used, so it falls through to redeemActionRedeem.
|
||||
code := &RedeemCode{
|
||||
Code: "expired-code",
|
||||
Status: StatusExpired,
|
||||
}
|
||||
action := resolveRedeemAction(code, nil)
|
||||
assert.Equal(t, redeemActionRedeem, action, "expired-status code is not IsUsed(), should redeem")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Table-driven comprehensive test
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResolveRedeemAction_Table(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
code *RedeemCode
|
||||
err error
|
||||
expected redeemAction
|
||||
}{
|
||||
{
|
||||
name: "nil code, nil error — first run",
|
||||
code: nil,
|
||||
err: nil,
|
||||
expected: redeemActionCreate,
|
||||
},
|
||||
{
|
||||
name: "nil code, lookup error — treat as not found",
|
||||
code: nil,
|
||||
err: ErrRedeemCodeNotFound,
|
||||
expected: redeemActionCreate,
|
||||
},
|
||||
{
|
||||
name: "nil code, generic DB error — treat as not found",
|
||||
code: nil,
|
||||
err: errors.New("connection refused"),
|
||||
expected: redeemActionCreate,
|
||||
},
|
||||
{
|
||||
name: "code exists, used — previous run completed redeem",
|
||||
code: &RedeemCode{Status: StatusUsed},
|
||||
err: nil,
|
||||
expected: redeemActionSkipCompleted,
|
||||
},
|
||||
{
|
||||
name: "code exists, unused — previous run created code but crashed before redeem",
|
||||
code: &RedeemCode{Status: StatusUnused},
|
||||
err: nil,
|
||||
expected: redeemActionRedeem,
|
||||
},
|
||||
{
|
||||
name: "code exists but error also set — error takes precedence",
|
||||
code: &RedeemCode{Status: StatusUsed},
|
||||
err: errors.New("unexpected"),
|
||||
expected: redeemActionCreate,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := resolveRedeemAction(tt.code, tt.err)
|
||||
assert.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// redeemAction enum value sanity
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRedeemAction_DistinctValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Ensure the three actions have distinct values (iota correctness)
|
||||
assert.NotEqual(t, redeemActionCreate, redeemActionRedeem)
|
||||
assert.NotEqual(t, redeemActionCreate, redeemActionSkipCompleted)
|
||||
assert.NotEqual(t, redeemActionRedeem, redeemActionSkipCompleted)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RedeemCode.IsUsed / CanUse interaction with resolveRedeemAction
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResolveRedeemAction_IsUsedCanUseConsistency(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
usedCode := &RedeemCode{Status: StatusUsed}
|
||||
unusedCode := &RedeemCode{Status: StatusUnused}
|
||||
|
||||
// Verify our decision function is consistent with the domain model methods
|
||||
assert.True(t, usedCode.IsUsed())
|
||||
assert.False(t, usedCode.CanUse())
|
||||
assert.Equal(t, redeemActionSkipCompleted, resolveRedeemAction(usedCode, nil))
|
||||
|
||||
assert.False(t, unusedCode.IsUsed())
|
||||
assert.True(t, unusedCode.CanUse())
|
||||
assert.Equal(t, redeemActionRedeem, resolveRedeemAction(unusedCode, nil))
|
||||
}
|
||||
546
backend/internal/service/payment_order.go
Normal file
546
backend/internal/service/payment_order.go
Normal file
@@ -0,0 +1,546 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
|
||||
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment/provider"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
// --- Order Creation ---
|
||||
|
||||
func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest) (*CreateOrderResponse, error) {
|
||||
if req.OrderType == "" {
|
||||
req.OrderType = payment.OrderTypeBalance
|
||||
}
|
||||
cfg, err := s.configService.GetPaymentConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get payment config: %w", err)
|
||||
}
|
||||
if !cfg.Enabled {
|
||||
return nil, infraerrors.Forbidden("PAYMENT_DISABLED", "payment system is disabled")
|
||||
}
|
||||
plan, err := s.validateOrderInput(ctx, req, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.checkCancelRateLimit(ctx, req.UserID, cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user, err := s.userRepo.GetByID(ctx, req.UserID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
if user.Status != payment.EntityStatusActive {
|
||||
return nil, infraerrors.Forbidden("USER_INACTIVE", "user account is disabled")
|
||||
}
|
||||
amount := req.Amount
|
||||
if plan != nil {
|
||||
amount = plan.Price
|
||||
}
|
||||
feeRate := s.getFeeRate(req.PaymentType)
|
||||
payAmountStr := payment.CalculatePayAmount(amount, feeRate)
|
||||
payAmount, _ := strconv.ParseFloat(payAmountStr, 64)
|
||||
order, err := s.createOrderInTx(ctx, req, user, plan, cfg, amount, feeRate, payAmount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := s.invokeProvider(ctx, order, req, cfg, payAmountStr, payAmount, plan)
|
||||
if err != nil {
|
||||
_, _ = s.entClient.PaymentOrder.UpdateOneID(order.ID).
|
||||
SetStatus(OrderStatusFailed).
|
||||
Save(ctx)
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *PaymentService) validateOrderInput(ctx context.Context, req CreateOrderRequest, cfg *PaymentConfig) (*dbent.SubscriptionPlan, error) {
|
||||
if req.OrderType == payment.OrderTypeBalance && cfg.BalanceDisabled {
|
||||
return nil, infraerrors.Forbidden("BALANCE_PAYMENT_DISABLED", "balance recharge has been disabled")
|
||||
}
|
||||
if req.OrderType == payment.OrderTypeSubscription {
|
||||
return s.validateSubOrder(ctx, req)
|
||||
}
|
||||
if math.IsNaN(req.Amount) || math.IsInf(req.Amount, 0) || req.Amount <= 0 {
|
||||
return nil, infraerrors.BadRequest("INVALID_AMOUNT", "amount must be a positive number")
|
||||
}
|
||||
if (cfg.MinAmount > 0 && req.Amount < cfg.MinAmount) || (cfg.MaxAmount > 0 && req.Amount > cfg.MaxAmount) {
|
||||
return nil, infraerrors.BadRequest("INVALID_AMOUNT", "amount out of range").
|
||||
WithMetadata(map[string]string{"min": fmt.Sprintf("%.2f", cfg.MinAmount), "max": fmt.Sprintf("%.2f", cfg.MaxAmount)})
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *PaymentService) validateSubOrder(ctx context.Context, req CreateOrderRequest) (*dbent.SubscriptionPlan, error) {
|
||||
if req.PlanID == 0 {
|
||||
return nil, infraerrors.BadRequest("INVALID_INPUT", "subscription order requires a plan")
|
||||
}
|
||||
plan, err := s.configService.GetPlan(ctx, req.PlanID)
|
||||
if err != nil || !plan.ForSale {
|
||||
return nil, infraerrors.NotFound("PLAN_NOT_AVAILABLE", "plan not found or not for sale")
|
||||
}
|
||||
group, err := s.groupRepo.GetByID(ctx, plan.GroupID)
|
||||
if err != nil || group.Status != payment.EntityStatusActive {
|
||||
return nil, infraerrors.NotFound("GROUP_NOT_FOUND", "subscription group is no longer available")
|
||||
}
|
||||
if !group.IsSubscriptionType() {
|
||||
return nil, infraerrors.BadRequest("GROUP_TYPE_MISMATCH", "group is not a subscription type")
|
||||
}
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderRequest, user *User, plan *dbent.SubscriptionPlan, cfg *PaymentConfig, amount, feeRate, payAmount float64) (*dbent.PaymentOrder, error) {
|
||||
tx, err := s.entClient.Tx(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
if err := s.checkPendingLimit(ctx, tx, req.UserID, cfg.MaxPendingOrders); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.checkDailyLimit(ctx, tx, req.UserID, amount, cfg.DailyLimit); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tm := cfg.OrderTimeoutMin
|
||||
if tm <= 0 {
|
||||
tm = defaultOrderTimeoutMin
|
||||
}
|
||||
exp := time.Now().Add(time.Duration(tm) * time.Minute)
|
||||
b := tx.PaymentOrder.Create().
|
||||
SetUserID(req.UserID).
|
||||
SetUserEmail(user.Email).
|
||||
SetUserName(user.Username).
|
||||
SetNillableUserNotes(psNilIfEmpty(user.Notes)).
|
||||
SetAmount(amount).
|
||||
SetPayAmount(payAmount).
|
||||
SetFeeRate(feeRate).
|
||||
SetRechargeCode("").
|
||||
SetOutTradeNo(generateOutTradeNo()).
|
||||
SetPaymentType(req.PaymentType).
|
||||
SetPaymentTradeNo("").
|
||||
SetOrderType(req.OrderType).
|
||||
SetStatus(OrderStatusPending).
|
||||
SetExpiresAt(exp).
|
||||
SetClientIP(req.ClientIP).
|
||||
SetSrcHost(req.SrcHost)
|
||||
if req.SrcURL != "" {
|
||||
b.SetSrcURL(req.SrcURL)
|
||||
}
|
||||
if plan != nil {
|
||||
b.SetPlanID(plan.ID).SetSubscriptionGroupID(plan.GroupID).SetSubscriptionDays(psComputeValidityDays(plan.ValidityDays, plan.ValidityUnit))
|
||||
}
|
||||
order, err := b.Save(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create order: %w", err)
|
||||
}
|
||||
code := fmt.Sprintf("PAY-%d-%d", order.ID, time.Now().UnixNano()%100000)
|
||||
order, err = tx.PaymentOrder.UpdateOneID(order.ID).SetRechargeCode(code).Save(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("set recharge code: %w", err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("commit order transaction: %w", err)
|
||||
}
|
||||
return order, nil
|
||||
}
|
||||
|
||||
func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, userID int64, max int) error {
|
||||
if max <= 0 {
|
||||
max = defaultMaxPendingOrders
|
||||
}
|
||||
c, err := tx.PaymentOrder.Query().Where(paymentorder.UserIDEQ(userID), paymentorder.StatusEQ(OrderStatusPending)).Count(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("count pending orders: %w", err)
|
||||
}
|
||||
if c >= max {
|
||||
return infraerrors.TooManyRequests("TOO_MANY_PENDING", fmt.Sprintf("too many pending orders (max %d)", max)).
|
||||
WithMetadata(map[string]string{"max": strconv.Itoa(max)})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PaymentService) checkCancelRateLimit(ctx context.Context, userID int64, cfg *PaymentConfig) error {
|
||||
if !cfg.CancelRateLimitEnabled || cfg.CancelRateLimitMax <= 0 {
|
||||
return nil
|
||||
}
|
||||
windowStart := cancelRateLimitWindowStart(cfg)
|
||||
operator := fmt.Sprintf("user:%d", userID)
|
||||
count, err := s.entClient.PaymentAuditLog.Query().
|
||||
Where(
|
||||
paymentauditlog.ActionEQ("ORDER_CANCELLED"),
|
||||
paymentauditlog.OperatorEQ(operator),
|
||||
paymentauditlog.CreatedAtGTE(windowStart),
|
||||
).Count(ctx)
|
||||
if err != nil {
|
||||
slog.Error("check cancel rate limit failed", "userID", userID, "error", err)
|
||||
return nil // fail open
|
||||
}
|
||||
if count >= cfg.CancelRateLimitMax {
|
||||
return infraerrors.TooManyRequests("CANCEL_RATE_LIMITED", "cancel rate limited").
|
||||
WithMetadata(map[string]string{
|
||||
"max": strconv.Itoa(cfg.CancelRateLimitMax),
|
||||
"window": strconv.Itoa(cfg.CancelRateLimitWindow),
|
||||
"unit": cfg.CancelRateLimitUnit,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func cancelRateLimitWindowStart(cfg *PaymentConfig) time.Time {
|
||||
now := time.Now()
|
||||
w := cfg.CancelRateLimitWindow
|
||||
if w <= 0 {
|
||||
w = 1
|
||||
}
|
||||
unit := cfg.CancelRateLimitUnit
|
||||
if unit == "" {
|
||||
unit = "day"
|
||||
}
|
||||
if cfg.CancelRateLimitMode == "fixed" {
|
||||
switch unit {
|
||||
case "minute":
|
||||
t := now.Truncate(time.Minute)
|
||||
return t.Add(-time.Duration(w-1) * time.Minute)
|
||||
case "day":
|
||||
y, m, d := now.Date()
|
||||
t := time.Date(y, m, d, 0, 0, 0, 0, now.Location())
|
||||
return t.AddDate(0, 0, -(w - 1))
|
||||
default: // hour
|
||||
t := now.Truncate(time.Hour)
|
||||
return t.Add(-time.Duration(w-1) * time.Hour)
|
||||
}
|
||||
}
|
||||
// rolling window
|
||||
switch unit {
|
||||
case "minute":
|
||||
return now.Add(-time.Duration(w) * time.Minute)
|
||||
case "day":
|
||||
return now.AddDate(0, 0, -w)
|
||||
default: // hour
|
||||
return now.Add(-time.Duration(w) * time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, userID int64, amount, limit float64) error {
|
||||
if limit <= 0 {
|
||||
return nil
|
||||
}
|
||||
ts := psStartOfDayUTC(time.Now())
|
||||
orders, err := tx.PaymentOrder.Query().Where(paymentorder.UserIDEQ(userID), paymentorder.StatusIn(OrderStatusPaid, OrderStatusRecharging, OrderStatusCompleted), paymentorder.PaidAtGTE(ts)).All(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("query daily usage: %w", err)
|
||||
}
|
||||
var used float64
|
||||
for _, o := range orders {
|
||||
used += o.Amount
|
||||
}
|
||||
if used+amount > limit {
|
||||
return infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", fmt.Sprintf("daily recharge limit reached, remaining: %.2f", math.Max(0, limit-used)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan) (*CreateOrderResponse, error) {
|
||||
s.EnsureProviders(ctx)
|
||||
providerKey := s.registry.GetProviderKey(req.PaymentType)
|
||||
if providerKey == "" {
|
||||
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment method (%s) is not configured", req.PaymentType))
|
||||
}
|
||||
sel, err := s.loadBalancer.SelectInstance(ctx, providerKey, req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("select provider instance: %w", err)
|
||||
}
|
||||
if sel == nil {
|
||||
return nil, infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "no available payment instance")
|
||||
}
|
||||
prov, err := provider.CreateProvider(providerKey, sel.InstanceID, sel.Config)
|
||||
if err != nil {
|
||||
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", "payment method is temporarily unavailable")
|
||||
}
|
||||
subject := s.buildPaymentSubject(plan, payAmountStr, cfg)
|
||||
outTradeNo := order.OutTradeNo
|
||||
pr, err := prov.CreatePayment(ctx, payment.CreatePaymentRequest{OrderID: outTradeNo, Amount: payAmountStr, PaymentType: req.PaymentType, Subject: subject, ClientIP: req.ClientIP, IsMobile: req.IsMobile, InstanceSubMethods: sel.SupportedTypes})
|
||||
if err != nil {
|
||||
slog.Error("[PaymentService] CreatePayment failed", "provider", providerKey, "instance", sel.InstanceID, "error", err)
|
||||
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment gateway error: %s", err.Error()))
|
||||
}
|
||||
_, err = s.entClient.PaymentOrder.UpdateOneID(order.ID).SetNillablePaymentTradeNo(psNilIfEmpty(pr.TradeNo)).SetNillablePayURL(psNilIfEmpty(pr.PayURL)).SetNillableQrCode(psNilIfEmpty(pr.QRCode)).SetNillableProviderInstanceID(psNilIfEmpty(sel.InstanceID)).Save(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("update order with payment details: %w", err)
|
||||
}
|
||||
s.writeAuditLog(ctx, order.ID, "ORDER_CREATED", fmt.Sprintf("user:%d", req.UserID), map[string]any{"amount": req.Amount, "paymentType": req.PaymentType, "orderType": req.OrderType})
|
||||
return &CreateOrderResponse{OrderID: order.ID, Amount: order.Amount, PayAmount: payAmount, FeeRate: order.FeeRate, Status: OrderStatusPending, PaymentType: req.PaymentType, PayURL: pr.PayURL, QRCode: pr.QRCode, ClientSecret: pr.ClientSecret, ExpiresAt: order.ExpiresAt, PaymentMode: sel.PaymentMode}, nil
|
||||
}
|
||||
|
||||
func (s *PaymentService) buildPaymentSubject(plan *dbent.SubscriptionPlan, payAmountStr string, cfg *PaymentConfig) string {
|
||||
if plan != nil {
|
||||
if plan.ProductName != "" {
|
||||
return plan.ProductName
|
||||
}
|
||||
return "Sub2API Subscription " + plan.Name
|
||||
}
|
||||
pf := strings.TrimSpace(cfg.ProductNamePrefix)
|
||||
sf := strings.TrimSpace(cfg.ProductNameSuffix)
|
||||
if pf != "" || sf != "" {
|
||||
return strings.TrimSpace(pf + " " + payAmountStr + " " + sf)
|
||||
}
|
||||
return "Sub2API " + payAmountStr + " CNY"
|
||||
}
|
||||
|
||||
// --- Order Queries ---
|
||||
|
||||
func (s *PaymentService) GetOrder(ctx context.Context, orderID, userID int64) (*dbent.PaymentOrder, error) {
|
||||
o, err := s.entClient.PaymentOrder.Get(ctx, orderID)
|
||||
if err != nil {
|
||||
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
|
||||
}
|
||||
if o.UserID != userID {
|
||||
return nil, infraerrors.Forbidden("FORBIDDEN", "no permission for this order")
|
||||
}
|
||||
return o, nil
|
||||
}
|
||||
|
||||
func (s *PaymentService) GetOrderByID(ctx context.Context, orderID int64) (*dbent.PaymentOrder, error) {
|
||||
o, err := s.entClient.PaymentOrder.Get(ctx, orderID)
|
||||
if err != nil {
|
||||
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
|
||||
}
|
||||
return o, nil
|
||||
}
|
||||
|
||||
func (s *PaymentService) GetUserOrders(ctx context.Context, userID int64, p OrderListParams) ([]*dbent.PaymentOrder, int, error) {
|
||||
q := s.entClient.PaymentOrder.Query().Where(paymentorder.UserIDEQ(userID))
|
||||
if p.Status != "" {
|
||||
q = q.Where(paymentorder.StatusEQ(p.Status))
|
||||
}
|
||||
if p.OrderType != "" {
|
||||
q = q.Where(paymentorder.OrderTypeEQ(p.OrderType))
|
||||
}
|
||||
if p.PaymentType != "" {
|
||||
q = q.Where(paymentorder.PaymentTypeEQ(p.PaymentType))
|
||||
}
|
||||
total, err := q.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("count user orders: %w", err)
|
||||
}
|
||||
ps, pg := applyPagination(p.PageSize, p.Page)
|
||||
orders, err := q.Order(dbent.Desc(paymentorder.FieldCreatedAt)).Limit(ps).Offset((pg - 1) * ps).All(ctx)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("query user orders: %w", err)
|
||||
}
|
||||
return orders, total, nil
|
||||
}
|
||||
|
||||
// AdminListOrders returns a paginated list of orders. If userID > 0, filters by user.
|
||||
func (s *PaymentService) AdminListOrders(ctx context.Context, userID int64, p OrderListParams) ([]*dbent.PaymentOrder, int, error) {
|
||||
q := s.entClient.PaymentOrder.Query()
|
||||
if userID > 0 {
|
||||
q = q.Where(paymentorder.UserIDEQ(userID))
|
||||
}
|
||||
if p.Status != "" {
|
||||
q = q.Where(paymentorder.StatusEQ(p.Status))
|
||||
}
|
||||
if p.OrderType != "" {
|
||||
q = q.Where(paymentorder.OrderTypeEQ(p.OrderType))
|
||||
}
|
||||
if p.PaymentType != "" {
|
||||
q = q.Where(paymentorder.PaymentTypeEQ(p.PaymentType))
|
||||
}
|
||||
if p.Keyword != "" {
|
||||
q = q.Where(paymentorder.Or(
|
||||
paymentorder.OutTradeNoContainsFold(p.Keyword),
|
||||
paymentorder.UserEmailContainsFold(p.Keyword),
|
||||
paymentorder.UserNameContainsFold(p.Keyword),
|
||||
))
|
||||
}
|
||||
total, err := q.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("count admin orders: %w", err)
|
||||
}
|
||||
ps, pg := applyPagination(p.PageSize, p.Page)
|
||||
orders, err := q.Order(dbent.Desc(paymentorder.FieldCreatedAt)).Limit(ps).Offset((pg - 1) * ps).All(ctx)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("query admin orders: %w", err)
|
||||
}
|
||||
return orders, total, nil
|
||||
}
|
||||
|
||||
// --- Cancel & Expire ---
|
||||
|
||||
func (s *PaymentService) CancelOrder(ctx context.Context, orderID, userID int64) (string, error) {
|
||||
o, err := s.entClient.PaymentOrder.Get(ctx, orderID)
|
||||
if err != nil {
|
||||
return "", infraerrors.NotFound("NOT_FOUND", "order not found")
|
||||
}
|
||||
if o.UserID != userID {
|
||||
return "", infraerrors.Forbidden("FORBIDDEN", "no permission for this order")
|
||||
}
|
||||
if o.Status != OrderStatusPending {
|
||||
return "", infraerrors.BadRequest("INVALID_STATUS", "order cannot be cancelled in current status")
|
||||
}
|
||||
return s.cancelCore(ctx, o, OrderStatusCancelled, fmt.Sprintf("user:%d", userID), "user cancelled order")
|
||||
}
|
||||
|
||||
func (s *PaymentService) AdminCancelOrder(ctx context.Context, orderID int64) (string, error) {
|
||||
o, err := s.entClient.PaymentOrder.Get(ctx, orderID)
|
||||
if err != nil {
|
||||
return "", infraerrors.NotFound("NOT_FOUND", "order not found")
|
||||
}
|
||||
if o.Status != OrderStatusPending {
|
||||
return "", infraerrors.BadRequest("INVALID_STATUS", "order cannot be cancelled in current status")
|
||||
}
|
||||
return s.cancelCore(ctx, o, OrderStatusCancelled, "admin", "admin cancelled order")
|
||||
}
|
||||
|
||||
func (s *PaymentService) cancelCore(ctx context.Context, o *dbent.PaymentOrder, fs, op, ad string) (string, error) {
|
||||
if o.PaymentTradeNo != "" || o.PaymentType != "" {
|
||||
if s.checkPaid(ctx, o) == "already_paid" {
|
||||
return "already_paid", nil
|
||||
}
|
||||
}
|
||||
c, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(o.ID), paymentorder.StatusEQ(OrderStatusPending)).SetStatus(fs).Save(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("update order status: %w", err)
|
||||
}
|
||||
if c > 0 {
|
||||
auditAction := "ORDER_CANCELLED"
|
||||
if fs == OrderStatusExpired {
|
||||
auditAction = "ORDER_EXPIRED"
|
||||
}
|
||||
s.writeAuditLog(ctx, o.ID, auditAction, op, map[string]any{"detail": ad})
|
||||
}
|
||||
return "cancelled", nil
|
||||
}
|
||||
|
||||
func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) string {
|
||||
prov, err := s.getOrderProvider(ctx, o)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
// Use OutTradeNo as fallback when PaymentTradeNo is empty
|
||||
// (e.g. EasyPay popup mode where trade_no arrives only via notify callback)
|
||||
tradeNo := o.PaymentTradeNo
|
||||
if tradeNo == "" {
|
||||
tradeNo = o.OutTradeNo
|
||||
}
|
||||
resp, err := prov.QueryOrder(ctx, tradeNo)
|
||||
if err != nil {
|
||||
slog.Warn("query upstream failed", "orderID", o.ID, "error", err)
|
||||
return ""
|
||||
}
|
||||
if resp.Status == payment.ProviderStatusPaid {
|
||||
if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: o.PaymentTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey()); err != nil {
|
||||
slog.Error("fulfillment failed during checkPaid", "orderID", o.ID, "error", err)
|
||||
// Still return already_paid — order was paid, fulfillment can be retried
|
||||
}
|
||||
return "already_paid"
|
||||
}
|
||||
if cp, ok := prov.(payment.CancelableProvider); ok {
|
||||
_ = cp.CancelPayment(ctx, tradeNo)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// VerifyOrderByOutTradeNo actively queries the upstream provider to check
|
||||
// if a payment was made, and processes it if so. This handles the case where
|
||||
// the provider's notify callback was missed (e.g. EasyPay popup mode).
|
||||
func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo string, userID int64) (*dbent.PaymentOrder, error) {
|
||||
o, err := s.entClient.PaymentOrder.Query().
|
||||
Where(paymentorder.OutTradeNo(outTradeNo)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
|
||||
}
|
||||
if o.UserID != userID {
|
||||
return nil, infraerrors.Forbidden("FORBIDDEN", "no permission for this order")
|
||||
}
|
||||
// Only verify orders that are still pending or recently expired
|
||||
if o.Status == OrderStatusPending || o.Status == OrderStatusExpired {
|
||||
result := s.checkPaid(ctx, o)
|
||||
if result == "already_paid" {
|
||||
// Reload order to get updated status
|
||||
o, err = s.entClient.PaymentOrder.Get(ctx, o.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reload order: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return o, nil
|
||||
}
|
||||
|
||||
// VerifyOrderPublic verifies payment status without user authentication.
|
||||
// Used by the payment result page when the user's session has expired.
|
||||
func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) {
|
||||
o, err := s.entClient.PaymentOrder.Query().
|
||||
Where(paymentorder.OutTradeNo(outTradeNo)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
|
||||
}
|
||||
if o.Status == OrderStatusPending || o.Status == OrderStatusExpired {
|
||||
result := s.checkPaid(ctx, o)
|
||||
if result == "already_paid" {
|
||||
o, err = s.entClient.PaymentOrder.Get(ctx, o.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reload order: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return o, nil
|
||||
}
|
||||
|
||||
func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) {
|
||||
now := time.Now()
|
||||
orders, err := s.entClient.PaymentOrder.Query().Where(paymentorder.StatusEQ(OrderStatusPending), paymentorder.ExpiresAtLTE(now)).All(ctx)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("query expired: %w", err)
|
||||
}
|
||||
n := 0
|
||||
for _, o := range orders {
|
||||
// Check upstream payment status before expiring — the user may have
|
||||
// paid just before timeout and the webhook hasn't arrived yet.
|
||||
outcome, _ := s.cancelCore(ctx, o, OrderStatusExpired, "system", "order expired")
|
||||
if outcome == "already_paid" {
|
||||
slog.Info("order was paid during expiry", "orderID", o.ID)
|
||||
continue
|
||||
}
|
||||
if outcome != "" {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// getOrderProvider creates a provider using the order's original instance config.
|
||||
// Falls back to registry lookup if instance ID is missing (legacy orders).
|
||||
func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
|
||||
if o.ProviderInstanceID != nil && *o.ProviderInstanceID != "" {
|
||||
instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64)
|
||||
if err == nil {
|
||||
cfg, err := s.loadBalancer.GetInstanceConfig(ctx, instID)
|
||||
if err == nil {
|
||||
providerKey := s.registry.GetProviderKey(o.PaymentType)
|
||||
if providerKey == "" {
|
||||
providerKey = o.PaymentType
|
||||
}
|
||||
p, err := provider.CreateProvider(providerKey, *o.ProviderInstanceID, cfg)
|
||||
if err == nil {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
s.EnsureProviders(ctx)
|
||||
return s.registry.GetProvider(o.PaymentType)
|
||||
}
|
||||
73
backend/internal/service/payment_order_expiry_service.go
Normal file
73
backend/internal/service/payment_order_expiry_service.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const expiryCheckTimeout = 30 * time.Second
|
||||
|
||||
// PaymentOrderExpiryService periodically expires timed-out payment orders.
|
||||
type PaymentOrderExpiryService struct {
|
||||
paymentSvc *PaymentService
|
||||
interval time.Duration
|
||||
stopCh chan struct{}
|
||||
stopOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewPaymentOrderExpiryService(paymentSvc *PaymentService, interval time.Duration) *PaymentOrderExpiryService {
|
||||
return &PaymentOrderExpiryService{
|
||||
paymentSvc: paymentSvc,
|
||||
interval: interval,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PaymentOrderExpiryService) Start() {
|
||||
if s == nil || s.paymentSvc == nil || s.interval <= 0 {
|
||||
return
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
ticker := time.NewTicker(s.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
s.runOnce()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.runOnce()
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *PaymentOrderExpiryService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
close(s.stopCh)
|
||||
})
|
||||
s.wg.Wait()
|
||||
}
|
||||
|
||||
func (s *PaymentOrderExpiryService) runOnce() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), expiryCheckTimeout)
|
||||
defer cancel()
|
||||
|
||||
expired, err := s.paymentSvc.ExpireTimedOutOrders(ctx)
|
||||
if err != nil {
|
||||
slog.Error("[PaymentOrderExpiry] failed to expire orders", "error", err)
|
||||
return
|
||||
}
|
||||
if expired > 0 {
|
||||
slog.Info("[PaymentOrderExpiry] expired timed-out orders", "count", expired)
|
||||
}
|
||||
}
|
||||
248
backend/internal/service/payment_refund.go
Normal file
248
backend/internal/service/payment_refund.go
Normal file
@@ -0,0 +1,248 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
// --- Refund Flow ---
|
||||
|
||||
func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error {
|
||||
o, err := s.validateRefundRequest(ctx, oid, uid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u, err := s.userRepo.GetByID(ctx, o.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
if u.Balance < o.Amount {
|
||||
return infraerrors.BadRequest("BALANCE_NOT_ENOUGH", "refund amount exceeds balance")
|
||||
}
|
||||
nr := strings.TrimSpace(reason)
|
||||
now := time.Now()
|
||||
by := fmt.Sprintf("%d", uid)
|
||||
c, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(oid), paymentorder.UserIDEQ(uid), paymentorder.StatusEQ(OrderStatusCompleted), paymentorder.OrderTypeEQ(payment.OrderTypeBalance)).SetStatus(OrderStatusRefundRequested).SetRefundRequestedAt(now).SetRefundRequestReason(nr).SetRefundRequestedBy(by).SetRefundAmount(o.Amount).Save(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update: %w", err)
|
||||
}
|
||||
if c == 0 {
|
||||
return infraerrors.Conflict("CONFLICT", "order status changed")
|
||||
}
|
||||
s.writeAuditLog(ctx, oid, "REFUND_REQUESTED", fmt.Sprintf("user:%d", uid), map[string]any{"amount": o.Amount, "reason": nr})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PaymentService) validateRefundRequest(ctx context.Context, oid, uid int64) (*dbent.PaymentOrder, error) {
|
||||
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
|
||||
if err != nil {
|
||||
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
|
||||
}
|
||||
if o.UserID != uid {
|
||||
return nil, infraerrors.Forbidden("FORBIDDEN", "no permission")
|
||||
}
|
||||
if o.OrderType != payment.OrderTypeBalance {
|
||||
return nil, infraerrors.BadRequest("INVALID_ORDER_TYPE", "only balance orders can request refund")
|
||||
}
|
||||
if o.Status != OrderStatusCompleted {
|
||||
return nil, infraerrors.BadRequest("INVALID_STATUS", "only completed orders can request refund")
|
||||
}
|
||||
return o, nil
|
||||
}
|
||||
|
||||
func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float64, reason string, force, deduct bool) (*RefundPlan, *RefundResult, error) {
|
||||
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
|
||||
if err != nil {
|
||||
return nil, nil, infraerrors.NotFound("NOT_FOUND", "order not found")
|
||||
}
|
||||
ok := []string{OrderStatusCompleted, OrderStatusRefundRequested, OrderStatusRefundFailed}
|
||||
if !psSliceContains(ok, o.Status) {
|
||||
return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund")
|
||||
}
|
||||
if math.IsNaN(amt) || math.IsInf(amt, 0) {
|
||||
return nil, nil, infraerrors.BadRequest("INVALID_AMOUNT", "invalid refund amount")
|
||||
}
|
||||
if amt <= 0 {
|
||||
amt = o.Amount
|
||||
}
|
||||
if amt-o.Amount > amountToleranceCNY {
|
||||
return nil, nil, infraerrors.BadRequest("REFUND_AMOUNT_EXCEEDED", "refund amount exceeds recharge")
|
||||
}
|
||||
// Full refund: use actual pay_amount for gateway (includes fees)
|
||||
ga := amt
|
||||
if math.Abs(amt-o.Amount) <= amountToleranceCNY {
|
||||
ga = o.PayAmount
|
||||
}
|
||||
rr := strings.TrimSpace(reason)
|
||||
if rr == "" && o.RefundRequestReason != nil {
|
||||
rr = *o.RefundRequestReason
|
||||
}
|
||||
if rr == "" {
|
||||
rr = fmt.Sprintf("refund order:%d", o.ID)
|
||||
}
|
||||
p := &RefundPlan{OrderID: oid, Order: o, RefundAmount: amt, GatewayAmount: ga, Reason: rr, Force: force, DeductBalance: deduct, DeductionType: payment.DeductionTypeNone}
|
||||
if deduct {
|
||||
if er := s.prepDeduct(ctx, o, p, force); er != nil {
|
||||
return nil, er, nil
|
||||
}
|
||||
}
|
||||
return p, nil, nil
|
||||
}
|
||||
|
||||
func (s *PaymentService) prepDeduct(ctx context.Context, o *dbent.PaymentOrder, p *RefundPlan, force bool) *RefundResult {
|
||||
if o.OrderType == payment.OrderTypeSubscription {
|
||||
p.DeductionType = payment.DeductionTypeSubscription
|
||||
if o.SubscriptionGroupID != nil && o.SubscriptionDays != nil {
|
||||
p.SubDaysToDeduct = *o.SubscriptionDays
|
||||
sub, err := s.subscriptionSvc.GetActiveSubscription(ctx, o.UserID, *o.SubscriptionGroupID)
|
||||
if err == nil && sub != nil {
|
||||
p.SubscriptionID = sub.ID
|
||||
} else if !force {
|
||||
return &RefundResult{Success: false, Warning: "cannot find active subscription for deduction, use force", RequireForce: true}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
u, err := s.userRepo.GetByID(ctx, o.UserID)
|
||||
if err != nil {
|
||||
if !force {
|
||||
return &RefundResult{Success: false, Warning: "cannot fetch user balance, use force", RequireForce: true}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
p.DeductionType = payment.DeductionTypeBalance
|
||||
p.BalanceToDeduct = math.Min(p.RefundAmount, u.Balance)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PaymentService) ExecuteRefund(ctx context.Context, p *RefundPlan) (*RefundResult, error) {
|
||||
c, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(p.OrderID), paymentorder.StatusIn(OrderStatusCompleted, OrderStatusRefundRequested, OrderStatusRefundFailed)).SetStatus(OrderStatusRefunding).Save(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("lock: %w", err)
|
||||
}
|
||||
if c == 0 {
|
||||
return nil, infraerrors.Conflict("CONFLICT", "order status changed")
|
||||
}
|
||||
if p.DeductionType == payment.DeductionTypeBalance && p.BalanceToDeduct > 0 {
|
||||
// Skip balance deduction on retry if previous attempt already deducted
|
||||
// but failed to roll back (REFUND_ROLLBACK_FAILED in audit log).
|
||||
if !s.hasAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED") {
|
||||
if err := s.userRepo.DeductBalance(ctx, p.Order.UserID, p.BalanceToDeduct); err != nil {
|
||||
s.restoreStatus(ctx, p)
|
||||
return nil, fmt.Errorf("deduction: %w", err)
|
||||
}
|
||||
} else {
|
||||
slog.Warn("skipping balance deduction on retry (previous rollback failed)", "orderID", p.OrderID)
|
||||
p.BalanceToDeduct = 0
|
||||
}
|
||||
}
|
||||
if p.DeductionType == payment.DeductionTypeSubscription && p.SubDaysToDeduct > 0 && p.SubscriptionID > 0 {
|
||||
if !s.hasAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED") {
|
||||
_, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, -p.SubDaysToDeduct)
|
||||
if err != nil {
|
||||
// If deducting would expire the subscription, revoke it entirely
|
||||
slog.Info("subscription deduction would expire, revoking", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct)
|
||||
if revokeErr := s.subscriptionSvc.RevokeSubscription(ctx, p.SubscriptionID); revokeErr != nil {
|
||||
s.restoreStatus(ctx, p)
|
||||
return nil, fmt.Errorf("revoke subscription: %w", revokeErr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
slog.Warn("skipping subscription deduction on retry (previous rollback failed)", "orderID", p.OrderID)
|
||||
p.SubDaysToDeduct = 0
|
||||
}
|
||||
}
|
||||
if err := s.gwRefund(ctx, p); err != nil {
|
||||
return s.handleGwFail(ctx, p, err)
|
||||
}
|
||||
return s.markRefundOk(ctx, p)
|
||||
}
|
||||
|
||||
func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error {
|
||||
if p.Order.PaymentTradeNo == "" {
|
||||
s.writeAuditLog(ctx, p.Order.ID, "REFUND_NO_TRADE_NO", "admin", map[string]any{"detail": "skipped"})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use the exact provider instance that created this order, not a random one
|
||||
// from the registry. Each instance has its own merchant credentials.
|
||||
prov, err := s.getRefundProvider(ctx, p.Order)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get refund provider: %w", err)
|
||||
}
|
||||
_, err = prov.Refund(ctx, payment.RefundRequest{
|
||||
TradeNo: p.Order.PaymentTradeNo,
|
||||
OrderID: p.Order.OutTradeNo,
|
||||
Amount: strconv.FormatFloat(p.GatewayAmount, 'f', 2, 64),
|
||||
Reason: p.Reason,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// getRefundProvider creates a provider using the order's original instance config.
|
||||
// Delegates to getOrderProvider which handles instance lookup and fallback.
|
||||
func (s *PaymentService) getRefundProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
|
||||
return s.getOrderProvider(ctx, o)
|
||||
}
|
||||
|
||||
func (s *PaymentService) handleGwFail(ctx context.Context, p *RefundPlan, gErr error) (*RefundResult, error) {
|
||||
if s.RollbackRefund(ctx, p, gErr) {
|
||||
s.restoreStatus(ctx, p)
|
||||
s.writeAuditLog(ctx, p.OrderID, "REFUND_GATEWAY_FAILED", "admin", map[string]any{"detail": psErrMsg(gErr)})
|
||||
return &RefundResult{Success: false, Warning: "gateway failed: " + psErrMsg(gErr) + ", rolled back"}, nil
|
||||
}
|
||||
now := time.Now()
|
||||
_, _ = s.entClient.PaymentOrder.UpdateOneID(p.OrderID).SetStatus(OrderStatusRefundFailed).SetFailedAt(now).SetFailedReason(psErrMsg(gErr)).Save(ctx)
|
||||
s.writeAuditLog(ctx, p.OrderID, "REFUND_FAILED", "admin", map[string]any{"detail": psErrMsg(gErr)})
|
||||
return nil, infraerrors.InternalServer("REFUND_FAILED", psErrMsg(gErr))
|
||||
}
|
||||
|
||||
func (s *PaymentService) markRefundOk(ctx context.Context, p *RefundPlan) (*RefundResult, error) {
|
||||
fs := OrderStatusRefunded
|
||||
if p.RefundAmount < p.Order.Amount {
|
||||
fs = OrderStatusPartiallyRefunded
|
||||
}
|
||||
now := time.Now()
|
||||
_, err := s.entClient.PaymentOrder.UpdateOneID(p.OrderID).SetStatus(fs).SetRefundAmount(p.RefundAmount).SetRefundReason(p.Reason).SetRefundAt(now).SetForceRefund(p.Force).Save(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mark refund: %w", err)
|
||||
}
|
||||
s.writeAuditLog(ctx, p.OrderID, "REFUND_SUCCESS", "admin", map[string]any{"refundAmount": p.RefundAmount, "reason": p.Reason, "balanceDeducted": p.BalanceToDeduct, "force": p.Force})
|
||||
return &RefundResult{Success: true, BalanceDeducted: p.BalanceToDeduct, SubDaysDeducted: p.SubDaysToDeduct}, nil
|
||||
}
|
||||
|
||||
func (s *PaymentService) RollbackRefund(ctx context.Context, p *RefundPlan, gErr error) bool {
|
||||
if p.DeductionType == payment.DeductionTypeBalance && p.BalanceToDeduct > 0 {
|
||||
if err := s.userRepo.UpdateBalance(ctx, p.Order.UserID, p.BalanceToDeduct); err != nil {
|
||||
slog.Error("[CRITICAL] rollback failed", "orderID", p.OrderID, "amount", p.BalanceToDeduct, "error", err)
|
||||
s.writeAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED", "admin", map[string]any{"gatewayError": psErrMsg(gErr), "rollbackError": psErrMsg(err), "balanceDeducted": p.BalanceToDeduct})
|
||||
return false
|
||||
}
|
||||
}
|
||||
if p.DeductionType == payment.DeductionTypeSubscription && p.SubDaysToDeduct > 0 && p.SubscriptionID > 0 {
|
||||
if _, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, p.SubDaysToDeduct); err != nil {
|
||||
slog.Error("[CRITICAL] subscription rollback failed", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct, "error", err)
|
||||
s.writeAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED", "admin", map[string]any{"gatewayError": psErrMsg(gErr), "rollbackError": psErrMsg(err), "subDaysDeducted": p.SubDaysToDeduct})
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *PaymentService) restoreStatus(ctx context.Context, p *RefundPlan) {
|
||||
rs := OrderStatusCompleted
|
||||
if p.Order.Status == OrderStatusRefundRequested {
|
||||
rs = OrderStatusRefundRequested
|
||||
}
|
||||
_, _ = s.entClient.PaymentOrder.UpdateOneID(p.OrderID).SetStatus(rs).Save(ctx)
|
||||
}
|
||||
305
backend/internal/service/payment_service.go
Normal file
305
backend/internal/service/payment_service.go
Normal file
@@ -0,0 +1,305 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math/rand/v2"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
|
||||
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment/provider"
|
||||
)
|
||||
|
||||
// --- Order Status Constants ---
|
||||
|
||||
const (
|
||||
OrderStatusPending = payment.OrderStatusPending
|
||||
OrderStatusPaid = payment.OrderStatusPaid
|
||||
OrderStatusRecharging = payment.OrderStatusRecharging
|
||||
OrderStatusCompleted = payment.OrderStatusCompleted
|
||||
OrderStatusExpired = payment.OrderStatusExpired
|
||||
OrderStatusCancelled = payment.OrderStatusCancelled
|
||||
OrderStatusFailed = payment.OrderStatusFailed
|
||||
OrderStatusRefundRequested = payment.OrderStatusRefundRequested
|
||||
OrderStatusRefunding = payment.OrderStatusRefunding
|
||||
OrderStatusPartiallyRefunded = payment.OrderStatusPartiallyRefunded
|
||||
OrderStatusRefunded = payment.OrderStatusRefunded
|
||||
OrderStatusRefundFailed = payment.OrderStatusRefundFailed
|
||||
)
|
||||
|
||||
const (
|
||||
// defaultMaxPendingOrders and defaultOrderTimeoutMin are defined in
|
||||
// payment_config_service.go alongside other payment configuration defaults.
|
||||
paymentGraceMinutes = 5
|
||||
|
||||
defaultPageSize = 20
|
||||
maxPageSize = 100
|
||||
topUsersLimit = 10
|
||||
amountToleranceCNY = 0.01
|
||||
|
||||
orderIDPrefix = "sub2_"
|
||||
)
|
||||
|
||||
// --- Types ---
|
||||
|
||||
// generateOutTradeNo creates a unique external order ID for payment providers.
|
||||
// Format: sub2_20250409aB3kX9mQ (prefix + date + 8-char random)
|
||||
func generateOutTradeNo() string {
|
||||
date := time.Now().Format("20060102")
|
||||
rnd := generateRandomString(8)
|
||||
return orderIDPrefix + date + rnd
|
||||
}
|
||||
|
||||
func generateRandomString(n int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
b := make([]byte, n)
|
||||
for i := range b {
|
||||
b[i] = charset[rand.IntN(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
type CreateOrderRequest struct {
|
||||
UserID int64
|
||||
Amount float64
|
||||
PaymentType string
|
||||
ClientIP string
|
||||
IsMobile bool
|
||||
SrcHost string
|
||||
SrcURL string
|
||||
OrderType string
|
||||
PlanID int64
|
||||
}
|
||||
|
||||
type CreateOrderResponse struct {
|
||||
OrderID int64 `json:"order_id"`
|
||||
Amount float64 `json:"amount"`
|
||||
PayAmount float64 `json:"pay_amount"`
|
||||
FeeRate float64 `json:"fee_rate"`
|
||||
Status string `json:"status"`
|
||||
PaymentType string `json:"payment_type"`
|
||||
PayURL string `json:"pay_url,omitempty"`
|
||||
QRCode string `json:"qr_code,omitempty"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
PaymentMode string `json:"payment_mode,omitempty"`
|
||||
}
|
||||
|
||||
type OrderListParams struct {
|
||||
Page int
|
||||
PageSize int
|
||||
Status string
|
||||
OrderType string
|
||||
PaymentType string
|
||||
Keyword string
|
||||
}
|
||||
|
||||
type RefundPlan struct {
|
||||
OrderID int64
|
||||
Order *dbent.PaymentOrder
|
||||
RefundAmount float64
|
||||
GatewayAmount float64
|
||||
Reason string
|
||||
Force bool
|
||||
DeductBalance bool
|
||||
DeductionType string
|
||||
BalanceToDeduct float64
|
||||
SubDaysToDeduct int
|
||||
SubscriptionID int64
|
||||
}
|
||||
|
||||
type RefundResult struct {
|
||||
Success bool `json:"success"`
|
||||
Warning string `json:"warning,omitempty"`
|
||||
RequireForce bool `json:"require_force,omitempty"`
|
||||
BalanceDeducted float64 `json:"balance_deducted,omitempty"`
|
||||
SubDaysDeducted int `json:"subscription_days_deducted,omitempty"`
|
||||
}
|
||||
|
||||
type DashboardStats struct {
|
||||
TodayAmount float64 `json:"today_amount"`
|
||||
TotalAmount float64 `json:"total_amount"`
|
||||
TodayCount int `json:"today_count"`
|
||||
TotalCount int `json:"total_count"`
|
||||
AvgAmount float64 `json:"avg_amount"`
|
||||
PendingOrders int `json:"pending_orders"`
|
||||
|
||||
DailySeries []DailyStats `json:"daily_series"`
|
||||
PaymentMethods []PaymentMethodStat `json:"payment_methods"`
|
||||
TopUsers []TopUserStat `json:"top_users"`
|
||||
}
|
||||
|
||||
type DailyStats struct {
|
||||
Date string `json:"date"`
|
||||
Amount float64 `json:"amount"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
type PaymentMethodStat struct {
|
||||
Type string `json:"type"`
|
||||
Amount float64 `json:"amount"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
type TopUserStat struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Amount float64 `json:"amount"`
|
||||
}
|
||||
|
||||
// --- Service ---
|
||||
|
||||
type PaymentService struct {
|
||||
providerMu sync.Mutex
|
||||
providersLoaded bool
|
||||
entClient *dbent.Client
|
||||
registry *payment.Registry
|
||||
loadBalancer payment.LoadBalancer
|
||||
redeemService *RedeemService
|
||||
subscriptionSvc *SubscriptionService
|
||||
configService *PaymentConfigService
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
}
|
||||
|
||||
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
|
||||
return &PaymentService{entClient: entClient, registry: registry, loadBalancer: loadBalancer, redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
|
||||
}
|
||||
|
||||
// --- Provider Registry ---
|
||||
|
||||
// EnsureProviders lazily initializes the provider registry on first call.
|
||||
func (s *PaymentService) EnsureProviders(ctx context.Context) {
|
||||
s.providerMu.Lock()
|
||||
defer s.providerMu.Unlock()
|
||||
if !s.providersLoaded {
|
||||
s.loadProviders(ctx)
|
||||
s.providersLoaded = true
|
||||
}
|
||||
}
|
||||
|
||||
// RefreshProviders clears and re-registers all providers from the database.
|
||||
func (s *PaymentService) RefreshProviders(ctx context.Context) {
|
||||
s.providerMu.Lock()
|
||||
defer s.providerMu.Unlock()
|
||||
s.registry.Clear()
|
||||
s.loadProviders(ctx)
|
||||
s.providersLoaded = true
|
||||
}
|
||||
|
||||
func (s *PaymentService) loadProviders(ctx context.Context) {
|
||||
instances, err := s.entClient.PaymentProviderInstance.Query().
|
||||
Where(paymentproviderinstance.EnabledEQ(true)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
slog.Error("[PaymentService] failed to query provider instances", "error", err)
|
||||
return
|
||||
}
|
||||
for _, inst := range instances {
|
||||
cfg, err := s.loadBalancer.GetInstanceConfig(ctx, int64(inst.ID))
|
||||
if err != nil {
|
||||
slog.Warn("[PaymentService] failed to decrypt config for instance", "instanceID", inst.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
if inst.PaymentMode != "" {
|
||||
cfg["paymentMode"] = inst.PaymentMode
|
||||
}
|
||||
instID := fmt.Sprintf("%d", inst.ID)
|
||||
p, err := provider.CreateProvider(inst.ProviderKey, instID, cfg)
|
||||
if err != nil {
|
||||
slog.Warn("[PaymentService] failed to create provider for instance", "instanceID", inst.ID, "key", inst.ProviderKey, "error", err)
|
||||
continue
|
||||
}
|
||||
s.registry.Register(p)
|
||||
}
|
||||
}
|
||||
|
||||
// GetWebhookProvider returns the provider instance that should verify a webhook.
|
||||
// It extracts out_trade_no from the raw body, looks up the order to find the
|
||||
// original provider instance, and creates a provider with that instance's credentials.
|
||||
// Falls back to the registry provider when the order cannot be found.
|
||||
func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) {
|
||||
if outTradeNo != "" {
|
||||
order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx)
|
||||
if err == nil {
|
||||
p, pErr := s.getOrderProvider(ctx, order)
|
||||
if pErr == nil {
|
||||
return p, nil
|
||||
}
|
||||
slog.Warn("[Webhook] order provider creation failed, falling back to registry", "outTradeNo", outTradeNo, "error", pErr)
|
||||
}
|
||||
}
|
||||
s.EnsureProviders(ctx)
|
||||
return s.registry.GetProviderByKey(providerKey)
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func psIsRefundStatus(s string) bool {
|
||||
switch s {
|
||||
case OrderStatusRefundRequested, OrderStatusRefunding, OrderStatusPartiallyRefunded, OrderStatusRefunded, OrderStatusRefundFailed:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func psErrMsg(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
func psNilIfEmpty(s string) *string {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return &s
|
||||
}
|
||||
|
||||
func psSliceContains(sl []string, s string) bool {
|
||||
for _, v := range sl {
|
||||
if v == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func psComputeValidityDays(days int, unit string) int {
|
||||
switch unit {
|
||||
case "week":
|
||||
return days * 7
|
||||
case "month":
|
||||
return days * 30
|
||||
default:
|
||||
return days
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PaymentService) getFeeRate(_ string) float64 { return 0 }
|
||||
|
||||
func psStartOfDayUTC(t time.Time) time.Time {
|
||||
y, m, d := t.UTC().Date()
|
||||
return time.Date(y, m, d, 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
|
||||
func applyPagination(pageSize, page int) (size, pg int) {
|
||||
size = pageSize
|
||||
if size <= 0 {
|
||||
size = defaultPageSize
|
||||
}
|
||||
if size > maxPageSize {
|
||||
size = maxPageSize
|
||||
}
|
||||
pg = page
|
||||
if pg < 1 {
|
||||
pg = 1
|
||||
}
|
||||
return size, pg
|
||||
}
|
||||
163
backend/internal/service/payment_stats.go
Normal file
163
backend/internal/service/payment_stats.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"math"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
|
||||
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
|
||||
)
|
||||
|
||||
// --- Dashboard & Analytics ---
|
||||
|
||||
func (s *PaymentService) GetDashboardStats(ctx context.Context, days int) (*DashboardStats, error) {
|
||||
if days <= 0 {
|
||||
days = 30
|
||||
}
|
||||
now := time.Now()
|
||||
since := now.AddDate(0, 0, -days)
|
||||
todayStart := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||
|
||||
paidStatuses := []string{OrderStatusCompleted, OrderStatusPaid, OrderStatusRecharging}
|
||||
|
||||
orders, err := s.entClient.PaymentOrder.Query().
|
||||
Where(
|
||||
paymentorder.StatusIn(paidStatuses...),
|
||||
paymentorder.PaidAtGTE(since),
|
||||
).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
st := &DashboardStats{}
|
||||
computeBasicStats(st, orders, todayStart)
|
||||
|
||||
st.PendingOrders, err = s.entClient.PaymentOrder.Query().
|
||||
Where(paymentorder.StatusEQ(OrderStatusPending)).
|
||||
Count(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
st.DailySeries = buildDailySeries(orders, since, days)
|
||||
st.PaymentMethods = buildMethodDistribution(orders)
|
||||
st.TopUsers = buildTopUsers(orders)
|
||||
|
||||
return st, nil
|
||||
}
|
||||
|
||||
func computeBasicStats(st *DashboardStats, orders []*dbent.PaymentOrder, todayStart time.Time) {
|
||||
var totalAmount, todayAmount float64
|
||||
var todayCount int
|
||||
for _, o := range orders {
|
||||
totalAmount += o.PayAmount
|
||||
if o.PaidAt != nil && !o.PaidAt.Before(todayStart) {
|
||||
todayAmount += o.PayAmount
|
||||
todayCount++
|
||||
}
|
||||
}
|
||||
st.TotalAmount = math.Round(totalAmount*100) / 100
|
||||
st.TodayAmount = math.Round(todayAmount*100) / 100
|
||||
st.TotalCount = len(orders)
|
||||
st.TodayCount = todayCount
|
||||
if st.TotalCount > 0 {
|
||||
st.AvgAmount = math.Round(totalAmount/float64(st.TotalCount)*100) / 100
|
||||
}
|
||||
}
|
||||
|
||||
func buildDailySeries(orders []*dbent.PaymentOrder, since time.Time, days int) []DailyStats {
|
||||
dailyMap := make(map[string]*DailyStats)
|
||||
for _, o := range orders {
|
||||
if o.PaidAt == nil {
|
||||
continue
|
||||
}
|
||||
date := o.PaidAt.Format("2006-01-02")
|
||||
ds, ok := dailyMap[date]
|
||||
if !ok {
|
||||
ds = &DailyStats{Date: date}
|
||||
dailyMap[date] = ds
|
||||
}
|
||||
ds.Amount += o.PayAmount
|
||||
ds.Count++
|
||||
}
|
||||
series := make([]DailyStats, 0, days)
|
||||
for i := 0; i < days; i++ {
|
||||
date := since.AddDate(0, 0, i+1).Format("2006-01-02")
|
||||
if ds, ok := dailyMap[date]; ok {
|
||||
ds.Amount = math.Round(ds.Amount*100) / 100
|
||||
series = append(series, *ds)
|
||||
} else {
|
||||
series = append(series, DailyStats{Date: date})
|
||||
}
|
||||
}
|
||||
return series
|
||||
}
|
||||
|
||||
func buildMethodDistribution(orders []*dbent.PaymentOrder) []PaymentMethodStat {
|
||||
methodMap := make(map[string]*PaymentMethodStat)
|
||||
for _, o := range orders {
|
||||
ms, ok := methodMap[o.PaymentType]
|
||||
if !ok {
|
||||
ms = &PaymentMethodStat{Type: o.PaymentType}
|
||||
methodMap[o.PaymentType] = ms
|
||||
}
|
||||
ms.Amount += o.PayAmount
|
||||
ms.Count++
|
||||
}
|
||||
methods := make([]PaymentMethodStat, 0, len(methodMap))
|
||||
for _, ms := range methodMap {
|
||||
ms.Amount = math.Round(ms.Amount*100) / 100
|
||||
methods = append(methods, *ms)
|
||||
}
|
||||
return methods
|
||||
}
|
||||
|
||||
func buildTopUsers(orders []*dbent.PaymentOrder) []TopUserStat {
|
||||
userMap := make(map[int64]*TopUserStat)
|
||||
for _, o := range orders {
|
||||
us, ok := userMap[o.UserID]
|
||||
if !ok {
|
||||
us = &TopUserStat{UserID: o.UserID, Email: o.UserEmail}
|
||||
userMap[o.UserID] = us
|
||||
}
|
||||
us.Amount += o.PayAmount
|
||||
}
|
||||
userList := make([]*TopUserStat, 0, len(userMap))
|
||||
for _, us := range userMap {
|
||||
us.Amount = math.Round(us.Amount*100) / 100
|
||||
userList = append(userList, us)
|
||||
}
|
||||
sort.Slice(userList, func(i, j int) bool {
|
||||
return userList[i].Amount > userList[j].Amount
|
||||
})
|
||||
limit := topUsersLimit
|
||||
if len(userList) < limit {
|
||||
limit = len(userList)
|
||||
}
|
||||
result := make([]TopUserStat, 0, limit)
|
||||
for i := 0; i < limit; i++ {
|
||||
result = append(result, *userList[i])
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// --- Audit Logs ---
|
||||
|
||||
func (s *PaymentService) writeAuditLog(ctx context.Context, oid int64, action, op string, detail map[string]any) {
|
||||
dj, _ := json.Marshal(detail)
|
||||
_, err := s.entClient.PaymentAuditLog.Create().SetOrderID(strconv.FormatInt(oid, 10)).SetAction(action).SetDetail(string(dj)).SetOperator(op).Save(ctx)
|
||||
if err != nil {
|
||||
slog.Error("audit log failed", "orderID", oid, "action", action, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PaymentService) GetOrderAuditLogs(ctx context.Context, oid int64) ([]*dbent.PaymentAuditLog, error) {
|
||||
return s.entClient.PaymentAuditLog.Query().Where(paymentauditlog.OrderIDEQ(strconv.FormatInt(oid, 10))).Order(paymentauditlog.ByCreatedAt()).All(ctx)
|
||||
}
|
||||
@@ -167,6 +167,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingKeyBackendModeEnabled,
|
||||
SettingKeyOIDCConnectEnabled,
|
||||
SettingKeyOIDCConnectProviderName,
|
||||
SettingPaymentEnabled,
|
||||
}
|
||||
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
@@ -227,6 +228,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
|
||||
OIDCOAuthEnabled: oidcEnabled,
|
||||
OIDCOAuthProviderName: oidcProviderName,
|
||||
PaymentEnabled: settings[SettingPaymentEnabled] == "true",
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -276,6 +278,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
|
||||
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
|
||||
PaymentEnabled bool `json:"payment_enabled"`
|
||||
Version string `json:"version,omitempty"`
|
||||
}{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
@@ -303,6 +306,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
|
||||
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
|
||||
PaymentEnabled: settings.PaymentEnabled,
|
||||
Version: s.version,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -139,6 +139,7 @@ type PublicSettings struct {
|
||||
BackendModeEnabled bool
|
||||
OIDCOAuthEnabled bool
|
||||
OIDCOAuthProviderName string
|
||||
PaymentEnabled bool
|
||||
Version string
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/google/wire"
|
||||
"github.com/redis/go-redis/v9"
|
||||
@@ -460,4 +462,20 @@ var ProviderSet = wire.NewSet(
|
||||
NewGroupCapacityService,
|
||||
NewChannelService,
|
||||
NewModelPricingResolver,
|
||||
ProvidePaymentConfigService,
|
||||
NewPaymentService,
|
||||
ProvidePaymentOrderExpiryService,
|
||||
)
|
||||
|
||||
// ProvidePaymentConfigService wraps NewPaymentConfigService to accept the named
|
||||
// payment.EncryptionKey type instead of raw []byte, avoiding Wire ambiguity.
|
||||
func ProvidePaymentConfigService(entClient *dbent.Client, settingRepo SettingRepository, key payment.EncryptionKey) *PaymentConfigService {
|
||||
return NewPaymentConfigService(entClient, settingRepo, []byte(key))
|
||||
}
|
||||
|
||||
// ProvidePaymentOrderExpiryService creates and starts PaymentOrderExpiryService.
|
||||
func ProvidePaymentOrderExpiryService(paymentSvc *PaymentService) *PaymentOrderExpiryService {
|
||||
svc := NewPaymentOrderExpiryService(paymentSvc, 60*time.Second)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user