diff --git a/backend/ent/client.go b/backend/ent/client.go index 3da7acf8..e52e015a 100644 --- a/backend/ent/client.go +++ b/backend/ent/client.go @@ -333,10 +333,10 @@ func (c *Client) Use(hooks ...Hook) { for _, n := range []interface{ Use(...Hook) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog, - c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, - c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, - c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, - c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage, + c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan, + c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User, + c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Use(hooks...) @@ -349,10 +349,10 @@ func (c *Client) Intercept(interceptors ...Interceptor) { for _, n := range []interface{ Intercept(...Interceptor) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog, - c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, - c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, - c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, - c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage, + c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan, + c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User, + c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Intercept(interceptors...) @@ -4629,19 +4629,19 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription type ( hooks struct { APIKey, Account, AccountGroup, Announcement, AnnouncementRead, - ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, - PaymentOrder, PaymentProviderInstance, PromoCode, - PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan, - TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup, - UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook + ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder, + PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode, + SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, + UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition, + UserAttributeValue, UserSubscription []ent.Hook } inters struct { APIKey, Account, AccountGroup, Announcement, AnnouncementRead, - ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, - PaymentOrder, PaymentProviderInstance, PromoCode, - PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan, - TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup, - UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor + ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder, + PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode, + SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, + UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition, + UserAttributeValue, UserSubscription []ent.Interceptor } ) diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go index 77d3e16e..8d8320bb 100644 --- a/backend/ent/intercept/intercept.go +++ b/backend/ent/intercept/intercept.go @@ -336,7 +336,6 @@ func (f TraversePaymentAuditLog) Traverse(ctx context.Context, q ent.Query) erro return fmt.Errorf("unexpected query type %T. expect *ent.PaymentAuditLogQuery", q) } - // The PaymentOrderFunc type is an adapter to allow the use of ordinary function as a Querier. type PaymentOrderFunc func(context.Context, *ent.PaymentOrderQuery) (ent.Value, error) diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 1fff61ba..68bdbf55 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -616,6 +616,7 @@ var ( {Name: "sort_order", Type: field.TypeInt, Default: 0}, {Name: "limits", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, {Name: "refund_enabled", Type: field.TypeBool, Default: false}, + {Name: "allow_user_refund", Type: field.TypeBool, Default: false}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 3bca248d..524ccb92 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -15642,25 +15642,26 @@ func (m *PaymentOrderMutation) ResetEdge(name string) error { // PaymentProviderInstanceMutation represents an operation that mutates the PaymentProviderInstance nodes in the graph. type PaymentProviderInstanceMutation struct { config - op Op - typ string - id *int64 - provider_key *string - name *string - _config *string - supported_types *string - enabled *bool - payment_mode *string - sort_order *int - addsort_order *int - limits *string - refund_enabled *bool - created_at *time.Time - updated_at *time.Time - clearedFields map[string]struct{} - done bool - oldValue func(context.Context) (*PaymentProviderInstance, error) - predicates []predicate.PaymentProviderInstance + op Op + typ string + id *int64 + provider_key *string + name *string + _config *string + supported_types *string + enabled *bool + payment_mode *string + sort_order *int + addsort_order *int + limits *string + refund_enabled *bool + allow_user_refund *bool + created_at *time.Time + updated_at *time.Time + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*PaymentProviderInstance, error) + predicates []predicate.PaymentProviderInstance } var _ ent.Mutation = (*PaymentProviderInstanceMutation)(nil) @@ -16105,6 +16106,42 @@ func (m *PaymentProviderInstanceMutation) ResetRefundEnabled() { m.refund_enabled = nil } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (m *PaymentProviderInstanceMutation) SetAllowUserRefund(b bool) { + m.allow_user_refund = &b +} + +// AllowUserRefund returns the value of the "allow_user_refund" field in the mutation. +func (m *PaymentProviderInstanceMutation) AllowUserRefund() (r bool, exists bool) { + v := m.allow_user_refund + if v == nil { + return + } + return *v, true +} + +// OldAllowUserRefund returns the old "allow_user_refund" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldAllowUserRefund(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAllowUserRefund is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAllowUserRefund requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAllowUserRefund: %w", err) + } + return oldValue.AllowUserRefund, nil +} + +// ResetAllowUserRefund resets all changes to the "allow_user_refund" field. +func (m *PaymentProviderInstanceMutation) ResetAllowUserRefund() { + m.allow_user_refund = nil +} + // SetCreatedAt sets the "created_at" field. func (m *PaymentProviderInstanceMutation) SetCreatedAt(t time.Time) { m.created_at = &t @@ -16211,7 +16248,7 @@ func (m *PaymentProviderInstanceMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *PaymentProviderInstanceMutation) Fields() []string { - fields := make([]string, 0, 11) + fields := make([]string, 0, 12) if m.provider_key != nil { fields = append(fields, paymentproviderinstance.FieldProviderKey) } @@ -16239,6 +16276,9 @@ func (m *PaymentProviderInstanceMutation) Fields() []string { if m.refund_enabled != nil { fields = append(fields, paymentproviderinstance.FieldRefundEnabled) } + if m.allow_user_refund != nil { + fields = append(fields, paymentproviderinstance.FieldAllowUserRefund) + } if m.created_at != nil { fields = append(fields, paymentproviderinstance.FieldCreatedAt) } @@ -16271,6 +16311,8 @@ func (m *PaymentProviderInstanceMutation) Field(name string) (ent.Value, bool) { return m.Limits() case paymentproviderinstance.FieldRefundEnabled: return m.RefundEnabled() + case paymentproviderinstance.FieldAllowUserRefund: + return m.AllowUserRefund() case paymentproviderinstance.FieldCreatedAt: return m.CreatedAt() case paymentproviderinstance.FieldUpdatedAt: @@ -16302,6 +16344,8 @@ func (m *PaymentProviderInstanceMutation) OldField(ctx context.Context, name str return m.OldLimits(ctx) case paymentproviderinstance.FieldRefundEnabled: return m.OldRefundEnabled(ctx) + case paymentproviderinstance.FieldAllowUserRefund: + return m.OldAllowUserRefund(ctx) case paymentproviderinstance.FieldCreatedAt: return m.OldCreatedAt(ctx) case paymentproviderinstance.FieldUpdatedAt: @@ -16378,6 +16422,13 @@ func (m *PaymentProviderInstanceMutation) SetField(name string, value ent.Value) } m.SetRefundEnabled(v) return nil + case paymentproviderinstance.FieldAllowUserRefund: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAllowUserRefund(v) + return nil case paymentproviderinstance.FieldCreatedAt: v, ok := value.(time.Time) if !ok { @@ -16483,6 +16534,9 @@ func (m *PaymentProviderInstanceMutation) ResetField(name string) error { case paymentproviderinstance.FieldRefundEnabled: m.ResetRefundEnabled() return nil + case paymentproviderinstance.FieldAllowUserRefund: + m.ResetAllowUserRefund() + return nil case paymentproviderinstance.FieldCreatedAt: m.ResetCreatedAt() return nil diff --git a/backend/ent/paymentproviderinstance.go b/backend/ent/paymentproviderinstance.go index 087cb13a..4279b86e 100644 --- a/backend/ent/paymentproviderinstance.go +++ b/backend/ent/paymentproviderinstance.go @@ -35,6 +35,8 @@ type PaymentProviderInstance struct { Limits string `json:"limits,omitempty"` // RefundEnabled holds the value of the "refund_enabled" field. RefundEnabled bool `json:"refund_enabled,omitempty"` + // AllowUserRefund holds the value of the "allow_user_refund" field. + AllowUserRefund bool `json:"allow_user_refund,omitempty"` // CreatedAt holds the value of the "created_at" field. CreatedAt time.Time `json:"created_at,omitempty"` // UpdatedAt holds the value of the "updated_at" field. @@ -47,7 +49,7 @@ func (*PaymentProviderInstance) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case paymentproviderinstance.FieldEnabled, paymentproviderinstance.FieldRefundEnabled: + case paymentproviderinstance.FieldEnabled, paymentproviderinstance.FieldRefundEnabled, paymentproviderinstance.FieldAllowUserRefund: values[i] = new(sql.NullBool) case paymentproviderinstance.FieldID, paymentproviderinstance.FieldSortOrder: values[i] = new(sql.NullInt64) @@ -130,6 +132,12 @@ func (_m *PaymentProviderInstance) assignValues(columns []string, values []any) } else if value.Valid { _m.RefundEnabled = value.Bool } + case paymentproviderinstance.FieldAllowUserRefund: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field allow_user_refund", values[i]) + } else if value.Valid { + _m.AllowUserRefund = value.Bool + } case paymentproviderinstance.FieldCreatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) @@ -205,6 +213,9 @@ func (_m *PaymentProviderInstance) String() string { builder.WriteString("refund_enabled=") builder.WriteString(fmt.Sprintf("%v", _m.RefundEnabled)) builder.WriteString(", ") + builder.WriteString("allow_user_refund=") + builder.WriteString(fmt.Sprintf("%v", _m.AllowUserRefund)) + builder.WriteString(", ") builder.WriteString("created_at=") builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") diff --git a/backend/ent/paymentproviderinstance/paymentproviderinstance.go b/backend/ent/paymentproviderinstance/paymentproviderinstance.go index c430fef6..eb1b0c52 100644 --- a/backend/ent/paymentproviderinstance/paymentproviderinstance.go +++ b/backend/ent/paymentproviderinstance/paymentproviderinstance.go @@ -31,6 +31,8 @@ const ( FieldLimits = "limits" // FieldRefundEnabled holds the string denoting the refund_enabled field in the database. FieldRefundEnabled = "refund_enabled" + // FieldAllowUserRefund holds the string denoting the allow_user_refund field in the database. + FieldAllowUserRefund = "allow_user_refund" // FieldCreatedAt holds the string denoting the created_at field in the database. FieldCreatedAt = "created_at" // FieldUpdatedAt holds the string denoting the updated_at field in the database. @@ -51,6 +53,7 @@ var Columns = []string{ FieldSortOrder, FieldLimits, FieldRefundEnabled, + FieldAllowUserRefund, FieldCreatedAt, FieldUpdatedAt, } @@ -88,6 +91,8 @@ var ( DefaultLimits string // DefaultRefundEnabled holds the default value on creation for the "refund_enabled" field. DefaultRefundEnabled bool + // DefaultAllowUserRefund holds the default value on creation for the "allow_user_refund" field. + DefaultAllowUserRefund bool // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. @@ -149,6 +154,11 @@ func ByRefundEnabled(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldRefundEnabled, opts...).ToFunc() } +// ByAllowUserRefund orders the results by the allow_user_refund field. +func ByAllowUserRefund(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAllowUserRefund, opts...).ToFunc() +} + // ByCreatedAt orders the results by the created_at field. func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() diff --git a/backend/ent/paymentproviderinstance/where.go b/backend/ent/paymentproviderinstance/where.go index 7b99517f..40e5a1f6 100644 --- a/backend/ent/paymentproviderinstance/where.go +++ b/backend/ent/paymentproviderinstance/where.go @@ -99,6 +99,11 @@ func RefundEnabled(v bool) predicate.PaymentProviderInstance { return predicate.PaymentProviderInstance(sql.FieldEQ(FieldRefundEnabled, v)) } +// AllowUserRefund applies equality check predicate on the "allow_user_refund" field. It's identical to AllowUserRefundEQ. +func AllowUserRefund(v bool) predicate.PaymentProviderInstance { + return predicate.PaymentProviderInstance(sql.FieldEQ(FieldAllowUserRefund, v)) +} + // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.PaymentProviderInstance { return predicate.PaymentProviderInstance(sql.FieldEQ(FieldCreatedAt, v)) @@ -559,6 +564,16 @@ func RefundEnabledNEQ(v bool) predicate.PaymentProviderInstance { return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldRefundEnabled, v)) } +// AllowUserRefundEQ applies the EQ predicate on the "allow_user_refund" field. +func AllowUserRefundEQ(v bool) predicate.PaymentProviderInstance { + return predicate.PaymentProviderInstance(sql.FieldEQ(FieldAllowUserRefund, v)) +} + +// AllowUserRefundNEQ applies the NEQ predicate on the "allow_user_refund" field. +func AllowUserRefundNEQ(v bool) predicate.PaymentProviderInstance { + return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldAllowUserRefund, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.PaymentProviderInstance { return predicate.PaymentProviderInstance(sql.FieldEQ(FieldCreatedAt, v)) diff --git a/backend/ent/paymentproviderinstance_create.go b/backend/ent/paymentproviderinstance_create.go index 20b16ddd..d1b14617 100644 --- a/backend/ent/paymentproviderinstance_create.go +++ b/backend/ent/paymentproviderinstance_create.go @@ -132,6 +132,20 @@ func (_c *PaymentProviderInstanceCreate) SetNillableRefundEnabled(v *bool) *Paym return _c } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (_c *PaymentProviderInstanceCreate) SetAllowUserRefund(v bool) *PaymentProviderInstanceCreate { + _c.mutation.SetAllowUserRefund(v) + return _c +} + +// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil. +func (_c *PaymentProviderInstanceCreate) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceCreate { + if v != nil { + _c.SetAllowUserRefund(*v) + } + return _c +} + // SetCreatedAt sets the "created_at" field. func (_c *PaymentProviderInstanceCreate) SetCreatedAt(v time.Time) *PaymentProviderInstanceCreate { _c.mutation.SetCreatedAt(v) @@ -223,6 +237,10 @@ func (_c *PaymentProviderInstanceCreate) defaults() { v := paymentproviderinstance.DefaultRefundEnabled _c.mutation.SetRefundEnabled(v) } + if _, ok := _c.mutation.AllowUserRefund(); !ok { + v := paymentproviderinstance.DefaultAllowUserRefund + _c.mutation.SetAllowUserRefund(v) + } if _, ok := _c.mutation.CreatedAt(); !ok { v := paymentproviderinstance.DefaultCreatedAt() _c.mutation.SetCreatedAt(v) @@ -282,6 +300,9 @@ func (_c *PaymentProviderInstanceCreate) check() error { if _, ok := _c.mutation.RefundEnabled(); !ok { return &ValidationError{Name: "refund_enabled", err: errors.New(`ent: missing required field "PaymentProviderInstance.refund_enabled"`)} } + if _, ok := _c.mutation.AllowUserRefund(); !ok { + return &ValidationError{Name: "allow_user_refund", err: errors.New(`ent: missing required field "PaymentProviderInstance.allow_user_refund"`)} + } if _, ok := _c.mutation.CreatedAt(); !ok { return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "PaymentProviderInstance.created_at"`)} } @@ -351,6 +372,10 @@ func (_c *PaymentProviderInstanceCreate) createSpec() (*PaymentProviderInstance, _spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value) _node.RefundEnabled = value } + if value, ok := _c.mutation.AllowUserRefund(); ok { + _spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value) + _node.AllowUserRefund = value + } if value, ok := _c.mutation.CreatedAt(); ok { _spec.SetField(paymentproviderinstance.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = value @@ -525,6 +550,18 @@ func (u *PaymentProviderInstanceUpsert) UpdateRefundEnabled() *PaymentProviderIn return u } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (u *PaymentProviderInstanceUpsert) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsert { + u.Set(paymentproviderinstance.FieldAllowUserRefund, v) + return u +} + +// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create. +func (u *PaymentProviderInstanceUpsert) UpdateAllowUserRefund() *PaymentProviderInstanceUpsert { + u.SetExcluded(paymentproviderinstance.FieldAllowUserRefund) + return u +} + // SetUpdatedAt sets the "updated_at" field. func (u *PaymentProviderInstanceUpsert) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsert { u.Set(paymentproviderinstance.FieldUpdatedAt, v) @@ -715,6 +752,20 @@ func (u *PaymentProviderInstanceUpsertOne) UpdateRefundEnabled() *PaymentProvide }) } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (u *PaymentProviderInstanceUpsertOne) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsertOne { + return u.Update(func(s *PaymentProviderInstanceUpsert) { + s.SetAllowUserRefund(v) + }) +} + +// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create. +func (u *PaymentProviderInstanceUpsertOne) UpdateAllowUserRefund() *PaymentProviderInstanceUpsertOne { + return u.Update(func(s *PaymentProviderInstanceUpsert) { + s.UpdateAllowUserRefund() + }) +} + // SetUpdatedAt sets the "updated_at" field. func (u *PaymentProviderInstanceUpsertOne) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsertOne { return u.Update(func(s *PaymentProviderInstanceUpsert) { @@ -1073,6 +1124,20 @@ func (u *PaymentProviderInstanceUpsertBulk) UpdateRefundEnabled() *PaymentProvid }) } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (u *PaymentProviderInstanceUpsertBulk) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsertBulk { + return u.Update(func(s *PaymentProviderInstanceUpsert) { + s.SetAllowUserRefund(v) + }) +} + +// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create. +func (u *PaymentProviderInstanceUpsertBulk) UpdateAllowUserRefund() *PaymentProviderInstanceUpsertBulk { + return u.Update(func(s *PaymentProviderInstanceUpsert) { + s.UpdateAllowUserRefund() + }) +} + // SetUpdatedAt sets the "updated_at" field. func (u *PaymentProviderInstanceUpsertBulk) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsertBulk { return u.Update(func(s *PaymentProviderInstanceUpsert) { diff --git a/backend/ent/paymentproviderinstance_update.go b/backend/ent/paymentproviderinstance_update.go index 06dba527..6bb3a82d 100644 --- a/backend/ent/paymentproviderinstance_update.go +++ b/backend/ent/paymentproviderinstance_update.go @@ -161,6 +161,20 @@ func (_u *PaymentProviderInstanceUpdate) SetNillableRefundEnabled(v *bool) *Paym return _u } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (_u *PaymentProviderInstanceUpdate) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpdate { + _u.mutation.SetAllowUserRefund(v) + return _u +} + +// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil. +func (_u *PaymentProviderInstanceUpdate) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceUpdate { + if v != nil { + _u.SetAllowUserRefund(*v) + } + return _u +} + // SetUpdatedAt sets the "updated_at" field. func (_u *PaymentProviderInstanceUpdate) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpdate { _u.mutation.SetUpdatedAt(v) @@ -275,6 +289,9 @@ func (_u *PaymentProviderInstanceUpdate) sqlSave(ctx context.Context) (_node int if value, ok := _u.mutation.RefundEnabled(); ok { _spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value) } + if value, ok := _u.mutation.AllowUserRefund(); ok { + _spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value) + } if value, ok := _u.mutation.UpdatedAt(); ok { _spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value) } @@ -431,6 +448,20 @@ func (_u *PaymentProviderInstanceUpdateOne) SetNillableRefundEnabled(v *bool) *P return _u } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (_u *PaymentProviderInstanceUpdateOne) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpdateOne { + _u.mutation.SetAllowUserRefund(v) + return _u +} + +// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil. +func (_u *PaymentProviderInstanceUpdateOne) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceUpdateOne { + if v != nil { + _u.SetAllowUserRefund(*v) + } + return _u +} + // SetUpdatedAt sets the "updated_at" field. func (_u *PaymentProviderInstanceUpdateOne) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpdateOne { _u.mutation.SetUpdatedAt(v) @@ -575,6 +606,9 @@ func (_u *PaymentProviderInstanceUpdateOne) sqlSave(ctx context.Context) (_node if value, ok := _u.mutation.RefundEnabled(); ok { _spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value) } + if value, ok := _u.mutation.AllowUserRefund(); ok { + _spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value) + } if value, ok := _u.mutation.UpdatedAt(); ok { _spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value) } diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go index 67f37c75..ef551940 100644 --- a/backend/ent/predicate/predicate.go +++ b/backend/ent/predicate/predicate.go @@ -33,7 +33,6 @@ type IdempotencyRecord func(*sql.Selector) // PaymentAuditLog is the predicate function for paymentauditlog builders. type PaymentAuditLog func(*sql.Selector) - // PaymentOrder is the predicate function for paymentorder builders. type PaymentOrder func(*sql.Selector) diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 951b5f99..fbdd08c7 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -668,12 +668,16 @@ func init() { paymentproviderinstanceDescRefundEnabled := paymentproviderinstanceFields[8].Descriptor() // paymentproviderinstance.DefaultRefundEnabled holds the default value on creation for the refund_enabled field. paymentproviderinstance.DefaultRefundEnabled = paymentproviderinstanceDescRefundEnabled.Default.(bool) + // paymentproviderinstanceDescAllowUserRefund is the schema descriptor for allow_user_refund field. + paymentproviderinstanceDescAllowUserRefund := paymentproviderinstanceFields[9].Descriptor() + // paymentproviderinstance.DefaultAllowUserRefund holds the default value on creation for the allow_user_refund field. + paymentproviderinstance.DefaultAllowUserRefund = paymentproviderinstanceDescAllowUserRefund.Default.(bool) // paymentproviderinstanceDescCreatedAt is the schema descriptor for created_at field. - paymentproviderinstanceDescCreatedAt := paymentproviderinstanceFields[9].Descriptor() + paymentproviderinstanceDescCreatedAt := paymentproviderinstanceFields[10].Descriptor() // paymentproviderinstance.DefaultCreatedAt holds the default value on creation for the created_at field. paymentproviderinstance.DefaultCreatedAt = paymentproviderinstanceDescCreatedAt.Default.(func() time.Time) // paymentproviderinstanceDescUpdatedAt is the schema descriptor for updated_at field. - paymentproviderinstanceDescUpdatedAt := paymentproviderinstanceFields[10].Descriptor() + paymentproviderinstanceDescUpdatedAt := paymentproviderinstanceFields[11].Descriptor() // paymentproviderinstance.DefaultUpdatedAt holds the default value on creation for the updated_at field. paymentproviderinstance.DefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.Default.(func() time.Time) // paymentproviderinstance.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. diff --git a/backend/ent/schema/payment_provider_instance.go b/backend/ent/schema/payment_provider_instance.go index 08ab7d31..e4c0b72c 100644 --- a/backend/ent/schema/payment_provider_instance.go +++ b/backend/ent/schema/payment_provider_instance.go @@ -53,6 +53,8 @@ func (PaymentProviderInstance) Fields() []ent.Field { Default(""), field.Bool("refund_enabled"). Default(false), + field.Bool("allow_user_refund"). + Default(false), field.Time("created_at"). Immutable(). Default(time.Now). diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go index 0425fc49..5fde86fa 100644 --- a/backend/internal/handler/payment_handler.go +++ b/backend/internal/handler/payment_handler.go @@ -335,6 +335,16 @@ func (h *PaymentHandler) RequestRefund(c *gin.Context) { response.Success(c, gin.H{"message": "refund requested"}) } +// GetRefundEligibleProviders returns provider instance IDs that allow user refund. +func (h *PaymentHandler) GetRefundEligibleProviders(c *gin.Context) { + ids, err := h.configService.GetUserRefundEligibleInstanceIDs(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"provider_instance_ids": ids}) +} + // VerifyOrderRequest is the request body for verifying a payment order. type VerifyOrderRequest struct { OutTradeNo string `json:"out_trade_no" binding:"required"` diff --git a/backend/internal/server/routes/payment.go b/backend/internal/server/routes/payment.go index 72012a4e..8def7559 100644 --- a/backend/internal/server/routes/payment.go +++ b/backend/internal/server/routes/payment.go @@ -37,6 +37,7 @@ func RegisterPaymentRoutes( orders.GET("/:id", paymentHandler.GetOrder) orders.POST("/:id/cancel", paymentHandler.CancelOrder) orders.POST("/:id/refund-request", paymentHandler.RequestRefund) + orders.GET("/refund-eligible-providers", paymentHandler.GetRefundEligibleProviders) } } diff --git a/backend/internal/service/account_stats_pricing.go b/backend/internal/service/account_stats_pricing.go index 47b7496f..90ff450f 100644 --- a/backend/internal/service/account_stats_pricing.go +++ b/backend/internal/service/account_stats_pricing.go @@ -2,7 +2,6 @@ package service import ( "context" - "sort" "strings" ) @@ -116,14 +115,8 @@ func matchAccountStatsRule(rule *AccountStatsPricingRule, accountID, groupID int return false } -// wildcardMatch 通配符匹配候选项(用于排序) -type wildcardMatch struct { - prefixLen int - pricing *ChannelModelPricing -} - // findPricingForModel 在定价列表中查找匹配的模型定价。 -// 先精确匹配,再通配符匹配(前缀越长优先级越高)。 +// 先精确匹配,再通配符匹配(按配置顺序,先匹配先使用)。 func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower string) *ChannelModelPricing { // 精确匹配优先 for i := range pricingList { @@ -137,8 +130,7 @@ func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower } } } - // 通配符匹配:收集所有匹配项,按前缀长度降序取最长 - var matches []wildcardMatch + // 通配符匹配:按配置顺序,先匹配先使用 for i := range pricingList { p := &pricingList[i] if !isPlatformMatch(platform, p.Platform) { @@ -151,17 +143,11 @@ func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower } prefix := strings.TrimSuffix(ml, "*") if strings.HasPrefix(modelLower, prefix) { - matches = append(matches, wildcardMatch{prefixLen: len(prefix), pricing: p}) + return p } } } - if len(matches) == 0 { - return nil - } - sort.Slice(matches, func(i, j int) bool { - return matches[i].prefixLen > matches[j].prefixLen - }) - return matches[0].pricing + return nil } // isPlatformMatch 判断平台是否匹配(空平台视为不限平台)。 diff --git a/backend/internal/service/account_stats_pricing_test.go b/backend/internal/service/account_stats_pricing_test.go index 2f625393..36e5eb74 100644 --- a/backend/internal/service/account_stats_pricing_test.go +++ b/backend/internal/service/account_stats_pricing_test.go @@ -147,14 +147,14 @@ func TestFindPricingForModel(t *testing.T) { wantNil: true, }, { - name: "wildcard matches by longest prefix (most specific wins)", + name: "wildcard matches by config order (first match wins)", list: []ChannelModelPricing{ {ID: 10, Models: []string{"claude-*"}}, {ID: 11, Models: []string{"claude-opus-*"}}, }, platform: "", model: "claude-opus-4", - wantID: 11, // "claude-opus-*" is longer prefix, wins over "claude-*" + wantID: 10, // config order: "claude-*" is first and matches, so it wins }, { name: "shorter wildcard used when longer does not match", diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go index 47008df0..0f7cb99a 100644 --- a/backend/internal/service/payment_config_providers.go +++ b/backend/internal/service/payment_config_providers.go @@ -22,16 +22,17 @@ func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*db // ProviderInstanceResponse is the API response for a provider instance. type ProviderInstanceResponse struct { - ID int64 `json:"id"` - ProviderKey string `json:"provider_key"` - Name string `json:"name"` - Config map[string]string `json:"config"` - SupportedTypes []string `json:"supported_types"` - Limits string `json:"limits"` - Enabled bool `json:"enabled"` - RefundEnabled bool `json:"refund_enabled"` - SortOrder int `json:"sort_order"` - PaymentMode string `json:"payment_mode"` + ID int64 `json:"id"` + ProviderKey string `json:"provider_key"` + Name string `json:"name"` + Config map[string]string `json:"config"` + SupportedTypes []string `json:"supported_types"` + Limits string `json:"limits"` + Enabled bool `json:"enabled"` + RefundEnabled bool `json:"refund_enabled"` + AllowUserRefund bool `json:"allow_user_refund"` + SortOrder int `json:"sort_order"` + PaymentMode string `json:"payment_mode"` } // ListProviderInstancesWithConfig returns provider instances with decrypted config. @@ -47,7 +48,8 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name, SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits, Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, - SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode, + AllowUserRefund: inst.AllowUserRefund, + SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode, } resp.Config, err = s.decryptAndMaskConfig(inst.Config) if err != nil { @@ -110,10 +112,12 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C if err != nil { return nil, err } + allowUserRefund := req.AllowUserRefund && req.RefundEnabled return s.entClient.PaymentProviderInstance.Create(). SetProviderKey(req.ProviderKey).SetName(req.Name).SetConfig(enc). SetSupportedTypes(typesStr).SetEnabled(req.Enabled).SetPaymentMode(req.PaymentMode). SetSortOrder(req.SortOrder).SetLimits(req.Limits).SetRefundEnabled(req.RefundEnabled). + SetAllowUserRefund(allowUserRefund). Save(ctx) } @@ -221,6 +225,21 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in } if req.RefundEnabled != nil { u.SetRefundEnabled(*req.RefundEnabled) + // Cascade: turning off refund_enabled also disables allow_user_refund + if !*req.RefundEnabled { + u.SetAllowUserRefund(false) + } + } + if req.AllowUserRefund != nil { + // Only allow enabling when refund_enabled is true + if *req.AllowUserRefund { + inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id) + if err == nil && inst.RefundEnabled { + u.SetAllowUserRefund(true) + } + } else { + u.SetAllowUserRefund(false) + } } if req.PaymentMode != nil { u.SetPaymentMode(*req.PaymentMode) @@ -233,6 +252,7 @@ func (s *PaymentConfigService) GetUserRefundEligibleInstanceIDs(ctx context.Cont instances, err := s.entClient.PaymentProviderInstance.Query(). Where( paymentproviderinstance.RefundEnabledEQ(true), + paymentproviderinstance.AllowUserRefundEQ(true), ).Select(paymentproviderinstance.FieldID).All(ctx) if err != nil { return nil, err diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go index 9042c3ab..cce31f4d 100644 --- a/backend/internal/service/payment_config_service.go +++ b/backend/internal/service/payment_config_service.go @@ -105,26 +105,28 @@ type MethodLimitsResponse struct { } type CreateProviderInstanceRequest struct { - ProviderKey string `json:"provider_key"` - Name string `json:"name"` - Config map[string]string `json:"config"` - SupportedTypes []string `json:"supported_types"` - Enabled bool `json:"enabled"` - PaymentMode string `json:"payment_mode"` - SortOrder int `json:"sort_order"` - Limits string `json:"limits"` - RefundEnabled bool `json:"refund_enabled"` + ProviderKey string `json:"provider_key"` + Name string `json:"name"` + Config map[string]string `json:"config"` + SupportedTypes []string `json:"supported_types"` + Enabled bool `json:"enabled"` + PaymentMode string `json:"payment_mode"` + SortOrder int `json:"sort_order"` + Limits string `json:"limits"` + RefundEnabled bool `json:"refund_enabled"` + AllowUserRefund bool `json:"allow_user_refund"` } type UpdateProviderInstanceRequest struct { - Name *string `json:"name"` - Config map[string]string `json:"config"` - SupportedTypes []string `json:"supported_types"` - Enabled *bool `json:"enabled"` - PaymentMode *string `json:"payment_mode"` - SortOrder *int `json:"sort_order"` - Limits *string `json:"limits"` - RefundEnabled *bool `json:"refund_enabled"` + Name *string `json:"name"` + Config map[string]string `json:"config"` + SupportedTypes []string `json:"supported_types"` + Enabled *bool `json:"enabled"` + PaymentMode *string `json:"payment_mode"` + SortOrder *int `json:"sort_order"` + Limits *string `json:"limits"` + RefundEnabled *bool `json:"refund_enabled"` + AllowUserRefund *bool `json:"allow_user_refund"` } type CreatePlanRequest struct { GroupID int64 `json:"group_id"` diff --git a/backend/internal/service/payment_refund.go b/backend/internal/service/payment_refund.go index 68f9c697..75d75b2f 100644 --- a/backend/internal/service/payment_refund.go +++ b/backend/internal/service/payment_refund.go @@ -17,6 +17,19 @@ import ( // --- Refund Flow --- +// getOrderProviderInstance looks up the provider instance that processed this order. +// Returns nil, nil for legacy orders without provider_instance_id. +func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) { + if o.ProviderInstanceID == nil || *o.ProviderInstanceID == "" { + return nil, nil + } + instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64) + if err != nil { + return nil, nil + } + return s.entClient.PaymentProviderInstance.Get(ctx, instID) +} + func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error { o, err := s.validateRefundRequest(ctx, oid, uid) if err != nil { @@ -57,6 +70,14 @@ func (s *PaymentService) validateRefundRequest(ctx context.Context, oid, uid int if o.Status != OrderStatusCompleted { return nil, infraerrors.BadRequest("INVALID_STATUS", "only completed orders can request refund") } + // Check provider instance allows user refund + inst, err := s.getOrderProviderInstance(ctx, o) + if err != nil || inst == nil { + return nil, infraerrors.Forbidden("USER_REFUND_DISABLED", "refund is not available for this order") + } + if !inst.AllowUserRefund { + return nil, infraerrors.Forbidden("USER_REFUND_DISABLED", "user refund is not enabled for this provider") + } return o, nil } @@ -69,6 +90,18 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float if !psSliceContains(ok, o.Status) { return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund") } + // Check provider instance allows admin refund + inst, instErr := s.getOrderProviderInstance(ctx, o) + if instErr != nil { + slog.Warn("refund: provider instance not found", "orderID", oid, "error", instErr) + } + if inst != nil && !inst.RefundEnabled { + return nil, nil, infraerrors.Forbidden("REFUND_DISABLED", "refund is not enabled for this provider") + } + if inst == nil && instErr == nil { + // Legacy order without provider_instance_id — block refund + return nil, nil, infraerrors.Forbidden("REFUND_DISABLED", "refund is not available for this order") + } if math.IsNaN(amt) || math.IsInf(amt, 0) { return nil, nil, infraerrors.BadRequest("INVALID_AMOUNT", "invalid refund amount") } @@ -102,6 +135,15 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float func (s *PaymentService) prepDeduct(ctx context.Context, o *dbent.PaymentOrder, p *RefundPlan, force bool) *RefundResult { if o.OrderType == payment.OrderTypeSubscription { p.DeductionType = payment.DeductionTypeSubscription + if o.SubscriptionGroupID != nil && o.SubscriptionDays != nil { + p.SubDaysToDeduct = *o.SubscriptionDays + sub, err := s.subscriptionSvc.GetActiveSubscription(ctx, o.UserID, *o.SubscriptionGroupID) + if err == nil && sub != nil { + p.SubscriptionID = sub.ID + } else if !force { + return &RefundResult{Success: false, Warning: "cannot find active subscription for deduction, use force", RequireForce: true} + } + } return nil } u, err := s.userRepo.GetByID(ctx, o.UserID) @@ -137,6 +179,21 @@ func (s *PaymentService) ExecuteRefund(ctx context.Context, p *RefundPlan) (*Ref p.BalanceToDeduct = 0 } } + if p.DeductionType == payment.DeductionTypeSubscription && p.SubDaysToDeduct > 0 && p.SubscriptionID > 0 { + if !s.hasAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED") { + _, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, -p.SubDaysToDeduct) + if err != nil { + slog.Info("subscription deduction would expire, revoking", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct) + if revokeErr := s.subscriptionSvc.RevokeSubscription(ctx, p.SubscriptionID); revokeErr != nil { + s.restoreStatus(ctx, p) + return nil, fmt.Errorf("revoke subscription: %w", revokeErr) + } + } + } else { + slog.Warn("skipping subscription deduction on retry (previous rollback failed)", "orderID", p.OrderID) + p.SubDaysToDeduct = 0 + } + } if err := s.gwRefund(ctx, p); err != nil { return s.handleGwFail(ctx, p, err) } @@ -204,6 +261,13 @@ func (s *PaymentService) RollbackRefund(ctx context.Context, p *RefundPlan, gErr return false } } + if p.DeductionType == payment.DeductionTypeSubscription && p.SubDaysToDeduct > 0 && p.SubscriptionID > 0 { + if _, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, p.SubDaysToDeduct); err != nil { + slog.Error("[CRITICAL] subscription rollback failed", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct, "error", err) + s.writeAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED", "admin", map[string]any{"gatewayError": psErrMsg(gErr), "rollbackError": psErrMsg(err), "subDaysDeducted": p.SubDaysToDeduct}) + return false + } + } return true } diff --git a/backend/migrations/103_add_allow_user_refund.sql b/backend/migrations/103_add_allow_user_refund.sql new file mode 100644 index 00000000..79525382 --- /dev/null +++ b/backend/migrations/103_add_allow_user_refund.sql @@ -0,0 +1 @@ +ALTER TABLE payment_provider_instances ADD COLUMN IF NOT EXISTS allow_user_refund BOOLEAN NOT NULL DEFAULT false; diff --git a/frontend/src/api/payment.ts b/frontend/src/api/payment.ts index 1389b60f..5cedb107 100644 --- a/frontend/src/api/payment.ts +++ b/frontend/src/api/payment.ts @@ -75,5 +75,10 @@ export const paymentAPI = { /** Request a refund for a completed order */ requestRefund(id: number, data: { reason: string }) { return apiClient.post(`/payment/orders/${id}/refund-request`, data) + }, + + /** Get provider instance IDs that allow user refund */ + getRefundEligibleProviders() { + return apiClient.get<{ provider_instance_ids: string[] }>('/payment/orders/refund-eligible-providers') } } diff --git a/frontend/src/components/payment/PaymentProviderDialog.vue b/frontend/src/components/payment/PaymentProviderDialog.vue index 9b60cba1..10c1bfea 100644 --- a/frontend/src/components/payment/PaymentProviderDialog.vue +++ b/frontend/src/components/payment/PaymentProviderDialog.vue @@ -32,7 +32,8 @@