feat: add per-provider allow_user_refund control and align wildcard matching
allow_user_refund: - Add allow_user_refund field to PaymentProviderInstance ent schema - Migration 103: ALTER TABLE payment_provider_instances ADD COLUMN - Cascade logic: disabling refund_enabled auto-disables allow_user_refund - User refund validation: check provider instance allows user refund - Admin refund validation: check provider instance allows admin refund - Subscription refund: deduct days on refund, rollback on failure - New endpoint: GET /payment/orders/refund-eligible-providers - Frontend: ToggleSwitch in ProviderCard/Dialog, cascade in SettingsView Wildcard matching: - Change findPricingForModel from "longest prefix wins" to "config order priority (first match wins)", aligning with channel service behavior
This commit is contained in:
@@ -2,7 +2,6 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -116,14 +115,8 @@ func matchAccountStatsRule(rule *AccountStatsPricingRule, accountID, groupID int
|
||||
return false
|
||||
}
|
||||
|
||||
// wildcardMatch 通配符匹配候选项(用于排序)
|
||||
type wildcardMatch struct {
|
||||
prefixLen int
|
||||
pricing *ChannelModelPricing
|
||||
}
|
||||
|
||||
// findPricingForModel 在定价列表中查找匹配的模型定价。
|
||||
// 先精确匹配,再通配符匹配(前缀越长优先级越高)。
|
||||
// 先精确匹配,再通配符匹配(按配置顺序,先匹配先使用)。
|
||||
func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower string) *ChannelModelPricing {
|
||||
// 精确匹配优先
|
||||
for i := range pricingList {
|
||||
@@ -137,8 +130,7 @@ func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower
|
||||
}
|
||||
}
|
||||
}
|
||||
// 通配符匹配:收集所有匹配项,按前缀长度降序取最长
|
||||
var matches []wildcardMatch
|
||||
// 通配符匹配:按配置顺序,先匹配先使用
|
||||
for i := range pricingList {
|
||||
p := &pricingList[i]
|
||||
if !isPlatformMatch(platform, p.Platform) {
|
||||
@@ -151,17 +143,11 @@ func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower
|
||||
}
|
||||
prefix := strings.TrimSuffix(ml, "*")
|
||||
if strings.HasPrefix(modelLower, prefix) {
|
||||
matches = append(matches, wildcardMatch{prefixLen: len(prefix), pricing: p})
|
||||
return p
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(matches) == 0 {
|
||||
return nil
|
||||
}
|
||||
sort.Slice(matches, func(i, j int) bool {
|
||||
return matches[i].prefixLen > matches[j].prefixLen
|
||||
})
|
||||
return matches[0].pricing
|
||||
return nil
|
||||
}
|
||||
|
||||
// isPlatformMatch 判断平台是否匹配(空平台视为不限平台)。
|
||||
|
||||
@@ -147,14 +147,14 @@ func TestFindPricingForModel(t *testing.T) {
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard matches by longest prefix (most specific wins)",
|
||||
name: "wildcard matches by config order (first match wins)",
|
||||
list: []ChannelModelPricing{
|
||||
{ID: 10, Models: []string{"claude-*"}},
|
||||
{ID: 11, Models: []string{"claude-opus-*"}},
|
||||
},
|
||||
platform: "",
|
||||
model: "claude-opus-4",
|
||||
wantID: 11, // "claude-opus-*" is longer prefix, wins over "claude-*"
|
||||
wantID: 10, // config order: "claude-*" is first and matches, so it wins
|
||||
},
|
||||
{
|
||||
name: "shorter wildcard used when longer does not match",
|
||||
|
||||
@@ -22,16 +22,17 @@ func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*db
|
||||
|
||||
// ProviderInstanceResponse is the API response for a provider instance.
|
||||
type ProviderInstanceResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
ProviderKey string `json:"provider_key"`
|
||||
Name string `json:"name"`
|
||||
Config map[string]string `json:"config"`
|
||||
SupportedTypes []string `json:"supported_types"`
|
||||
Limits string `json:"limits"`
|
||||
Enabled bool `json:"enabled"`
|
||||
RefundEnabled bool `json:"refund_enabled"`
|
||||
SortOrder int `json:"sort_order"`
|
||||
PaymentMode string `json:"payment_mode"`
|
||||
ID int64 `json:"id"`
|
||||
ProviderKey string `json:"provider_key"`
|
||||
Name string `json:"name"`
|
||||
Config map[string]string `json:"config"`
|
||||
SupportedTypes []string `json:"supported_types"`
|
||||
Limits string `json:"limits"`
|
||||
Enabled bool `json:"enabled"`
|
||||
RefundEnabled bool `json:"refund_enabled"`
|
||||
AllowUserRefund bool `json:"allow_user_refund"`
|
||||
SortOrder int `json:"sort_order"`
|
||||
PaymentMode string `json:"payment_mode"`
|
||||
}
|
||||
|
||||
// ListProviderInstancesWithConfig returns provider instances with decrypted config.
|
||||
@@ -47,7 +48,8 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte
|
||||
ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name,
|
||||
SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits,
|
||||
Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled,
|
||||
SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode,
|
||||
AllowUserRefund: inst.AllowUserRefund,
|
||||
SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode,
|
||||
}
|
||||
resp.Config, err = s.decryptAndMaskConfig(inst.Config)
|
||||
if err != nil {
|
||||
@@ -110,10 +112,12 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allowUserRefund := req.AllowUserRefund && req.RefundEnabled
|
||||
return s.entClient.PaymentProviderInstance.Create().
|
||||
SetProviderKey(req.ProviderKey).SetName(req.Name).SetConfig(enc).
|
||||
SetSupportedTypes(typesStr).SetEnabled(req.Enabled).SetPaymentMode(req.PaymentMode).
|
||||
SetSortOrder(req.SortOrder).SetLimits(req.Limits).SetRefundEnabled(req.RefundEnabled).
|
||||
SetAllowUserRefund(allowUserRefund).
|
||||
Save(ctx)
|
||||
}
|
||||
|
||||
@@ -221,6 +225,21 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
|
||||
}
|
||||
if req.RefundEnabled != nil {
|
||||
u.SetRefundEnabled(*req.RefundEnabled)
|
||||
// Cascade: turning off refund_enabled also disables allow_user_refund
|
||||
if !*req.RefundEnabled {
|
||||
u.SetAllowUserRefund(false)
|
||||
}
|
||||
}
|
||||
if req.AllowUserRefund != nil {
|
||||
// Only allow enabling when refund_enabled is true
|
||||
if *req.AllowUserRefund {
|
||||
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
|
||||
if err == nil && inst.RefundEnabled {
|
||||
u.SetAllowUserRefund(true)
|
||||
}
|
||||
} else {
|
||||
u.SetAllowUserRefund(false)
|
||||
}
|
||||
}
|
||||
if req.PaymentMode != nil {
|
||||
u.SetPaymentMode(*req.PaymentMode)
|
||||
@@ -233,6 +252,7 @@ func (s *PaymentConfigService) GetUserRefundEligibleInstanceIDs(ctx context.Cont
|
||||
instances, err := s.entClient.PaymentProviderInstance.Query().
|
||||
Where(
|
||||
paymentproviderinstance.RefundEnabledEQ(true),
|
||||
paymentproviderinstance.AllowUserRefundEQ(true),
|
||||
).Select(paymentproviderinstance.FieldID).All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -105,26 +105,28 @@ type MethodLimitsResponse struct {
|
||||
}
|
||||
|
||||
type CreateProviderInstanceRequest struct {
|
||||
ProviderKey string `json:"provider_key"`
|
||||
Name string `json:"name"`
|
||||
Config map[string]string `json:"config"`
|
||||
SupportedTypes []string `json:"supported_types"`
|
||||
Enabled bool `json:"enabled"`
|
||||
PaymentMode string `json:"payment_mode"`
|
||||
SortOrder int `json:"sort_order"`
|
||||
Limits string `json:"limits"`
|
||||
RefundEnabled bool `json:"refund_enabled"`
|
||||
ProviderKey string `json:"provider_key"`
|
||||
Name string `json:"name"`
|
||||
Config map[string]string `json:"config"`
|
||||
SupportedTypes []string `json:"supported_types"`
|
||||
Enabled bool `json:"enabled"`
|
||||
PaymentMode string `json:"payment_mode"`
|
||||
SortOrder int `json:"sort_order"`
|
||||
Limits string `json:"limits"`
|
||||
RefundEnabled bool `json:"refund_enabled"`
|
||||
AllowUserRefund bool `json:"allow_user_refund"`
|
||||
}
|
||||
|
||||
type UpdateProviderInstanceRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Config map[string]string `json:"config"`
|
||||
SupportedTypes []string `json:"supported_types"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
PaymentMode *string `json:"payment_mode"`
|
||||
SortOrder *int `json:"sort_order"`
|
||||
Limits *string `json:"limits"`
|
||||
RefundEnabled *bool `json:"refund_enabled"`
|
||||
Name *string `json:"name"`
|
||||
Config map[string]string `json:"config"`
|
||||
SupportedTypes []string `json:"supported_types"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
PaymentMode *string `json:"payment_mode"`
|
||||
SortOrder *int `json:"sort_order"`
|
||||
Limits *string `json:"limits"`
|
||||
RefundEnabled *bool `json:"refund_enabled"`
|
||||
AllowUserRefund *bool `json:"allow_user_refund"`
|
||||
}
|
||||
type CreatePlanRequest struct {
|
||||
GroupID int64 `json:"group_id"`
|
||||
|
||||
@@ -17,6 +17,19 @@ import (
|
||||
|
||||
// --- Refund Flow ---
|
||||
|
||||
// getOrderProviderInstance looks up the provider instance that processed this order.
|
||||
// Returns nil, nil for legacy orders without provider_instance_id.
|
||||
func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) {
|
||||
if o.ProviderInstanceID == nil || *o.ProviderInstanceID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.entClient.PaymentProviderInstance.Get(ctx, instID)
|
||||
}
|
||||
|
||||
func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error {
|
||||
o, err := s.validateRefundRequest(ctx, oid, uid)
|
||||
if err != nil {
|
||||
@@ -57,6 +70,14 @@ func (s *PaymentService) validateRefundRequest(ctx context.Context, oid, uid int
|
||||
if o.Status != OrderStatusCompleted {
|
||||
return nil, infraerrors.BadRequest("INVALID_STATUS", "only completed orders can request refund")
|
||||
}
|
||||
// Check provider instance allows user refund
|
||||
inst, err := s.getOrderProviderInstance(ctx, o)
|
||||
if err != nil || inst == nil {
|
||||
return nil, infraerrors.Forbidden("USER_REFUND_DISABLED", "refund is not available for this order")
|
||||
}
|
||||
if !inst.AllowUserRefund {
|
||||
return nil, infraerrors.Forbidden("USER_REFUND_DISABLED", "user refund is not enabled for this provider")
|
||||
}
|
||||
return o, nil
|
||||
}
|
||||
|
||||
@@ -69,6 +90,18 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float
|
||||
if !psSliceContains(ok, o.Status) {
|
||||
return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund")
|
||||
}
|
||||
// Check provider instance allows admin refund
|
||||
inst, instErr := s.getOrderProviderInstance(ctx, o)
|
||||
if instErr != nil {
|
||||
slog.Warn("refund: provider instance not found", "orderID", oid, "error", instErr)
|
||||
}
|
||||
if inst != nil && !inst.RefundEnabled {
|
||||
return nil, nil, infraerrors.Forbidden("REFUND_DISABLED", "refund is not enabled for this provider")
|
||||
}
|
||||
if inst == nil && instErr == nil {
|
||||
// Legacy order without provider_instance_id — block refund
|
||||
return nil, nil, infraerrors.Forbidden("REFUND_DISABLED", "refund is not available for this order")
|
||||
}
|
||||
if math.IsNaN(amt) || math.IsInf(amt, 0) {
|
||||
return nil, nil, infraerrors.BadRequest("INVALID_AMOUNT", "invalid refund amount")
|
||||
}
|
||||
@@ -102,6 +135,15 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float
|
||||
func (s *PaymentService) prepDeduct(ctx context.Context, o *dbent.PaymentOrder, p *RefundPlan, force bool) *RefundResult {
|
||||
if o.OrderType == payment.OrderTypeSubscription {
|
||||
p.DeductionType = payment.DeductionTypeSubscription
|
||||
if o.SubscriptionGroupID != nil && o.SubscriptionDays != nil {
|
||||
p.SubDaysToDeduct = *o.SubscriptionDays
|
||||
sub, err := s.subscriptionSvc.GetActiveSubscription(ctx, o.UserID, *o.SubscriptionGroupID)
|
||||
if err == nil && sub != nil {
|
||||
p.SubscriptionID = sub.ID
|
||||
} else if !force {
|
||||
return &RefundResult{Success: false, Warning: "cannot find active subscription for deduction, use force", RequireForce: true}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
u, err := s.userRepo.GetByID(ctx, o.UserID)
|
||||
@@ -137,6 +179,21 @@ func (s *PaymentService) ExecuteRefund(ctx context.Context, p *RefundPlan) (*Ref
|
||||
p.BalanceToDeduct = 0
|
||||
}
|
||||
}
|
||||
if p.DeductionType == payment.DeductionTypeSubscription && p.SubDaysToDeduct > 0 && p.SubscriptionID > 0 {
|
||||
if !s.hasAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED") {
|
||||
_, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, -p.SubDaysToDeduct)
|
||||
if err != nil {
|
||||
slog.Info("subscription deduction would expire, revoking", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct)
|
||||
if revokeErr := s.subscriptionSvc.RevokeSubscription(ctx, p.SubscriptionID); revokeErr != nil {
|
||||
s.restoreStatus(ctx, p)
|
||||
return nil, fmt.Errorf("revoke subscription: %w", revokeErr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
slog.Warn("skipping subscription deduction on retry (previous rollback failed)", "orderID", p.OrderID)
|
||||
p.SubDaysToDeduct = 0
|
||||
}
|
||||
}
|
||||
if err := s.gwRefund(ctx, p); err != nil {
|
||||
return s.handleGwFail(ctx, p, err)
|
||||
}
|
||||
@@ -204,6 +261,13 @@ func (s *PaymentService) RollbackRefund(ctx context.Context, p *RefundPlan, gErr
|
||||
return false
|
||||
}
|
||||
}
|
||||
if p.DeductionType == payment.DeductionTypeSubscription && p.SubDaysToDeduct > 0 && p.SubscriptionID > 0 {
|
||||
if _, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, p.SubDaysToDeduct); err != nil {
|
||||
slog.Error("[CRITICAL] subscription rollback failed", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct, "error", err)
|
||||
s.writeAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED", "admin", map[string]any{"gatewayError": psErrMsg(gErr), "rollbackError": psErrMsg(err), "subDaysDeducted": p.SubDaysToDeduct})
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user