feat: add per-provider allow_user_refund control and align wildcard matching

allow_user_refund:
- Add allow_user_refund field to PaymentProviderInstance ent schema
- Migration 103: ALTER TABLE payment_provider_instances ADD COLUMN
- Cascade logic: disabling refund_enabled auto-disables allow_user_refund
- User refund validation: check provider instance allows user refund
- Admin refund validation: check provider instance allows admin refund
- Subscription refund: deduct days on refund, rollback on failure
- New endpoint: GET /payment/orders/refund-eligible-providers
- Frontend: ToggleSwitch in ProviderCard/Dialog, cascade in SettingsView

Wildcard matching:
- Change findPricingForModel from "longest prefix wins" to "config order
  priority (first match wins)", aligning with channel service behavior
This commit is contained in:
erio
2026-04-14 16:26:46 +08:00
parent e8ee400a3f
commit f1297a3694
28 changed files with 405 additions and 98 deletions

View File

@@ -2,7 +2,6 @@ package service
import (
"context"
"sort"
"strings"
)
@@ -116,14 +115,8 @@ func matchAccountStatsRule(rule *AccountStatsPricingRule, accountID, groupID int
return false
}
// wildcardMatch 通配符匹配候选项(用于排序)
type wildcardMatch struct {
prefixLen int
pricing *ChannelModelPricing
}
// findPricingForModel 在定价列表中查找匹配的模型定价。
// 先精确匹配,再通配符匹配(前缀越长优先级越高)。
// 先精确匹配,再通配符匹配(按配置顺序,先匹配先使用)。
func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower string) *ChannelModelPricing {
// 精确匹配优先
for i := range pricingList {
@@ -137,8 +130,7 @@ func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower
}
}
}
// 通配符匹配:收集所有匹配项,按前缀长度降序取最长
var matches []wildcardMatch
// 通配符匹配:按配置顺序,先匹配先使用
for i := range pricingList {
p := &pricingList[i]
if !isPlatformMatch(platform, p.Platform) {
@@ -151,17 +143,11 @@ func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower
}
prefix := strings.TrimSuffix(ml, "*")
if strings.HasPrefix(modelLower, prefix) {
matches = append(matches, wildcardMatch{prefixLen: len(prefix), pricing: p})
return p
}
}
}
if len(matches) == 0 {
return nil
}
sort.Slice(matches, func(i, j int) bool {
return matches[i].prefixLen > matches[j].prefixLen
})
return matches[0].pricing
return nil
}
// isPlatformMatch 判断平台是否匹配(空平台视为不限平台)。

View File

@@ -147,14 +147,14 @@ func TestFindPricingForModel(t *testing.T) {
wantNil: true,
},
{
name: "wildcard matches by longest prefix (most specific wins)",
name: "wildcard matches by config order (first match wins)",
list: []ChannelModelPricing{
{ID: 10, Models: []string{"claude-*"}},
{ID: 11, Models: []string{"claude-opus-*"}},
},
platform: "",
model: "claude-opus-4",
wantID: 11, // "claude-opus-*" is longer prefix, wins over "claude-*"
wantID: 10, // config order: "claude-*" is first and matches, so it wins
},
{
name: "shorter wildcard used when longer does not match",

View File

@@ -22,16 +22,17 @@ func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*db
// ProviderInstanceResponse is the API response for a provider instance.
type ProviderInstanceResponse struct {
ID int64 `json:"id"`
ProviderKey string `json:"provider_key"`
Name string `json:"name"`
Config map[string]string `json:"config"`
SupportedTypes []string `json:"supported_types"`
Limits string `json:"limits"`
Enabled bool `json:"enabled"`
RefundEnabled bool `json:"refund_enabled"`
SortOrder int `json:"sort_order"`
PaymentMode string `json:"payment_mode"`
ID int64 `json:"id"`
ProviderKey string `json:"provider_key"`
Name string `json:"name"`
Config map[string]string `json:"config"`
SupportedTypes []string `json:"supported_types"`
Limits string `json:"limits"`
Enabled bool `json:"enabled"`
RefundEnabled bool `json:"refund_enabled"`
AllowUserRefund bool `json:"allow_user_refund"`
SortOrder int `json:"sort_order"`
PaymentMode string `json:"payment_mode"`
}
// ListProviderInstancesWithConfig returns provider instances with decrypted config.
@@ -47,7 +48,8 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte
ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name,
SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits,
Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled,
SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode,
AllowUserRefund: inst.AllowUserRefund,
SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode,
}
resp.Config, err = s.decryptAndMaskConfig(inst.Config)
if err != nil {
@@ -110,10 +112,12 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C
if err != nil {
return nil, err
}
allowUserRefund := req.AllowUserRefund && req.RefundEnabled
return s.entClient.PaymentProviderInstance.Create().
SetProviderKey(req.ProviderKey).SetName(req.Name).SetConfig(enc).
SetSupportedTypes(typesStr).SetEnabled(req.Enabled).SetPaymentMode(req.PaymentMode).
SetSortOrder(req.SortOrder).SetLimits(req.Limits).SetRefundEnabled(req.RefundEnabled).
SetAllowUserRefund(allowUserRefund).
Save(ctx)
}
@@ -221,6 +225,21 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
}
if req.RefundEnabled != nil {
u.SetRefundEnabled(*req.RefundEnabled)
// Cascade: turning off refund_enabled also disables allow_user_refund
if !*req.RefundEnabled {
u.SetAllowUserRefund(false)
}
}
if req.AllowUserRefund != nil {
// Only allow enabling when refund_enabled is true
if *req.AllowUserRefund {
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
if err == nil && inst.RefundEnabled {
u.SetAllowUserRefund(true)
}
} else {
u.SetAllowUserRefund(false)
}
}
if req.PaymentMode != nil {
u.SetPaymentMode(*req.PaymentMode)
@@ -233,6 +252,7 @@ func (s *PaymentConfigService) GetUserRefundEligibleInstanceIDs(ctx context.Cont
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(
paymentproviderinstance.RefundEnabledEQ(true),
paymentproviderinstance.AllowUserRefundEQ(true),
).Select(paymentproviderinstance.FieldID).All(ctx)
if err != nil {
return nil, err

View File

@@ -105,26 +105,28 @@ type MethodLimitsResponse struct {
}
type CreateProviderInstanceRequest struct {
ProviderKey string `json:"provider_key"`
Name string `json:"name"`
Config map[string]string `json:"config"`
SupportedTypes []string `json:"supported_types"`
Enabled bool `json:"enabled"`
PaymentMode string `json:"payment_mode"`
SortOrder int `json:"sort_order"`
Limits string `json:"limits"`
RefundEnabled bool `json:"refund_enabled"`
ProviderKey string `json:"provider_key"`
Name string `json:"name"`
Config map[string]string `json:"config"`
SupportedTypes []string `json:"supported_types"`
Enabled bool `json:"enabled"`
PaymentMode string `json:"payment_mode"`
SortOrder int `json:"sort_order"`
Limits string `json:"limits"`
RefundEnabled bool `json:"refund_enabled"`
AllowUserRefund bool `json:"allow_user_refund"`
}
type UpdateProviderInstanceRequest struct {
Name *string `json:"name"`
Config map[string]string `json:"config"`
SupportedTypes []string `json:"supported_types"`
Enabled *bool `json:"enabled"`
PaymentMode *string `json:"payment_mode"`
SortOrder *int `json:"sort_order"`
Limits *string `json:"limits"`
RefundEnabled *bool `json:"refund_enabled"`
Name *string `json:"name"`
Config map[string]string `json:"config"`
SupportedTypes []string `json:"supported_types"`
Enabled *bool `json:"enabled"`
PaymentMode *string `json:"payment_mode"`
SortOrder *int `json:"sort_order"`
Limits *string `json:"limits"`
RefundEnabled *bool `json:"refund_enabled"`
AllowUserRefund *bool `json:"allow_user_refund"`
}
type CreatePlanRequest struct {
GroupID int64 `json:"group_id"`

View File

@@ -17,6 +17,19 @@ import (
// --- Refund Flow ---
// getOrderProviderInstance looks up the provider instance that processed this order.
// Returns nil, nil for legacy orders without provider_instance_id.
func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) {
if o.ProviderInstanceID == nil || *o.ProviderInstanceID == "" {
return nil, nil
}
instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64)
if err != nil {
return nil, nil
}
return s.entClient.PaymentProviderInstance.Get(ctx, instID)
}
func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error {
o, err := s.validateRefundRequest(ctx, oid, uid)
if err != nil {
@@ -57,6 +70,14 @@ func (s *PaymentService) validateRefundRequest(ctx context.Context, oid, uid int
if o.Status != OrderStatusCompleted {
return nil, infraerrors.BadRequest("INVALID_STATUS", "only completed orders can request refund")
}
// Check provider instance allows user refund
inst, err := s.getOrderProviderInstance(ctx, o)
if err != nil || inst == nil {
return nil, infraerrors.Forbidden("USER_REFUND_DISABLED", "refund is not available for this order")
}
if !inst.AllowUserRefund {
return nil, infraerrors.Forbidden("USER_REFUND_DISABLED", "user refund is not enabled for this provider")
}
return o, nil
}
@@ -69,6 +90,18 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float
if !psSliceContains(ok, o.Status) {
return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund")
}
// Check provider instance allows admin refund
inst, instErr := s.getOrderProviderInstance(ctx, o)
if instErr != nil {
slog.Warn("refund: provider instance not found", "orderID", oid, "error", instErr)
}
if inst != nil && !inst.RefundEnabled {
return nil, nil, infraerrors.Forbidden("REFUND_DISABLED", "refund is not enabled for this provider")
}
if inst == nil && instErr == nil {
// Legacy order without provider_instance_id — block refund
return nil, nil, infraerrors.Forbidden("REFUND_DISABLED", "refund is not available for this order")
}
if math.IsNaN(amt) || math.IsInf(amt, 0) {
return nil, nil, infraerrors.BadRequest("INVALID_AMOUNT", "invalid refund amount")
}
@@ -102,6 +135,15 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float
func (s *PaymentService) prepDeduct(ctx context.Context, o *dbent.PaymentOrder, p *RefundPlan, force bool) *RefundResult {
if o.OrderType == payment.OrderTypeSubscription {
p.DeductionType = payment.DeductionTypeSubscription
if o.SubscriptionGroupID != nil && o.SubscriptionDays != nil {
p.SubDaysToDeduct = *o.SubscriptionDays
sub, err := s.subscriptionSvc.GetActiveSubscription(ctx, o.UserID, *o.SubscriptionGroupID)
if err == nil && sub != nil {
p.SubscriptionID = sub.ID
} else if !force {
return &RefundResult{Success: false, Warning: "cannot find active subscription for deduction, use force", RequireForce: true}
}
}
return nil
}
u, err := s.userRepo.GetByID(ctx, o.UserID)
@@ -137,6 +179,21 @@ func (s *PaymentService) ExecuteRefund(ctx context.Context, p *RefundPlan) (*Ref
p.BalanceToDeduct = 0
}
}
if p.DeductionType == payment.DeductionTypeSubscription && p.SubDaysToDeduct > 0 && p.SubscriptionID > 0 {
if !s.hasAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED") {
_, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, -p.SubDaysToDeduct)
if err != nil {
slog.Info("subscription deduction would expire, revoking", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct)
if revokeErr := s.subscriptionSvc.RevokeSubscription(ctx, p.SubscriptionID); revokeErr != nil {
s.restoreStatus(ctx, p)
return nil, fmt.Errorf("revoke subscription: %w", revokeErr)
}
}
} else {
slog.Warn("skipping subscription deduction on retry (previous rollback failed)", "orderID", p.OrderID)
p.SubDaysToDeduct = 0
}
}
if err := s.gwRefund(ctx, p); err != nil {
return s.handleGwFail(ctx, p, err)
}
@@ -204,6 +261,13 @@ func (s *PaymentService) RollbackRefund(ctx context.Context, p *RefundPlan, gErr
return false
}
}
if p.DeductionType == payment.DeductionTypeSubscription && p.SubDaysToDeduct > 0 && p.SubscriptionID > 0 {
if _, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, p.SubDaysToDeduct); err != nil {
slog.Error("[CRITICAL] subscription rollback failed", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct, "error", err)
s.writeAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED", "admin", map[string]any{"gatewayError": psErrMsg(gErr), "rollbackError": psErrMsg(err), "subDaysDeducted": p.SubDaysToDeduct})
return false
}
}
return true
}