diff --git a/backend/ent/authidentity.go b/backend/ent/authidentity.go new file mode 100644 index 00000000..5ccfcf19 --- /dev/null +++ b/backend/ent/authidentity.go @@ -0,0 +1,266 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// AuthIdentity is the model entity for the AuthIdentity schema. +type AuthIdentity struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // UserID holds the value of the "user_id" field. + UserID int64 `json:"user_id,omitempty"` + // ProviderType holds the value of the "provider_type" field. + ProviderType string `json:"provider_type,omitempty"` + // ProviderKey holds the value of the "provider_key" field. + ProviderKey string `json:"provider_key,omitempty"` + // ProviderSubject holds the value of the "provider_subject" field. + ProviderSubject string `json:"provider_subject,omitempty"` + // VerifiedAt holds the value of the "verified_at" field. + VerifiedAt *time.Time `json:"verified_at,omitempty"` + // Issuer holds the value of the "issuer" field. + Issuer *string `json:"issuer,omitempty"` + // Metadata holds the value of the "metadata" field. + Metadata map[string]interface{} `json:"metadata,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the AuthIdentityQuery when eager-loading is set. + Edges AuthIdentityEdges `json:"edges"` + selectValues sql.SelectValues +} + +// AuthIdentityEdges holds the relations/edges for other nodes in the graph. +type AuthIdentityEdges struct { + // User holds the value of the user edge. + User *User `json:"user,omitempty"` + // Channels holds the value of the channels edge. + Channels []*AuthIdentityChannel `json:"channels,omitempty"` + // AdoptionDecisions holds the value of the adoption_decisions edge. + AdoptionDecisions []*IdentityAdoptionDecision `json:"adoption_decisions,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [3]bool +} + +// UserOrErr returns the User value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e AuthIdentityEdges) UserOrErr() (*User, error) { + if e.User != nil { + return e.User, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: user.Label} + } + return nil, &NotLoadedError{edge: "user"} +} + +// ChannelsOrErr returns the Channels value or an error if the edge +// was not loaded in eager-loading. +func (e AuthIdentityEdges) ChannelsOrErr() ([]*AuthIdentityChannel, error) { + if e.loadedTypes[1] { + return e.Channels, nil + } + return nil, &NotLoadedError{edge: "channels"} +} + +// AdoptionDecisionsOrErr returns the AdoptionDecisions value or an error if the edge +// was not loaded in eager-loading. +func (e AuthIdentityEdges) AdoptionDecisionsOrErr() ([]*IdentityAdoptionDecision, error) { + if e.loadedTypes[2] { + return e.AdoptionDecisions, nil + } + return nil, &NotLoadedError{edge: "adoption_decisions"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*AuthIdentity) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case authidentity.FieldMetadata: + values[i] = new([]byte) + case authidentity.FieldID, authidentity.FieldUserID: + values[i] = new(sql.NullInt64) + case authidentity.FieldProviderType, authidentity.FieldProviderKey, authidentity.FieldProviderSubject, authidentity.FieldIssuer: + values[i] = new(sql.NullString) + case authidentity.FieldCreatedAt, authidentity.FieldUpdatedAt, authidentity.FieldVerifiedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the AuthIdentity fields. +func (_m *AuthIdentity) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case authidentity.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case authidentity.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case authidentity.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case authidentity.FieldUserID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field user_id", values[i]) + } else if value.Valid { + _m.UserID = value.Int64 + } + case authidentity.FieldProviderType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_type", values[i]) + } else if value.Valid { + _m.ProviderType = value.String + } + case authidentity.FieldProviderKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_key", values[i]) + } else if value.Valid { + _m.ProviderKey = value.String + } + case authidentity.FieldProviderSubject: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_subject", values[i]) + } else if value.Valid { + _m.ProviderSubject = value.String + } + case authidentity.FieldVerifiedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field verified_at", values[i]) + } else if value.Valid { + _m.VerifiedAt = new(time.Time) + *_m.VerifiedAt = value.Time + } + case authidentity.FieldIssuer: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field issuer", values[i]) + } else if value.Valid { + _m.Issuer = new(string) + *_m.Issuer = value.String + } + case authidentity.FieldMetadata: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field metadata", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.Metadata); err != nil { + return fmt.Errorf("unmarshal field metadata: %w", err) + } + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the AuthIdentity. +// This includes values selected through modifiers, order, etc. +func (_m *AuthIdentity) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryUser queries the "user" edge of the AuthIdentity entity. +func (_m *AuthIdentity) QueryUser() *UserQuery { + return NewAuthIdentityClient(_m.config).QueryUser(_m) +} + +// QueryChannels queries the "channels" edge of the AuthIdentity entity. +func (_m *AuthIdentity) QueryChannels() *AuthIdentityChannelQuery { + return NewAuthIdentityClient(_m.config).QueryChannels(_m) +} + +// QueryAdoptionDecisions queries the "adoption_decisions" edge of the AuthIdentity entity. +func (_m *AuthIdentity) QueryAdoptionDecisions() *IdentityAdoptionDecisionQuery { + return NewAuthIdentityClient(_m.config).QueryAdoptionDecisions(_m) +} + +// Update returns a builder for updating this AuthIdentity. +// Note that you need to call AuthIdentity.Unwrap() before calling this method if this AuthIdentity +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *AuthIdentity) Update() *AuthIdentityUpdateOne { + return NewAuthIdentityClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the AuthIdentity entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *AuthIdentity) Unwrap() *AuthIdentity { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: AuthIdentity is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *AuthIdentity) String() string { + var builder strings.Builder + builder.WriteString("AuthIdentity(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("user_id=") + builder.WriteString(fmt.Sprintf("%v", _m.UserID)) + builder.WriteString(", ") + builder.WriteString("provider_type=") + builder.WriteString(_m.ProviderType) + builder.WriteString(", ") + builder.WriteString("provider_key=") + builder.WriteString(_m.ProviderKey) + builder.WriteString(", ") + builder.WriteString("provider_subject=") + builder.WriteString(_m.ProviderSubject) + builder.WriteString(", ") + if v := _m.VerifiedAt; v != nil { + builder.WriteString("verified_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.Issuer; v != nil { + builder.WriteString("issuer=") + builder.WriteString(*v) + } + builder.WriteString(", ") + builder.WriteString("metadata=") + builder.WriteString(fmt.Sprintf("%v", _m.Metadata)) + builder.WriteByte(')') + return builder.String() +} + +// AuthIdentities is a parsable slice of AuthIdentity. +type AuthIdentities []*AuthIdentity diff --git a/backend/ent/authidentity/authidentity.go b/backend/ent/authidentity/authidentity.go new file mode 100644 index 00000000..c90be759 --- /dev/null +++ b/backend/ent/authidentity/authidentity.go @@ -0,0 +1,209 @@ +// Code generated by ent, DO NOT EDIT. + +package authidentity + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the authidentity type in the database. + Label = "auth_identity" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldUserID holds the string denoting the user_id field in the database. + FieldUserID = "user_id" + // FieldProviderType holds the string denoting the provider_type field in the database. + FieldProviderType = "provider_type" + // FieldProviderKey holds the string denoting the provider_key field in the database. + FieldProviderKey = "provider_key" + // FieldProviderSubject holds the string denoting the provider_subject field in the database. + FieldProviderSubject = "provider_subject" + // FieldVerifiedAt holds the string denoting the verified_at field in the database. + FieldVerifiedAt = "verified_at" + // FieldIssuer holds the string denoting the issuer field in the database. + FieldIssuer = "issuer" + // FieldMetadata holds the string denoting the metadata field in the database. + FieldMetadata = "metadata" + // EdgeUser holds the string denoting the user edge name in mutations. + EdgeUser = "user" + // EdgeChannels holds the string denoting the channels edge name in mutations. + EdgeChannels = "channels" + // EdgeAdoptionDecisions holds the string denoting the adoption_decisions edge name in mutations. + EdgeAdoptionDecisions = "adoption_decisions" + // Table holds the table name of the authidentity in the database. + Table = "auth_identities" + // UserTable is the table that holds the user relation/edge. + UserTable = "auth_identities" + // UserInverseTable is the table name for the User entity. + // It exists in this package in order to avoid circular dependency with the "user" package. + UserInverseTable = "users" + // UserColumn is the table column denoting the user relation/edge. + UserColumn = "user_id" + // ChannelsTable is the table that holds the channels relation/edge. + ChannelsTable = "auth_identity_channels" + // ChannelsInverseTable is the table name for the AuthIdentityChannel entity. + // It exists in this package in order to avoid circular dependency with the "authidentitychannel" package. + ChannelsInverseTable = "auth_identity_channels" + // ChannelsColumn is the table column denoting the channels relation/edge. + ChannelsColumn = "identity_id" + // AdoptionDecisionsTable is the table that holds the adoption_decisions relation/edge. + AdoptionDecisionsTable = "identity_adoption_decisions" + // AdoptionDecisionsInverseTable is the table name for the IdentityAdoptionDecision entity. + // It exists in this package in order to avoid circular dependency with the "identityadoptiondecision" package. + AdoptionDecisionsInverseTable = "identity_adoption_decisions" + // AdoptionDecisionsColumn is the table column denoting the adoption_decisions relation/edge. + AdoptionDecisionsColumn = "identity_id" +) + +// Columns holds all SQL columns for authidentity fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldUserID, + FieldProviderType, + FieldProviderKey, + FieldProviderSubject, + FieldVerifiedAt, + FieldIssuer, + FieldMetadata, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + ProviderTypeValidator func(string) error + // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + ProviderKeyValidator func(string) error + // ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save. + ProviderSubjectValidator func(string) error + // DefaultMetadata holds the default value on creation for the "metadata" field. + DefaultMetadata func() map[string]interface{} +) + +// OrderOption defines the ordering options for the AuthIdentity queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByUserID orders the results by the user_id field. +func ByUserID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUserID, opts...).ToFunc() +} + +// ByProviderType orders the results by the provider_type field. +func ByProviderType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderType, opts...).ToFunc() +} + +// ByProviderKey orders the results by the provider_key field. +func ByProviderKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderKey, opts...).ToFunc() +} + +// ByProviderSubject orders the results by the provider_subject field. +func ByProviderSubject(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderSubject, opts...).ToFunc() +} + +// ByVerifiedAt orders the results by the verified_at field. +func ByVerifiedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldVerifiedAt, opts...).ToFunc() +} + +// ByIssuer orders the results by the issuer field. +func ByIssuer(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIssuer, opts...).ToFunc() +} + +// ByUserField orders the results by user field. +func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...)) + } +} + +// ByChannelsCount orders the results by channels count. +func ByChannelsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newChannelsStep(), opts...) + } +} + +// ByChannels orders the results by channels terms. +func ByChannels(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newChannelsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByAdoptionDecisionsCount orders the results by adoption_decisions count. +func ByAdoptionDecisionsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newAdoptionDecisionsStep(), opts...) + } +} + +// ByAdoptionDecisions orders the results by adoption_decisions terms. +func ByAdoptionDecisions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAdoptionDecisionsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} +func newUserStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(UserInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) +} +func newChannelsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(ChannelsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, ChannelsTable, ChannelsColumn), + ) +} +func newAdoptionDecisionsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AdoptionDecisionsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, AdoptionDecisionsTable, AdoptionDecisionsColumn), + ) +} diff --git a/backend/ent/authidentity/where.go b/backend/ent/authidentity/where.go new file mode 100644 index 00000000..3dbf3178 --- /dev/null +++ b/backend/ent/authidentity/where.go @@ -0,0 +1,600 @@ +// Code generated by ent, DO NOT EDIT. + +package authidentity + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ. +func UserID(v int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldUserID, v)) +} + +// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ. +func ProviderType(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ. +func ProviderKey(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderSubject applies equality check predicate on the "provider_subject" field. It's identical to ProviderSubjectEQ. +func ProviderSubject(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderSubject, v)) +} + +// VerifiedAt applies equality check predicate on the "verified_at" field. It's identical to VerifiedAtEQ. +func VerifiedAt(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldVerifiedAt, v)) +} + +// Issuer applies equality check predicate on the "issuer" field. It's identical to IssuerEQ. +func Issuer(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldIssuer, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// UserIDEQ applies the EQ predicate on the "user_id" field. +func UserIDEQ(v int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldUserID, v)) +} + +// UserIDNEQ applies the NEQ predicate on the "user_id" field. +func UserIDNEQ(v int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldUserID, v)) +} + +// UserIDIn applies the In predicate on the "user_id" field. +func UserIDIn(vs ...int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldUserID, vs...)) +} + +// UserIDNotIn applies the NotIn predicate on the "user_id" field. +func UserIDNotIn(vs ...int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldUserID, vs...)) +} + +// ProviderTypeEQ applies the EQ predicate on the "provider_type" field. +func ProviderTypeEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field. +func ProviderTypeNEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderType, v)) +} + +// ProviderTypeIn applies the In predicate on the "provider_type" field. +func ProviderTypeIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldProviderType, vs...)) +} + +// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field. +func ProviderTypeNotIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderType, vs...)) +} + +// ProviderTypeGT applies the GT predicate on the "provider_type" field. +func ProviderTypeGT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldProviderType, v)) +} + +// ProviderTypeGTE applies the GTE predicate on the "provider_type" field. +func ProviderTypeGTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldProviderType, v)) +} + +// ProviderTypeLT applies the LT predicate on the "provider_type" field. +func ProviderTypeLT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldProviderType, v)) +} + +// ProviderTypeLTE applies the LTE predicate on the "provider_type" field. +func ProviderTypeLTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldProviderType, v)) +} + +// ProviderTypeContains applies the Contains predicate on the "provider_type" field. +func ProviderTypeContains(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContains(FieldProviderType, v)) +} + +// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field. +func ProviderTypeHasPrefix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderType, v)) +} + +// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field. +func ProviderTypeHasSuffix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderType, v)) +} + +// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field. +func ProviderTypeEqualFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderType, v)) +} + +// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field. +func ProviderTypeContainsFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderType, v)) +} + +// ProviderKeyEQ applies the EQ predicate on the "provider_key" field. +func ProviderKeyEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field. +func ProviderKeyNEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderKey, v)) +} + +// ProviderKeyIn applies the In predicate on the "provider_key" field. +func ProviderKeyIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldProviderKey, vs...)) +} + +// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field. +func ProviderKeyNotIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderKey, vs...)) +} + +// ProviderKeyGT applies the GT predicate on the "provider_key" field. +func ProviderKeyGT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldProviderKey, v)) +} + +// ProviderKeyGTE applies the GTE predicate on the "provider_key" field. +func ProviderKeyGTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldProviderKey, v)) +} + +// ProviderKeyLT applies the LT predicate on the "provider_key" field. +func ProviderKeyLT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldProviderKey, v)) +} + +// ProviderKeyLTE applies the LTE predicate on the "provider_key" field. +func ProviderKeyLTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldProviderKey, v)) +} + +// ProviderKeyContains applies the Contains predicate on the "provider_key" field. +func ProviderKeyContains(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContains(FieldProviderKey, v)) +} + +// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field. +func ProviderKeyHasPrefix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderKey, v)) +} + +// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field. +func ProviderKeyHasSuffix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderKey, v)) +} + +// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field. +func ProviderKeyEqualFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderKey, v)) +} + +// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field. +func ProviderKeyContainsFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderKey, v)) +} + +// ProviderSubjectEQ applies the EQ predicate on the "provider_subject" field. +func ProviderSubjectEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderSubject, v)) +} + +// ProviderSubjectNEQ applies the NEQ predicate on the "provider_subject" field. +func ProviderSubjectNEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderSubject, v)) +} + +// ProviderSubjectIn applies the In predicate on the "provider_subject" field. +func ProviderSubjectIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldProviderSubject, vs...)) +} + +// ProviderSubjectNotIn applies the NotIn predicate on the "provider_subject" field. +func ProviderSubjectNotIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderSubject, vs...)) +} + +// ProviderSubjectGT applies the GT predicate on the "provider_subject" field. +func ProviderSubjectGT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldProviderSubject, v)) +} + +// ProviderSubjectGTE applies the GTE predicate on the "provider_subject" field. +func ProviderSubjectGTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldProviderSubject, v)) +} + +// ProviderSubjectLT applies the LT predicate on the "provider_subject" field. +func ProviderSubjectLT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldProviderSubject, v)) +} + +// ProviderSubjectLTE applies the LTE predicate on the "provider_subject" field. +func ProviderSubjectLTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldProviderSubject, v)) +} + +// ProviderSubjectContains applies the Contains predicate on the "provider_subject" field. +func ProviderSubjectContains(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContains(FieldProviderSubject, v)) +} + +// ProviderSubjectHasPrefix applies the HasPrefix predicate on the "provider_subject" field. +func ProviderSubjectHasPrefix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderSubject, v)) +} + +// ProviderSubjectHasSuffix applies the HasSuffix predicate on the "provider_subject" field. +func ProviderSubjectHasSuffix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderSubject, v)) +} + +// ProviderSubjectEqualFold applies the EqualFold predicate on the "provider_subject" field. +func ProviderSubjectEqualFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderSubject, v)) +} + +// ProviderSubjectContainsFold applies the ContainsFold predicate on the "provider_subject" field. +func ProviderSubjectContainsFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderSubject, v)) +} + +// VerifiedAtEQ applies the EQ predicate on the "verified_at" field. +func VerifiedAtEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldVerifiedAt, v)) +} + +// VerifiedAtNEQ applies the NEQ predicate on the "verified_at" field. +func VerifiedAtNEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldVerifiedAt, v)) +} + +// VerifiedAtIn applies the In predicate on the "verified_at" field. +func VerifiedAtIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldVerifiedAt, vs...)) +} + +// VerifiedAtNotIn applies the NotIn predicate on the "verified_at" field. +func VerifiedAtNotIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldVerifiedAt, vs...)) +} + +// VerifiedAtGT applies the GT predicate on the "verified_at" field. +func VerifiedAtGT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldVerifiedAt, v)) +} + +// VerifiedAtGTE applies the GTE predicate on the "verified_at" field. +func VerifiedAtGTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldVerifiedAt, v)) +} + +// VerifiedAtLT applies the LT predicate on the "verified_at" field. +func VerifiedAtLT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldVerifiedAt, v)) +} + +// VerifiedAtLTE applies the LTE predicate on the "verified_at" field. +func VerifiedAtLTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldVerifiedAt, v)) +} + +// VerifiedAtIsNil applies the IsNil predicate on the "verified_at" field. +func VerifiedAtIsNil() predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIsNull(FieldVerifiedAt)) +} + +// VerifiedAtNotNil applies the NotNil predicate on the "verified_at" field. +func VerifiedAtNotNil() predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotNull(FieldVerifiedAt)) +} + +// IssuerEQ applies the EQ predicate on the "issuer" field. +func IssuerEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldIssuer, v)) +} + +// IssuerNEQ applies the NEQ predicate on the "issuer" field. +func IssuerNEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldIssuer, v)) +} + +// IssuerIn applies the In predicate on the "issuer" field. +func IssuerIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldIssuer, vs...)) +} + +// IssuerNotIn applies the NotIn predicate on the "issuer" field. +func IssuerNotIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldIssuer, vs...)) +} + +// IssuerGT applies the GT predicate on the "issuer" field. +func IssuerGT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldIssuer, v)) +} + +// IssuerGTE applies the GTE predicate on the "issuer" field. +func IssuerGTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldIssuer, v)) +} + +// IssuerLT applies the LT predicate on the "issuer" field. +func IssuerLT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldIssuer, v)) +} + +// IssuerLTE applies the LTE predicate on the "issuer" field. +func IssuerLTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldIssuer, v)) +} + +// IssuerContains applies the Contains predicate on the "issuer" field. +func IssuerContains(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContains(FieldIssuer, v)) +} + +// IssuerHasPrefix applies the HasPrefix predicate on the "issuer" field. +func IssuerHasPrefix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasPrefix(FieldIssuer, v)) +} + +// IssuerHasSuffix applies the HasSuffix predicate on the "issuer" field. +func IssuerHasSuffix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasSuffix(FieldIssuer, v)) +} + +// IssuerIsNil applies the IsNil predicate on the "issuer" field. +func IssuerIsNil() predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIsNull(FieldIssuer)) +} + +// IssuerNotNil applies the NotNil predicate on the "issuer" field. +func IssuerNotNil() predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotNull(FieldIssuer)) +} + +// IssuerEqualFold applies the EqualFold predicate on the "issuer" field. +func IssuerEqualFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEqualFold(FieldIssuer, v)) +} + +// IssuerContainsFold applies the ContainsFold predicate on the "issuer" field. +func IssuerContainsFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContainsFold(FieldIssuer, v)) +} + +// HasUser applies the HasEdge predicate on the "user" edge. +func HasUser() predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates). +func HasUserWith(preds ...predicate.User) predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := newUserStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasChannels applies the HasEdge predicate on the "channels" edge. +func HasChannels() predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, ChannelsTable, ChannelsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasChannelsWith applies the HasEdge predicate on the "channels" edge with a given conditions (other predicates). +func HasChannelsWith(preds ...predicate.AuthIdentityChannel) predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := newChannelsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasAdoptionDecisions applies the HasEdge predicate on the "adoption_decisions" edge. +func HasAdoptionDecisions() predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, AdoptionDecisionsTable, AdoptionDecisionsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasAdoptionDecisionsWith applies the HasEdge predicate on the "adoption_decisions" edge with a given conditions (other predicates). +func HasAdoptionDecisionsWith(preds ...predicate.IdentityAdoptionDecision) predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := newAdoptionDecisionsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.AuthIdentity) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.AuthIdentity) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.AuthIdentity) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.NotPredicates(p)) +} diff --git a/backend/ent/authidentity_create.go b/backend/ent/authidentity_create.go new file mode 100644 index 00000000..e287705c --- /dev/null +++ b/backend/ent/authidentity_create.go @@ -0,0 +1,1036 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// AuthIdentityCreate is the builder for creating a AuthIdentity entity. +type AuthIdentityCreate struct { + config + mutation *AuthIdentityMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *AuthIdentityCreate) SetCreatedAt(v time.Time) *AuthIdentityCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *AuthIdentityCreate) SetNillableCreatedAt(v *time.Time) *AuthIdentityCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *AuthIdentityCreate) SetUpdatedAt(v time.Time) *AuthIdentityCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *AuthIdentityCreate) SetNillableUpdatedAt(v *time.Time) *AuthIdentityCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetUserID sets the "user_id" field. +func (_c *AuthIdentityCreate) SetUserID(v int64) *AuthIdentityCreate { + _c.mutation.SetUserID(v) + return _c +} + +// SetProviderType sets the "provider_type" field. +func (_c *AuthIdentityCreate) SetProviderType(v string) *AuthIdentityCreate { + _c.mutation.SetProviderType(v) + return _c +} + +// SetProviderKey sets the "provider_key" field. +func (_c *AuthIdentityCreate) SetProviderKey(v string) *AuthIdentityCreate { + _c.mutation.SetProviderKey(v) + return _c +} + +// SetProviderSubject sets the "provider_subject" field. +func (_c *AuthIdentityCreate) SetProviderSubject(v string) *AuthIdentityCreate { + _c.mutation.SetProviderSubject(v) + return _c +} + +// SetVerifiedAt sets the "verified_at" field. +func (_c *AuthIdentityCreate) SetVerifiedAt(v time.Time) *AuthIdentityCreate { + _c.mutation.SetVerifiedAt(v) + return _c +} + +// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil. +func (_c *AuthIdentityCreate) SetNillableVerifiedAt(v *time.Time) *AuthIdentityCreate { + if v != nil { + _c.SetVerifiedAt(*v) + } + return _c +} + +// SetIssuer sets the "issuer" field. +func (_c *AuthIdentityCreate) SetIssuer(v string) *AuthIdentityCreate { + _c.mutation.SetIssuer(v) + return _c +} + +// SetNillableIssuer sets the "issuer" field if the given value is not nil. +func (_c *AuthIdentityCreate) SetNillableIssuer(v *string) *AuthIdentityCreate { + if v != nil { + _c.SetIssuer(*v) + } + return _c +} + +// SetMetadata sets the "metadata" field. +func (_c *AuthIdentityCreate) SetMetadata(v map[string]interface{}) *AuthIdentityCreate { + _c.mutation.SetMetadata(v) + return _c +} + +// SetUser sets the "user" edge to the User entity. +func (_c *AuthIdentityCreate) SetUser(v *User) *AuthIdentityCreate { + return _c.SetUserID(v.ID) +} + +// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs. +func (_c *AuthIdentityCreate) AddChannelIDs(ids ...int64) *AuthIdentityCreate { + _c.mutation.AddChannelIDs(ids...) + return _c +} + +// AddChannels adds the "channels" edges to the AuthIdentityChannel entity. +func (_c *AuthIdentityCreate) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddChannelIDs(ids...) +} + +// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs. +func (_c *AuthIdentityCreate) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityCreate { + _c.mutation.AddAdoptionDecisionIDs(ids...) + return _c +} + +// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity. +func (_c *AuthIdentityCreate) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddAdoptionDecisionIDs(ids...) +} + +// Mutation returns the AuthIdentityMutation object of the builder. +func (_c *AuthIdentityCreate) Mutation() *AuthIdentityMutation { + return _c.mutation +} + +// Save creates the AuthIdentity in the database. +func (_c *AuthIdentityCreate) Save(ctx context.Context) (*AuthIdentity, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *AuthIdentityCreate) SaveX(ctx context.Context) *AuthIdentity { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *AuthIdentityCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *AuthIdentityCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *AuthIdentityCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := authidentity.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := authidentity.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.Metadata(); !ok { + v := authidentity.DefaultMetadata() + _c.mutation.SetMetadata(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *AuthIdentityCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AuthIdentity.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "AuthIdentity.updated_at"`)} + } + if _, ok := _c.mutation.UserID(); !ok { + return &ValidationError{Name: "user_id", err: errors.New(`ent: missing required field "AuthIdentity.user_id"`)} + } + if _, ok := _c.mutation.ProviderType(); !ok { + return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "AuthIdentity.provider_type"`)} + } + if v, ok := _c.mutation.ProviderType(); ok { + if err := authidentity.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderKey(); !ok { + return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "AuthIdentity.provider_key"`)} + } + if v, ok := _c.mutation.ProviderKey(); ok { + if err := authidentity.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderSubject(); !ok { + return &ValidationError{Name: "provider_subject", err: errors.New(`ent: missing required field "AuthIdentity.provider_subject"`)} + } + if v, ok := _c.mutation.ProviderSubject(); ok { + if err := authidentity.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)} + } + } + if _, ok := _c.mutation.Metadata(); !ok { + return &ValidationError{Name: "metadata", err: errors.New(`ent: missing required field "AuthIdentity.metadata"`)} + } + if len(_c.mutation.UserIDs()) == 0 { + return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "AuthIdentity.user"`)} + } + return nil +} + +func (_c *AuthIdentityCreate) sqlSave(ctx context.Context) (*AuthIdentity, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *AuthIdentityCreate) createSpec() (*AuthIdentity, *sqlgraph.CreateSpec) { + var ( + _node = &AuthIdentity{config: _c.config} + _spec = sqlgraph.NewCreateSpec(authidentity.Table, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(authidentity.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.ProviderType(); ok { + _spec.SetField(authidentity.FieldProviderType, field.TypeString, value) + _node.ProviderType = value + } + if value, ok := _c.mutation.ProviderKey(); ok { + _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value) + _node.ProviderKey = value + } + if value, ok := _c.mutation.ProviderSubject(); ok { + _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value) + _node.ProviderSubject = value + } + if value, ok := _c.mutation.VerifiedAt(); ok { + _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value) + _node.VerifiedAt = &value + } + if value, ok := _c.mutation.Issuer(); ok { + _spec.SetField(authidentity.FieldIssuer, field.TypeString, value) + _node.Issuer = &value + } + if value, ok := _c.mutation.Metadata(); ok { + _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value) + _node.Metadata = value + } + if nodes := _c.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentity.UserTable, + Columns: []string{authidentity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.UserID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.ChannelsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.AuthIdentity.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.AuthIdentityUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *AuthIdentityCreate) OnConflict(opts ...sql.ConflictOption) *AuthIdentityUpsertOne { + _c.conflict = opts + return &AuthIdentityUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *AuthIdentityCreate) OnConflictColumns(columns ...string) *AuthIdentityUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &AuthIdentityUpsertOne{ + create: _c, + } +} + +type ( + // AuthIdentityUpsertOne is the builder for "upsert"-ing + // one AuthIdentity node. + AuthIdentityUpsertOne struct { + create *AuthIdentityCreate + } + + // AuthIdentityUpsert is the "OnConflict" setter. + AuthIdentityUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityUpsert) SetUpdatedAt(v time.Time) *AuthIdentityUpsert { + u.Set(authidentity.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateUpdatedAt() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldUpdatedAt) + return u +} + +// SetUserID sets the "user_id" field. +func (u *AuthIdentityUpsert) SetUserID(v int64) *AuthIdentityUpsert { + u.Set(authidentity.FieldUserID, v) + return u +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateUserID() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldUserID) + return u +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityUpsert) SetProviderType(v string) *AuthIdentityUpsert { + u.Set(authidentity.FieldProviderType, v) + return u +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateProviderType() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldProviderType) + return u +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityUpsert) SetProviderKey(v string) *AuthIdentityUpsert { + u.Set(authidentity.FieldProviderKey, v) + return u +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateProviderKey() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldProviderKey) + return u +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *AuthIdentityUpsert) SetProviderSubject(v string) *AuthIdentityUpsert { + u.Set(authidentity.FieldProviderSubject, v) + return u +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateProviderSubject() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldProviderSubject) + return u +} + +// SetVerifiedAt sets the "verified_at" field. +func (u *AuthIdentityUpsert) SetVerifiedAt(v time.Time) *AuthIdentityUpsert { + u.Set(authidentity.FieldVerifiedAt, v) + return u +} + +// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateVerifiedAt() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldVerifiedAt) + return u +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (u *AuthIdentityUpsert) ClearVerifiedAt() *AuthIdentityUpsert { + u.SetNull(authidentity.FieldVerifiedAt) + return u +} + +// SetIssuer sets the "issuer" field. +func (u *AuthIdentityUpsert) SetIssuer(v string) *AuthIdentityUpsert { + u.Set(authidentity.FieldIssuer, v) + return u +} + +// UpdateIssuer sets the "issuer" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateIssuer() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldIssuer) + return u +} + +// ClearIssuer clears the value of the "issuer" field. +func (u *AuthIdentityUpsert) ClearIssuer() *AuthIdentityUpsert { + u.SetNull(authidentity.FieldIssuer) + return u +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityUpsert) SetMetadata(v map[string]interface{}) *AuthIdentityUpsert { + u.Set(authidentity.FieldMetadata, v) + return u +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateMetadata() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldMetadata) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *AuthIdentityUpsertOne) UpdateNewValues() *AuthIdentityUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(authidentity.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *AuthIdentityUpsertOne) Ignore() *AuthIdentityUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *AuthIdentityUpsertOne) DoNothing() *AuthIdentityUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the AuthIdentityCreate.OnConflict +// documentation for more info. +func (u *AuthIdentityUpsertOne) Update(set func(*AuthIdentityUpsert)) *AuthIdentityUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&AuthIdentityUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityUpsertOne) SetUpdatedAt(v time.Time) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateUpdatedAt() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetUserID sets the "user_id" field. +func (u *AuthIdentityUpsertOne) SetUserID(v int64) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetUserID(v) + }) +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateUserID() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateUserID() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityUpsertOne) SetProviderType(v string) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateProviderType() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityUpsertOne) SetProviderKey(v string) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateProviderKey() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderKey() + }) +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *AuthIdentityUpsertOne) SetProviderSubject(v string) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderSubject(v) + }) +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateProviderSubject() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderSubject() + }) +} + +// SetVerifiedAt sets the "verified_at" field. +func (u *AuthIdentityUpsertOne) SetVerifiedAt(v time.Time) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetVerifiedAt(v) + }) +} + +// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateVerifiedAt() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateVerifiedAt() + }) +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (u *AuthIdentityUpsertOne) ClearVerifiedAt() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.ClearVerifiedAt() + }) +} + +// SetIssuer sets the "issuer" field. +func (u *AuthIdentityUpsertOne) SetIssuer(v string) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetIssuer(v) + }) +} + +// UpdateIssuer sets the "issuer" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateIssuer() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateIssuer() + }) +} + +// ClearIssuer clears the value of the "issuer" field. +func (u *AuthIdentityUpsertOne) ClearIssuer() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.ClearIssuer() + }) +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityUpsertOne) SetMetadata(v map[string]interface{}) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetMetadata(v) + }) +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateMetadata() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateMetadata() + }) +} + +// Exec executes the query. +func (u *AuthIdentityUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for AuthIdentityCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *AuthIdentityUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *AuthIdentityUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *AuthIdentityUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// AuthIdentityCreateBulk is the builder for creating many AuthIdentity entities in bulk. +type AuthIdentityCreateBulk struct { + config + err error + builders []*AuthIdentityCreate + conflict []sql.ConflictOption +} + +// Save creates the AuthIdentity entities in the database. +func (_c *AuthIdentityCreateBulk) Save(ctx context.Context) ([]*AuthIdentity, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*AuthIdentity, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*AuthIdentityMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *AuthIdentityCreateBulk) SaveX(ctx context.Context) []*AuthIdentity { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *AuthIdentityCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *AuthIdentityCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.AuthIdentity.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.AuthIdentityUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *AuthIdentityCreateBulk) OnConflict(opts ...sql.ConflictOption) *AuthIdentityUpsertBulk { + _c.conflict = opts + return &AuthIdentityUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *AuthIdentityCreateBulk) OnConflictColumns(columns ...string) *AuthIdentityUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &AuthIdentityUpsertBulk{ + create: _c, + } +} + +// AuthIdentityUpsertBulk is the builder for "upsert"-ing +// a bulk of AuthIdentity nodes. +type AuthIdentityUpsertBulk struct { + create *AuthIdentityCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *AuthIdentityUpsertBulk) UpdateNewValues() *AuthIdentityUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(authidentity.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *AuthIdentityUpsertBulk) Ignore() *AuthIdentityUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *AuthIdentityUpsertBulk) DoNothing() *AuthIdentityUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the AuthIdentityCreateBulk.OnConflict +// documentation for more info. +func (u *AuthIdentityUpsertBulk) Update(set func(*AuthIdentityUpsert)) *AuthIdentityUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&AuthIdentityUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityUpsertBulk) SetUpdatedAt(v time.Time) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateUpdatedAt() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetUserID sets the "user_id" field. +func (u *AuthIdentityUpsertBulk) SetUserID(v int64) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetUserID(v) + }) +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateUserID() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateUserID() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityUpsertBulk) SetProviderType(v string) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateProviderType() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityUpsertBulk) SetProviderKey(v string) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateProviderKey() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderKey() + }) +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *AuthIdentityUpsertBulk) SetProviderSubject(v string) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderSubject(v) + }) +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateProviderSubject() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderSubject() + }) +} + +// SetVerifiedAt sets the "verified_at" field. +func (u *AuthIdentityUpsertBulk) SetVerifiedAt(v time.Time) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetVerifiedAt(v) + }) +} + +// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateVerifiedAt() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateVerifiedAt() + }) +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (u *AuthIdentityUpsertBulk) ClearVerifiedAt() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.ClearVerifiedAt() + }) +} + +// SetIssuer sets the "issuer" field. +func (u *AuthIdentityUpsertBulk) SetIssuer(v string) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetIssuer(v) + }) +} + +// UpdateIssuer sets the "issuer" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateIssuer() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateIssuer() + }) +} + +// ClearIssuer clears the value of the "issuer" field. +func (u *AuthIdentityUpsertBulk) ClearIssuer() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.ClearIssuer() + }) +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityUpsertBulk) SetMetadata(v map[string]interface{}) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetMetadata(v) + }) +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateMetadata() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateMetadata() + }) +} + +// Exec executes the query. +func (u *AuthIdentityUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the AuthIdentityCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for AuthIdentityCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *AuthIdentityUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/authidentity_delete.go b/backend/ent/authidentity_delete.go new file mode 100644 index 00000000..4f1f6f3c --- /dev/null +++ b/backend/ent/authidentity_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// AuthIdentityDelete is the builder for deleting a AuthIdentity entity. +type AuthIdentityDelete struct { + config + hooks []Hook + mutation *AuthIdentityMutation +} + +// Where appends a list predicates to the AuthIdentityDelete builder. +func (_d *AuthIdentityDelete) Where(ps ...predicate.AuthIdentity) *AuthIdentityDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *AuthIdentityDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *AuthIdentityDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *AuthIdentityDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(authidentity.Table, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// AuthIdentityDeleteOne is the builder for deleting a single AuthIdentity entity. +type AuthIdentityDeleteOne struct { + _d *AuthIdentityDelete +} + +// Where appends a list predicates to the AuthIdentityDelete builder. +func (_d *AuthIdentityDeleteOne) Where(ps ...predicate.AuthIdentity) *AuthIdentityDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *AuthIdentityDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{authidentity.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *AuthIdentityDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/authidentity_query.go b/backend/ent/authidentity_query.go new file mode 100644 index 00000000..ff27ef3c --- /dev/null +++ b/backend/ent/authidentity_query.go @@ -0,0 +1,797 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "database/sql/driver" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// AuthIdentityQuery is the builder for querying AuthIdentity entities. +type AuthIdentityQuery struct { + config + ctx *QueryContext + order []authidentity.OrderOption + inters []Interceptor + predicates []predicate.AuthIdentity + withUser *UserQuery + withChannels *AuthIdentityChannelQuery + withAdoptionDecisions *IdentityAdoptionDecisionQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the AuthIdentityQuery builder. +func (_q *AuthIdentityQuery) Where(ps ...predicate.AuthIdentity) *AuthIdentityQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *AuthIdentityQuery) Limit(limit int) *AuthIdentityQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *AuthIdentityQuery) Offset(offset int) *AuthIdentityQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *AuthIdentityQuery) Unique(unique bool) *AuthIdentityQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *AuthIdentityQuery) Order(o ...authidentity.OrderOption) *AuthIdentityQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryUser chains the current query on the "user" edge. +func (_q *AuthIdentityQuery) QueryUser() *UserQuery { + query := (&UserClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, selector), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, authidentity.UserTable, authidentity.UserColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryChannels chains the current query on the "channels" edge. +func (_q *AuthIdentityQuery) QueryChannels() *AuthIdentityChannelQuery { + query := (&AuthIdentityChannelClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, selector), + sqlgraph.To(authidentitychannel.Table, authidentitychannel.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, authidentity.ChannelsTable, authidentity.ChannelsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryAdoptionDecisions chains the current query on the "adoption_decisions" edge. +func (_q *AuthIdentityQuery) QueryAdoptionDecisions() *IdentityAdoptionDecisionQuery { + query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, selector), + sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, authidentity.AdoptionDecisionsTable, authidentity.AdoptionDecisionsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first AuthIdentity entity from the query. +// Returns a *NotFoundError when no AuthIdentity was found. +func (_q *AuthIdentityQuery) First(ctx context.Context) (*AuthIdentity, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{authidentity.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *AuthIdentityQuery) FirstX(ctx context.Context) *AuthIdentity { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first AuthIdentity ID from the query. +// Returns a *NotFoundError when no AuthIdentity ID was found. +func (_q *AuthIdentityQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{authidentity.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *AuthIdentityQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single AuthIdentity entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one AuthIdentity entity is found. +// Returns a *NotFoundError when no AuthIdentity entities are found. +func (_q *AuthIdentityQuery) Only(ctx context.Context) (*AuthIdentity, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{authidentity.Label} + default: + return nil, &NotSingularError{authidentity.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *AuthIdentityQuery) OnlyX(ctx context.Context) *AuthIdentity { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only AuthIdentity ID in the query. +// Returns a *NotSingularError when more than one AuthIdentity ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *AuthIdentityQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{authidentity.Label} + default: + err = &NotSingularError{authidentity.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *AuthIdentityQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of AuthIdentities. +func (_q *AuthIdentityQuery) All(ctx context.Context) ([]*AuthIdentity, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*AuthIdentity, *AuthIdentityQuery]() + return withInterceptors[[]*AuthIdentity](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *AuthIdentityQuery) AllX(ctx context.Context) []*AuthIdentity { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of AuthIdentity IDs. +func (_q *AuthIdentityQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(authidentity.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *AuthIdentityQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *AuthIdentityQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*AuthIdentityQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *AuthIdentityQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *AuthIdentityQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *AuthIdentityQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the AuthIdentityQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *AuthIdentityQuery) Clone() *AuthIdentityQuery { + if _q == nil { + return nil + } + return &AuthIdentityQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]authidentity.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.AuthIdentity{}, _q.predicates...), + withUser: _q.withUser.Clone(), + withChannels: _q.withChannels.Clone(), + withAdoptionDecisions: _q.withAdoptionDecisions.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithUser tells the query-builder to eager-load the nodes that are connected to +// the "user" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *AuthIdentityQuery) WithUser(opts ...func(*UserQuery)) *AuthIdentityQuery { + query := (&UserClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withUser = query + return _q +} + +// WithChannels tells the query-builder to eager-load the nodes that are connected to +// the "channels" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *AuthIdentityQuery) WithChannels(opts ...func(*AuthIdentityChannelQuery)) *AuthIdentityQuery { + query := (&AuthIdentityChannelClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withChannels = query + return _q +} + +// WithAdoptionDecisions tells the query-builder to eager-load the nodes that are connected to +// the "adoption_decisions" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *AuthIdentityQuery) WithAdoptionDecisions(opts ...func(*IdentityAdoptionDecisionQuery)) *AuthIdentityQuery { + query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withAdoptionDecisions = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.AuthIdentity.Query(). +// GroupBy(authidentity.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *AuthIdentityQuery) GroupBy(field string, fields ...string) *AuthIdentityGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &AuthIdentityGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = authidentity.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.AuthIdentity.Query(). +// Select(authidentity.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *AuthIdentityQuery) Select(fields ...string) *AuthIdentitySelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &AuthIdentitySelect{AuthIdentityQuery: _q} + sbuild.label = authidentity.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a AuthIdentitySelect configured with the given aggregations. +func (_q *AuthIdentityQuery) Aggregate(fns ...AggregateFunc) *AuthIdentitySelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *AuthIdentityQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !authidentity.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *AuthIdentityQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AuthIdentity, error) { + var ( + nodes = []*AuthIdentity{} + _spec = _q.querySpec() + loadedTypes = [3]bool{ + _q.withUser != nil, + _q.withChannels != nil, + _q.withAdoptionDecisions != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*AuthIdentity).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &AuthIdentity{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withUser; query != nil { + if err := _q.loadUser(ctx, query, nodes, nil, + func(n *AuthIdentity, e *User) { n.Edges.User = e }); err != nil { + return nil, err + } + } + if query := _q.withChannels; query != nil { + if err := _q.loadChannels(ctx, query, nodes, + func(n *AuthIdentity) { n.Edges.Channels = []*AuthIdentityChannel{} }, + func(n *AuthIdentity, e *AuthIdentityChannel) { n.Edges.Channels = append(n.Edges.Channels, e) }); err != nil { + return nil, err + } + } + if query := _q.withAdoptionDecisions; query != nil { + if err := _q.loadAdoptionDecisions(ctx, query, nodes, + func(n *AuthIdentity) { n.Edges.AdoptionDecisions = []*IdentityAdoptionDecision{} }, + func(n *AuthIdentity, e *IdentityAdoptionDecision) { + n.Edges.AdoptionDecisions = append(n.Edges.AdoptionDecisions, e) + }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *AuthIdentityQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *User)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*AuthIdentity) + for i := range nodes { + fk := nodes[i].UserID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (_q *AuthIdentityQuery) loadChannels(ctx context.Context, query *AuthIdentityChannelQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *AuthIdentityChannel)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*AuthIdentity) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(authidentitychannel.FieldIdentityID) + } + query.Where(predicate.AuthIdentityChannel(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(authidentity.ChannelsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.IdentityID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "identity_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} +func (_q *AuthIdentityQuery) loadAdoptionDecisions(ctx context.Context, query *IdentityAdoptionDecisionQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *IdentityAdoptionDecision)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*AuthIdentity) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(identityadoptiondecision.FieldIdentityID) + } + query.Where(predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(authidentity.AdoptionDecisionsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.IdentityID + if fk == nil { + return fmt.Errorf(`foreign-key "identity_id" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "identity_id" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} + +func (_q *AuthIdentityQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *AuthIdentityQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, authidentity.FieldID) + for i := range fields { + if fields[i] != authidentity.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withUser != nil { + _spec.Node.AddColumnOnce(authidentity.FieldUserID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *AuthIdentityQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(authidentity.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = authidentity.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *AuthIdentityQuery) ForUpdate(opts ...sql.LockOption) *AuthIdentityQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *AuthIdentityQuery) ForShare(opts ...sql.LockOption) *AuthIdentityQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// AuthIdentityGroupBy is the group-by builder for AuthIdentity entities. +type AuthIdentityGroupBy struct { + selector + build *AuthIdentityQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *AuthIdentityGroupBy) Aggregate(fns ...AggregateFunc) *AuthIdentityGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *AuthIdentityGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*AuthIdentityQuery, *AuthIdentityGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *AuthIdentityGroupBy) sqlScan(ctx context.Context, root *AuthIdentityQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// AuthIdentitySelect is the builder for selecting fields of AuthIdentity entities. +type AuthIdentitySelect struct { + *AuthIdentityQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *AuthIdentitySelect) Aggregate(fns ...AggregateFunc) *AuthIdentitySelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *AuthIdentitySelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*AuthIdentityQuery, *AuthIdentitySelect](ctx, _s.AuthIdentityQuery, _s, _s.inters, v) +} + +func (_s *AuthIdentitySelect) sqlScan(ctx context.Context, root *AuthIdentityQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/authidentity_update.go b/backend/ent/authidentity_update.go new file mode 100644 index 00000000..c457470b --- /dev/null +++ b/backend/ent/authidentity_update.go @@ -0,0 +1,923 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// AuthIdentityUpdate is the builder for updating AuthIdentity entities. +type AuthIdentityUpdate struct { + config + hooks []Hook + mutation *AuthIdentityMutation +} + +// Where appends a list predicates to the AuthIdentityUpdate builder. +func (_u *AuthIdentityUpdate) Where(ps ...predicate.AuthIdentity) *AuthIdentityUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *AuthIdentityUpdate) SetUpdatedAt(v time.Time) *AuthIdentityUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetUserID sets the "user_id" field. +func (_u *AuthIdentityUpdate) SetUserID(v int64) *AuthIdentityUpdate { + _u.mutation.SetUserID(v) + return _u +} + +// SetNillableUserID sets the "user_id" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableUserID(v *int64) *AuthIdentityUpdate { + if v != nil { + _u.SetUserID(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *AuthIdentityUpdate) SetProviderType(v string) *AuthIdentityUpdate { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableProviderType(v *string) *AuthIdentityUpdate { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *AuthIdentityUpdate) SetProviderKey(v string) *AuthIdentityUpdate { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableProviderKey(v *string) *AuthIdentityUpdate { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetProviderSubject sets the "provider_subject" field. +func (_u *AuthIdentityUpdate) SetProviderSubject(v string) *AuthIdentityUpdate { + _u.mutation.SetProviderSubject(v) + return _u +} + +// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableProviderSubject(v *string) *AuthIdentityUpdate { + if v != nil { + _u.SetProviderSubject(*v) + } + return _u +} + +// SetVerifiedAt sets the "verified_at" field. +func (_u *AuthIdentityUpdate) SetVerifiedAt(v time.Time) *AuthIdentityUpdate { + _u.mutation.SetVerifiedAt(v) + return _u +} + +// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableVerifiedAt(v *time.Time) *AuthIdentityUpdate { + if v != nil { + _u.SetVerifiedAt(*v) + } + return _u +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (_u *AuthIdentityUpdate) ClearVerifiedAt() *AuthIdentityUpdate { + _u.mutation.ClearVerifiedAt() + return _u +} + +// SetIssuer sets the "issuer" field. +func (_u *AuthIdentityUpdate) SetIssuer(v string) *AuthIdentityUpdate { + _u.mutation.SetIssuer(v) + return _u +} + +// SetNillableIssuer sets the "issuer" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableIssuer(v *string) *AuthIdentityUpdate { + if v != nil { + _u.SetIssuer(*v) + } + return _u +} + +// ClearIssuer clears the value of the "issuer" field. +func (_u *AuthIdentityUpdate) ClearIssuer() *AuthIdentityUpdate { + _u.mutation.ClearIssuer() + return _u +} + +// SetMetadata sets the "metadata" field. +func (_u *AuthIdentityUpdate) SetMetadata(v map[string]interface{}) *AuthIdentityUpdate { + _u.mutation.SetMetadata(v) + return _u +} + +// SetUser sets the "user" edge to the User entity. +func (_u *AuthIdentityUpdate) SetUser(v *User) *AuthIdentityUpdate { + return _u.SetUserID(v.ID) +} + +// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs. +func (_u *AuthIdentityUpdate) AddChannelIDs(ids ...int64) *AuthIdentityUpdate { + _u.mutation.AddChannelIDs(ids...) + return _u +} + +// AddChannels adds the "channels" edges to the AuthIdentityChannel entity. +func (_u *AuthIdentityUpdate) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddChannelIDs(ids...) +} + +// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs. +func (_u *AuthIdentityUpdate) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdate { + _u.mutation.AddAdoptionDecisionIDs(ids...) + return _u +} + +// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity. +func (_u *AuthIdentityUpdate) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddAdoptionDecisionIDs(ids...) +} + +// Mutation returns the AuthIdentityMutation object of the builder. +func (_u *AuthIdentityUpdate) Mutation() *AuthIdentityMutation { + return _u.mutation +} + +// ClearUser clears the "user" edge to the User entity. +func (_u *AuthIdentityUpdate) ClearUser() *AuthIdentityUpdate { + _u.mutation.ClearUser() + return _u +} + +// ClearChannels clears all "channels" edges to the AuthIdentityChannel entity. +func (_u *AuthIdentityUpdate) ClearChannels() *AuthIdentityUpdate { + _u.mutation.ClearChannels() + return _u +} + +// RemoveChannelIDs removes the "channels" edge to AuthIdentityChannel entities by IDs. +func (_u *AuthIdentityUpdate) RemoveChannelIDs(ids ...int64) *AuthIdentityUpdate { + _u.mutation.RemoveChannelIDs(ids...) + return _u +} + +// RemoveChannels removes "channels" edges to AuthIdentityChannel entities. +func (_u *AuthIdentityUpdate) RemoveChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveChannelIDs(ids...) +} + +// ClearAdoptionDecisions clears all "adoption_decisions" edges to the IdentityAdoptionDecision entity. +func (_u *AuthIdentityUpdate) ClearAdoptionDecisions() *AuthIdentityUpdate { + _u.mutation.ClearAdoptionDecisions() + return _u +} + +// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to IdentityAdoptionDecision entities by IDs. +func (_u *AuthIdentityUpdate) RemoveAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdate { + _u.mutation.RemoveAdoptionDecisionIDs(ids...) + return _u +} + +// RemoveAdoptionDecisions removes "adoption_decisions" edges to IdentityAdoptionDecision entities. +func (_u *AuthIdentityUpdate) RemoveAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveAdoptionDecisionIDs(ids...) +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *AuthIdentityUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *AuthIdentityUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *AuthIdentityUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *AuthIdentityUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *AuthIdentityUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := authidentity.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *AuthIdentityUpdate) check() error { + if v, ok := _u.mutation.ProviderType(); ok { + if err := authidentity.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := authidentity.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderSubject(); ok { + if err := authidentity.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)} + } + } + if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "AuthIdentity.user"`) + } + return nil +} + +func (_u *AuthIdentityUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(authidentity.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderSubject(); ok { + _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value) + } + if value, ok := _u.mutation.VerifiedAt(); ok { + _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value) + } + if _u.mutation.VerifiedAtCleared() { + _spec.ClearField(authidentity.FieldVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.Issuer(); ok { + _spec.SetField(authidentity.FieldIssuer, field.TypeString, value) + } + if _u.mutation.IssuerCleared() { + _spec.ClearField(authidentity.FieldIssuer, field.TypeString) + } + if value, ok := _u.mutation.Metadata(); ok { + _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value) + } + if _u.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentity.UserTable, + Columns: []string{authidentity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentity.UserTable, + Columns: []string{authidentity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.ChannelsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedChannelsIDs(); len(nodes) > 0 && !_u.mutation.ChannelsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.ChannelsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.AdoptionDecisionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedAdoptionDecisionsIDs(); len(nodes) > 0 && !_u.mutation.AdoptionDecisionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{authidentity.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// AuthIdentityUpdateOne is the builder for updating a single AuthIdentity entity. +type AuthIdentityUpdateOne struct { + config + fields []string + hooks []Hook + mutation *AuthIdentityMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *AuthIdentityUpdateOne) SetUpdatedAt(v time.Time) *AuthIdentityUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetUserID sets the "user_id" field. +func (_u *AuthIdentityUpdateOne) SetUserID(v int64) *AuthIdentityUpdateOne { + _u.mutation.SetUserID(v) + return _u +} + +// SetNillableUserID sets the "user_id" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableUserID(v *int64) *AuthIdentityUpdateOne { + if v != nil { + _u.SetUserID(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *AuthIdentityUpdateOne) SetProviderType(v string) *AuthIdentityUpdateOne { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableProviderType(v *string) *AuthIdentityUpdateOne { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *AuthIdentityUpdateOne) SetProviderKey(v string) *AuthIdentityUpdateOne { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableProviderKey(v *string) *AuthIdentityUpdateOne { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetProviderSubject sets the "provider_subject" field. +func (_u *AuthIdentityUpdateOne) SetProviderSubject(v string) *AuthIdentityUpdateOne { + _u.mutation.SetProviderSubject(v) + return _u +} + +// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableProviderSubject(v *string) *AuthIdentityUpdateOne { + if v != nil { + _u.SetProviderSubject(*v) + } + return _u +} + +// SetVerifiedAt sets the "verified_at" field. +func (_u *AuthIdentityUpdateOne) SetVerifiedAt(v time.Time) *AuthIdentityUpdateOne { + _u.mutation.SetVerifiedAt(v) + return _u +} + +// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableVerifiedAt(v *time.Time) *AuthIdentityUpdateOne { + if v != nil { + _u.SetVerifiedAt(*v) + } + return _u +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (_u *AuthIdentityUpdateOne) ClearVerifiedAt() *AuthIdentityUpdateOne { + _u.mutation.ClearVerifiedAt() + return _u +} + +// SetIssuer sets the "issuer" field. +func (_u *AuthIdentityUpdateOne) SetIssuer(v string) *AuthIdentityUpdateOne { + _u.mutation.SetIssuer(v) + return _u +} + +// SetNillableIssuer sets the "issuer" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableIssuer(v *string) *AuthIdentityUpdateOne { + if v != nil { + _u.SetIssuer(*v) + } + return _u +} + +// ClearIssuer clears the value of the "issuer" field. +func (_u *AuthIdentityUpdateOne) ClearIssuer() *AuthIdentityUpdateOne { + _u.mutation.ClearIssuer() + return _u +} + +// SetMetadata sets the "metadata" field. +func (_u *AuthIdentityUpdateOne) SetMetadata(v map[string]interface{}) *AuthIdentityUpdateOne { + _u.mutation.SetMetadata(v) + return _u +} + +// SetUser sets the "user" edge to the User entity. +func (_u *AuthIdentityUpdateOne) SetUser(v *User) *AuthIdentityUpdateOne { + return _u.SetUserID(v.ID) +} + +// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs. +func (_u *AuthIdentityUpdateOne) AddChannelIDs(ids ...int64) *AuthIdentityUpdateOne { + _u.mutation.AddChannelIDs(ids...) + return _u +} + +// AddChannels adds the "channels" edges to the AuthIdentityChannel entity. +func (_u *AuthIdentityUpdateOne) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddChannelIDs(ids...) +} + +// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs. +func (_u *AuthIdentityUpdateOne) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdateOne { + _u.mutation.AddAdoptionDecisionIDs(ids...) + return _u +} + +// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity. +func (_u *AuthIdentityUpdateOne) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddAdoptionDecisionIDs(ids...) +} + +// Mutation returns the AuthIdentityMutation object of the builder. +func (_u *AuthIdentityUpdateOne) Mutation() *AuthIdentityMutation { + return _u.mutation +} + +// ClearUser clears the "user" edge to the User entity. +func (_u *AuthIdentityUpdateOne) ClearUser() *AuthIdentityUpdateOne { + _u.mutation.ClearUser() + return _u +} + +// ClearChannels clears all "channels" edges to the AuthIdentityChannel entity. +func (_u *AuthIdentityUpdateOne) ClearChannels() *AuthIdentityUpdateOne { + _u.mutation.ClearChannels() + return _u +} + +// RemoveChannelIDs removes the "channels" edge to AuthIdentityChannel entities by IDs. +func (_u *AuthIdentityUpdateOne) RemoveChannelIDs(ids ...int64) *AuthIdentityUpdateOne { + _u.mutation.RemoveChannelIDs(ids...) + return _u +} + +// RemoveChannels removes "channels" edges to AuthIdentityChannel entities. +func (_u *AuthIdentityUpdateOne) RemoveChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveChannelIDs(ids...) +} + +// ClearAdoptionDecisions clears all "adoption_decisions" edges to the IdentityAdoptionDecision entity. +func (_u *AuthIdentityUpdateOne) ClearAdoptionDecisions() *AuthIdentityUpdateOne { + _u.mutation.ClearAdoptionDecisions() + return _u +} + +// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to IdentityAdoptionDecision entities by IDs. +func (_u *AuthIdentityUpdateOne) RemoveAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdateOne { + _u.mutation.RemoveAdoptionDecisionIDs(ids...) + return _u +} + +// RemoveAdoptionDecisions removes "adoption_decisions" edges to IdentityAdoptionDecision entities. +func (_u *AuthIdentityUpdateOne) RemoveAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveAdoptionDecisionIDs(ids...) +} + +// Where appends a list predicates to the AuthIdentityUpdate builder. +func (_u *AuthIdentityUpdateOne) Where(ps ...predicate.AuthIdentity) *AuthIdentityUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *AuthIdentityUpdateOne) Select(field string, fields ...string) *AuthIdentityUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated AuthIdentity entity. +func (_u *AuthIdentityUpdateOne) Save(ctx context.Context) (*AuthIdentity, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *AuthIdentityUpdateOne) SaveX(ctx context.Context) *AuthIdentity { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *AuthIdentityUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *AuthIdentityUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *AuthIdentityUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := authidentity.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *AuthIdentityUpdateOne) check() error { + if v, ok := _u.mutation.ProviderType(); ok { + if err := authidentity.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := authidentity.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderSubject(); ok { + if err := authidentity.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)} + } + } + if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "AuthIdentity.user"`) + } + return nil +} + +func (_u *AuthIdentityUpdateOne) sqlSave(ctx context.Context) (_node *AuthIdentity, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AuthIdentity.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, authidentity.FieldID) + for _, f := range fields { + if !authidentity.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != authidentity.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(authidentity.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderSubject(); ok { + _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value) + } + if value, ok := _u.mutation.VerifiedAt(); ok { + _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value) + } + if _u.mutation.VerifiedAtCleared() { + _spec.ClearField(authidentity.FieldVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.Issuer(); ok { + _spec.SetField(authidentity.FieldIssuer, field.TypeString, value) + } + if _u.mutation.IssuerCleared() { + _spec.ClearField(authidentity.FieldIssuer, field.TypeString) + } + if value, ok := _u.mutation.Metadata(); ok { + _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value) + } + if _u.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentity.UserTable, + Columns: []string{authidentity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentity.UserTable, + Columns: []string{authidentity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.ChannelsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedChannelsIDs(); len(nodes) > 0 && !_u.mutation.ChannelsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.ChannelsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.AdoptionDecisionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedAdoptionDecisionsIDs(); len(nodes) > 0 && !_u.mutation.AdoptionDecisionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &AuthIdentity{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{authidentity.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/authidentitychannel.go b/backend/ent/authidentitychannel.go new file mode 100644 index 00000000..1ff3e5d1 --- /dev/null +++ b/backend/ent/authidentitychannel.go @@ -0,0 +1,228 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" +) + +// AuthIdentityChannel is the model entity for the AuthIdentityChannel schema. +type AuthIdentityChannel struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // IdentityID holds the value of the "identity_id" field. + IdentityID int64 `json:"identity_id,omitempty"` + // ProviderType holds the value of the "provider_type" field. + ProviderType string `json:"provider_type,omitempty"` + // ProviderKey holds the value of the "provider_key" field. + ProviderKey string `json:"provider_key,omitempty"` + // Channel holds the value of the "channel" field. + Channel string `json:"channel,omitempty"` + // ChannelAppID holds the value of the "channel_app_id" field. + ChannelAppID string `json:"channel_app_id,omitempty"` + // ChannelSubject holds the value of the "channel_subject" field. + ChannelSubject string `json:"channel_subject,omitempty"` + // Metadata holds the value of the "metadata" field. + Metadata map[string]interface{} `json:"metadata,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the AuthIdentityChannelQuery when eager-loading is set. + Edges AuthIdentityChannelEdges `json:"edges"` + selectValues sql.SelectValues +} + +// AuthIdentityChannelEdges holds the relations/edges for other nodes in the graph. +type AuthIdentityChannelEdges struct { + // Identity holds the value of the identity edge. + Identity *AuthIdentity `json:"identity,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [1]bool +} + +// IdentityOrErr returns the Identity value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e AuthIdentityChannelEdges) IdentityOrErr() (*AuthIdentity, error) { + if e.Identity != nil { + return e.Identity, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: authidentity.Label} + } + return nil, &NotLoadedError{edge: "identity"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*AuthIdentityChannel) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case authidentitychannel.FieldMetadata: + values[i] = new([]byte) + case authidentitychannel.FieldID, authidentitychannel.FieldIdentityID: + values[i] = new(sql.NullInt64) + case authidentitychannel.FieldProviderType, authidentitychannel.FieldProviderKey, authidentitychannel.FieldChannel, authidentitychannel.FieldChannelAppID, authidentitychannel.FieldChannelSubject: + values[i] = new(sql.NullString) + case authidentitychannel.FieldCreatedAt, authidentitychannel.FieldUpdatedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the AuthIdentityChannel fields. +func (_m *AuthIdentityChannel) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case authidentitychannel.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case authidentitychannel.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case authidentitychannel.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case authidentitychannel.FieldIdentityID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field identity_id", values[i]) + } else if value.Valid { + _m.IdentityID = value.Int64 + } + case authidentitychannel.FieldProviderType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_type", values[i]) + } else if value.Valid { + _m.ProviderType = value.String + } + case authidentitychannel.FieldProviderKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_key", values[i]) + } else if value.Valid { + _m.ProviderKey = value.String + } + case authidentitychannel.FieldChannel: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field channel", values[i]) + } else if value.Valid { + _m.Channel = value.String + } + case authidentitychannel.FieldChannelAppID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field channel_app_id", values[i]) + } else if value.Valid { + _m.ChannelAppID = value.String + } + case authidentitychannel.FieldChannelSubject: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field channel_subject", values[i]) + } else if value.Valid { + _m.ChannelSubject = value.String + } + case authidentitychannel.FieldMetadata: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field metadata", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.Metadata); err != nil { + return fmt.Errorf("unmarshal field metadata: %w", err) + } + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the AuthIdentityChannel. +// This includes values selected through modifiers, order, etc. +func (_m *AuthIdentityChannel) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryIdentity queries the "identity" edge of the AuthIdentityChannel entity. +func (_m *AuthIdentityChannel) QueryIdentity() *AuthIdentityQuery { + return NewAuthIdentityChannelClient(_m.config).QueryIdentity(_m) +} + +// Update returns a builder for updating this AuthIdentityChannel. +// Note that you need to call AuthIdentityChannel.Unwrap() before calling this method if this AuthIdentityChannel +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *AuthIdentityChannel) Update() *AuthIdentityChannelUpdateOne { + return NewAuthIdentityChannelClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the AuthIdentityChannel entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *AuthIdentityChannel) Unwrap() *AuthIdentityChannel { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: AuthIdentityChannel is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *AuthIdentityChannel) String() string { + var builder strings.Builder + builder.WriteString("AuthIdentityChannel(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("identity_id=") + builder.WriteString(fmt.Sprintf("%v", _m.IdentityID)) + builder.WriteString(", ") + builder.WriteString("provider_type=") + builder.WriteString(_m.ProviderType) + builder.WriteString(", ") + builder.WriteString("provider_key=") + builder.WriteString(_m.ProviderKey) + builder.WriteString(", ") + builder.WriteString("channel=") + builder.WriteString(_m.Channel) + builder.WriteString(", ") + builder.WriteString("channel_app_id=") + builder.WriteString(_m.ChannelAppID) + builder.WriteString(", ") + builder.WriteString("channel_subject=") + builder.WriteString(_m.ChannelSubject) + builder.WriteString(", ") + builder.WriteString("metadata=") + builder.WriteString(fmt.Sprintf("%v", _m.Metadata)) + builder.WriteByte(')') + return builder.String() +} + +// AuthIdentityChannels is a parsable slice of AuthIdentityChannel. +type AuthIdentityChannels []*AuthIdentityChannel diff --git a/backend/ent/authidentitychannel/authidentitychannel.go b/backend/ent/authidentitychannel/authidentitychannel.go new file mode 100644 index 00000000..7dcc98bb --- /dev/null +++ b/backend/ent/authidentitychannel/authidentitychannel.go @@ -0,0 +1,153 @@ +// Code generated by ent, DO NOT EDIT. + +package authidentitychannel + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the authidentitychannel type in the database. + Label = "auth_identity_channel" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldIdentityID holds the string denoting the identity_id field in the database. + FieldIdentityID = "identity_id" + // FieldProviderType holds the string denoting the provider_type field in the database. + FieldProviderType = "provider_type" + // FieldProviderKey holds the string denoting the provider_key field in the database. + FieldProviderKey = "provider_key" + // FieldChannel holds the string denoting the channel field in the database. + FieldChannel = "channel" + // FieldChannelAppID holds the string denoting the channel_app_id field in the database. + FieldChannelAppID = "channel_app_id" + // FieldChannelSubject holds the string denoting the channel_subject field in the database. + FieldChannelSubject = "channel_subject" + // FieldMetadata holds the string denoting the metadata field in the database. + FieldMetadata = "metadata" + // EdgeIdentity holds the string denoting the identity edge name in mutations. + EdgeIdentity = "identity" + // Table holds the table name of the authidentitychannel in the database. + Table = "auth_identity_channels" + // IdentityTable is the table that holds the identity relation/edge. + IdentityTable = "auth_identity_channels" + // IdentityInverseTable is the table name for the AuthIdentity entity. + // It exists in this package in order to avoid circular dependency with the "authidentity" package. + IdentityInverseTable = "auth_identities" + // IdentityColumn is the table column denoting the identity relation/edge. + IdentityColumn = "identity_id" +) + +// Columns holds all SQL columns for authidentitychannel fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldIdentityID, + FieldProviderType, + FieldProviderKey, + FieldChannel, + FieldChannelAppID, + FieldChannelSubject, + FieldMetadata, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + ProviderTypeValidator func(string) error + // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + ProviderKeyValidator func(string) error + // ChannelValidator is a validator for the "channel" field. It is called by the builders before save. + ChannelValidator func(string) error + // ChannelAppIDValidator is a validator for the "channel_app_id" field. It is called by the builders before save. + ChannelAppIDValidator func(string) error + // ChannelSubjectValidator is a validator for the "channel_subject" field. It is called by the builders before save. + ChannelSubjectValidator func(string) error + // DefaultMetadata holds the default value on creation for the "metadata" field. + DefaultMetadata func() map[string]interface{} +) + +// OrderOption defines the ordering options for the AuthIdentityChannel queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByIdentityID orders the results by the identity_id field. +func ByIdentityID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIdentityID, opts...).ToFunc() +} + +// ByProviderType orders the results by the provider_type field. +func ByProviderType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderType, opts...).ToFunc() +} + +// ByProviderKey orders the results by the provider_key field. +func ByProviderKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderKey, opts...).ToFunc() +} + +// ByChannel orders the results by the channel field. +func ByChannel(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldChannel, opts...).ToFunc() +} + +// ByChannelAppID orders the results by the channel_app_id field. +func ByChannelAppID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldChannelAppID, opts...).ToFunc() +} + +// ByChannelSubject orders the results by the channel_subject field. +func ByChannelSubject(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldChannelSubject, opts...).ToFunc() +} + +// ByIdentityField orders the results by identity field. +func ByIdentityField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newIdentityStep(), sql.OrderByField(field, opts...)) + } +} +func newIdentityStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(IdentityInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn), + ) +} diff --git a/backend/ent/authidentitychannel/where.go b/backend/ent/authidentitychannel/where.go new file mode 100644 index 00000000..827dc384 --- /dev/null +++ b/backend/ent/authidentitychannel/where.go @@ -0,0 +1,559 @@ +// Code generated by ent, DO NOT EDIT. + +package authidentitychannel + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// IdentityID applies equality check predicate on the "identity_id" field. It's identical to IdentityIDEQ. +func IdentityID(v int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldIdentityID, v)) +} + +// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ. +func ProviderType(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ. +func ProviderKey(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderKey, v)) +} + +// Channel applies equality check predicate on the "channel" field. It's identical to ChannelEQ. +func Channel(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannel, v)) +} + +// ChannelAppID applies equality check predicate on the "channel_app_id" field. It's identical to ChannelAppIDEQ. +func ChannelAppID(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelAppID, v)) +} + +// ChannelSubject applies equality check predicate on the "channel_subject" field. It's identical to ChannelSubjectEQ. +func ChannelSubject(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelSubject, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// IdentityIDEQ applies the EQ predicate on the "identity_id" field. +func IdentityIDEQ(v int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldIdentityID, v)) +} + +// IdentityIDNEQ applies the NEQ predicate on the "identity_id" field. +func IdentityIDNEQ(v int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldIdentityID, v)) +} + +// IdentityIDIn applies the In predicate on the "identity_id" field. +func IdentityIDIn(vs ...int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldIdentityID, vs...)) +} + +// IdentityIDNotIn applies the NotIn predicate on the "identity_id" field. +func IdentityIDNotIn(vs ...int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldIdentityID, vs...)) +} + +// ProviderTypeEQ applies the EQ predicate on the "provider_type" field. +func ProviderTypeEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field. +func ProviderTypeNEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldProviderType, v)) +} + +// ProviderTypeIn applies the In predicate on the "provider_type" field. +func ProviderTypeIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldProviderType, vs...)) +} + +// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field. +func ProviderTypeNotIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldProviderType, vs...)) +} + +// ProviderTypeGT applies the GT predicate on the "provider_type" field. +func ProviderTypeGT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldProviderType, v)) +} + +// ProviderTypeGTE applies the GTE predicate on the "provider_type" field. +func ProviderTypeGTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldProviderType, v)) +} + +// ProviderTypeLT applies the LT predicate on the "provider_type" field. +func ProviderTypeLT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldProviderType, v)) +} + +// ProviderTypeLTE applies the LTE predicate on the "provider_type" field. +func ProviderTypeLTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldProviderType, v)) +} + +// ProviderTypeContains applies the Contains predicate on the "provider_type" field. +func ProviderTypeContains(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContains(FieldProviderType, v)) +} + +// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field. +func ProviderTypeHasPrefix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldProviderType, v)) +} + +// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field. +func ProviderTypeHasSuffix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldProviderType, v)) +} + +// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field. +func ProviderTypeEqualFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldProviderType, v)) +} + +// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field. +func ProviderTypeContainsFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldProviderType, v)) +} + +// ProviderKeyEQ applies the EQ predicate on the "provider_key" field. +func ProviderKeyEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field. +func ProviderKeyNEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldProviderKey, v)) +} + +// ProviderKeyIn applies the In predicate on the "provider_key" field. +func ProviderKeyIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldProviderKey, vs...)) +} + +// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field. +func ProviderKeyNotIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldProviderKey, vs...)) +} + +// ProviderKeyGT applies the GT predicate on the "provider_key" field. +func ProviderKeyGT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldProviderKey, v)) +} + +// ProviderKeyGTE applies the GTE predicate on the "provider_key" field. +func ProviderKeyGTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldProviderKey, v)) +} + +// ProviderKeyLT applies the LT predicate on the "provider_key" field. +func ProviderKeyLT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldProviderKey, v)) +} + +// ProviderKeyLTE applies the LTE predicate on the "provider_key" field. +func ProviderKeyLTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldProviderKey, v)) +} + +// ProviderKeyContains applies the Contains predicate on the "provider_key" field. +func ProviderKeyContains(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContains(FieldProviderKey, v)) +} + +// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field. +func ProviderKeyHasPrefix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldProviderKey, v)) +} + +// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field. +func ProviderKeyHasSuffix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldProviderKey, v)) +} + +// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field. +func ProviderKeyEqualFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldProviderKey, v)) +} + +// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field. +func ProviderKeyContainsFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldProviderKey, v)) +} + +// ChannelEQ applies the EQ predicate on the "channel" field. +func ChannelEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannel, v)) +} + +// ChannelNEQ applies the NEQ predicate on the "channel" field. +func ChannelNEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannel, v)) +} + +// ChannelIn applies the In predicate on the "channel" field. +func ChannelIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannel, vs...)) +} + +// ChannelNotIn applies the NotIn predicate on the "channel" field. +func ChannelNotIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannel, vs...)) +} + +// ChannelGT applies the GT predicate on the "channel" field. +func ChannelGT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannel, v)) +} + +// ChannelGTE applies the GTE predicate on the "channel" field. +func ChannelGTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannel, v)) +} + +// ChannelLT applies the LT predicate on the "channel" field. +func ChannelLT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannel, v)) +} + +// ChannelLTE applies the LTE predicate on the "channel" field. +func ChannelLTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannel, v)) +} + +// ChannelContains applies the Contains predicate on the "channel" field. +func ChannelContains(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannel, v)) +} + +// ChannelHasPrefix applies the HasPrefix predicate on the "channel" field. +func ChannelHasPrefix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannel, v)) +} + +// ChannelHasSuffix applies the HasSuffix predicate on the "channel" field. +func ChannelHasSuffix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannel, v)) +} + +// ChannelEqualFold applies the EqualFold predicate on the "channel" field. +func ChannelEqualFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannel, v)) +} + +// ChannelContainsFold applies the ContainsFold predicate on the "channel" field. +func ChannelContainsFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannel, v)) +} + +// ChannelAppIDEQ applies the EQ predicate on the "channel_app_id" field. +func ChannelAppIDEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelAppID, v)) +} + +// ChannelAppIDNEQ applies the NEQ predicate on the "channel_app_id" field. +func ChannelAppIDNEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannelAppID, v)) +} + +// ChannelAppIDIn applies the In predicate on the "channel_app_id" field. +func ChannelAppIDIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannelAppID, vs...)) +} + +// ChannelAppIDNotIn applies the NotIn predicate on the "channel_app_id" field. +func ChannelAppIDNotIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannelAppID, vs...)) +} + +// ChannelAppIDGT applies the GT predicate on the "channel_app_id" field. +func ChannelAppIDGT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannelAppID, v)) +} + +// ChannelAppIDGTE applies the GTE predicate on the "channel_app_id" field. +func ChannelAppIDGTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannelAppID, v)) +} + +// ChannelAppIDLT applies the LT predicate on the "channel_app_id" field. +func ChannelAppIDLT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannelAppID, v)) +} + +// ChannelAppIDLTE applies the LTE predicate on the "channel_app_id" field. +func ChannelAppIDLTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannelAppID, v)) +} + +// ChannelAppIDContains applies the Contains predicate on the "channel_app_id" field. +func ChannelAppIDContains(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannelAppID, v)) +} + +// ChannelAppIDHasPrefix applies the HasPrefix predicate on the "channel_app_id" field. +func ChannelAppIDHasPrefix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannelAppID, v)) +} + +// ChannelAppIDHasSuffix applies the HasSuffix predicate on the "channel_app_id" field. +func ChannelAppIDHasSuffix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannelAppID, v)) +} + +// ChannelAppIDEqualFold applies the EqualFold predicate on the "channel_app_id" field. +func ChannelAppIDEqualFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannelAppID, v)) +} + +// ChannelAppIDContainsFold applies the ContainsFold predicate on the "channel_app_id" field. +func ChannelAppIDContainsFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannelAppID, v)) +} + +// ChannelSubjectEQ applies the EQ predicate on the "channel_subject" field. +func ChannelSubjectEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelSubject, v)) +} + +// ChannelSubjectNEQ applies the NEQ predicate on the "channel_subject" field. +func ChannelSubjectNEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannelSubject, v)) +} + +// ChannelSubjectIn applies the In predicate on the "channel_subject" field. +func ChannelSubjectIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannelSubject, vs...)) +} + +// ChannelSubjectNotIn applies the NotIn predicate on the "channel_subject" field. +func ChannelSubjectNotIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannelSubject, vs...)) +} + +// ChannelSubjectGT applies the GT predicate on the "channel_subject" field. +func ChannelSubjectGT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannelSubject, v)) +} + +// ChannelSubjectGTE applies the GTE predicate on the "channel_subject" field. +func ChannelSubjectGTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannelSubject, v)) +} + +// ChannelSubjectLT applies the LT predicate on the "channel_subject" field. +func ChannelSubjectLT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannelSubject, v)) +} + +// ChannelSubjectLTE applies the LTE predicate on the "channel_subject" field. +func ChannelSubjectLTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannelSubject, v)) +} + +// ChannelSubjectContains applies the Contains predicate on the "channel_subject" field. +func ChannelSubjectContains(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannelSubject, v)) +} + +// ChannelSubjectHasPrefix applies the HasPrefix predicate on the "channel_subject" field. +func ChannelSubjectHasPrefix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannelSubject, v)) +} + +// ChannelSubjectHasSuffix applies the HasSuffix predicate on the "channel_subject" field. +func ChannelSubjectHasSuffix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannelSubject, v)) +} + +// ChannelSubjectEqualFold applies the EqualFold predicate on the "channel_subject" field. +func ChannelSubjectEqualFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannelSubject, v)) +} + +// ChannelSubjectContainsFold applies the ContainsFold predicate on the "channel_subject" field. +func ChannelSubjectContainsFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannelSubject, v)) +} + +// HasIdentity applies the HasEdge predicate on the "identity" edge. +func HasIdentity() predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasIdentityWith applies the HasEdge predicate on the "identity" edge with a given conditions (other predicates). +func HasIdentityWith(preds ...predicate.AuthIdentity) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(func(s *sql.Selector) { + step := newIdentityStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.AuthIdentityChannel) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.AuthIdentityChannel) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.AuthIdentityChannel) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.NotPredicates(p)) +} diff --git a/backend/ent/authidentitychannel_create.go b/backend/ent/authidentitychannel_create.go new file mode 100644 index 00000000..4ce28479 --- /dev/null +++ b/backend/ent/authidentitychannel_create.go @@ -0,0 +1,932 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" +) + +// AuthIdentityChannelCreate is the builder for creating a AuthIdentityChannel entity. +type AuthIdentityChannelCreate struct { + config + mutation *AuthIdentityChannelMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *AuthIdentityChannelCreate) SetCreatedAt(v time.Time) *AuthIdentityChannelCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *AuthIdentityChannelCreate) SetNillableCreatedAt(v *time.Time) *AuthIdentityChannelCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *AuthIdentityChannelCreate) SetUpdatedAt(v time.Time) *AuthIdentityChannelCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *AuthIdentityChannelCreate) SetNillableUpdatedAt(v *time.Time) *AuthIdentityChannelCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetIdentityID sets the "identity_id" field. +func (_c *AuthIdentityChannelCreate) SetIdentityID(v int64) *AuthIdentityChannelCreate { + _c.mutation.SetIdentityID(v) + return _c +} + +// SetProviderType sets the "provider_type" field. +func (_c *AuthIdentityChannelCreate) SetProviderType(v string) *AuthIdentityChannelCreate { + _c.mutation.SetProviderType(v) + return _c +} + +// SetProviderKey sets the "provider_key" field. +func (_c *AuthIdentityChannelCreate) SetProviderKey(v string) *AuthIdentityChannelCreate { + _c.mutation.SetProviderKey(v) + return _c +} + +// SetChannel sets the "channel" field. +func (_c *AuthIdentityChannelCreate) SetChannel(v string) *AuthIdentityChannelCreate { + _c.mutation.SetChannel(v) + return _c +} + +// SetChannelAppID sets the "channel_app_id" field. +func (_c *AuthIdentityChannelCreate) SetChannelAppID(v string) *AuthIdentityChannelCreate { + _c.mutation.SetChannelAppID(v) + return _c +} + +// SetChannelSubject sets the "channel_subject" field. +func (_c *AuthIdentityChannelCreate) SetChannelSubject(v string) *AuthIdentityChannelCreate { + _c.mutation.SetChannelSubject(v) + return _c +} + +// SetMetadata sets the "metadata" field. +func (_c *AuthIdentityChannelCreate) SetMetadata(v map[string]interface{}) *AuthIdentityChannelCreate { + _c.mutation.SetMetadata(v) + return _c +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_c *AuthIdentityChannelCreate) SetIdentity(v *AuthIdentity) *AuthIdentityChannelCreate { + return _c.SetIdentityID(v.ID) +} + +// Mutation returns the AuthIdentityChannelMutation object of the builder. +func (_c *AuthIdentityChannelCreate) Mutation() *AuthIdentityChannelMutation { + return _c.mutation +} + +// Save creates the AuthIdentityChannel in the database. +func (_c *AuthIdentityChannelCreate) Save(ctx context.Context) (*AuthIdentityChannel, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *AuthIdentityChannelCreate) SaveX(ctx context.Context) *AuthIdentityChannel { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *AuthIdentityChannelCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *AuthIdentityChannelCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *AuthIdentityChannelCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := authidentitychannel.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := authidentitychannel.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.Metadata(); !ok { + v := authidentitychannel.DefaultMetadata() + _c.mutation.SetMetadata(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *AuthIdentityChannelCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AuthIdentityChannel.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "AuthIdentityChannel.updated_at"`)} + } + if _, ok := _c.mutation.IdentityID(); !ok { + return &ValidationError{Name: "identity_id", err: errors.New(`ent: missing required field "AuthIdentityChannel.identity_id"`)} + } + if _, ok := _c.mutation.ProviderType(); !ok { + return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "AuthIdentityChannel.provider_type"`)} + } + if v, ok := _c.mutation.ProviderType(); ok { + if err := authidentitychannel.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderKey(); !ok { + return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "AuthIdentityChannel.provider_key"`)} + } + if v, ok := _c.mutation.ProviderKey(); ok { + if err := authidentitychannel.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)} + } + } + if _, ok := _c.mutation.Channel(); !ok { + return &ValidationError{Name: "channel", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel"`)} + } + if v, ok := _c.mutation.Channel(); ok { + if err := authidentitychannel.ChannelValidator(v); err != nil { + return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)} + } + } + if _, ok := _c.mutation.ChannelAppID(); !ok { + return &ValidationError{Name: "channel_app_id", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel_app_id"`)} + } + if v, ok := _c.mutation.ChannelAppID(); ok { + if err := authidentitychannel.ChannelAppIDValidator(v); err != nil { + return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)} + } + } + if _, ok := _c.mutation.ChannelSubject(); !ok { + return &ValidationError{Name: "channel_subject", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel_subject"`)} + } + if v, ok := _c.mutation.ChannelSubject(); ok { + if err := authidentitychannel.ChannelSubjectValidator(v); err != nil { + return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)} + } + } + if _, ok := _c.mutation.Metadata(); !ok { + return &ValidationError{Name: "metadata", err: errors.New(`ent: missing required field "AuthIdentityChannel.metadata"`)} + } + if len(_c.mutation.IdentityIDs()) == 0 { + return &ValidationError{Name: "identity", err: errors.New(`ent: missing required edge "AuthIdentityChannel.identity"`)} + } + return nil +} + +func (_c *AuthIdentityChannelCreate) sqlSave(ctx context.Context) (*AuthIdentityChannel, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *AuthIdentityChannelCreate) createSpec() (*AuthIdentityChannel, *sqlgraph.CreateSpec) { + var ( + _node = &AuthIdentityChannel{config: _c.config} + _spec = sqlgraph.NewCreateSpec(authidentitychannel.Table, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(authidentitychannel.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.ProviderType(); ok { + _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value) + _node.ProviderType = value + } + if value, ok := _c.mutation.ProviderKey(); ok { + _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value) + _node.ProviderKey = value + } + if value, ok := _c.mutation.Channel(); ok { + _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value) + _node.Channel = value + } + if value, ok := _c.mutation.ChannelAppID(); ok { + _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value) + _node.ChannelAppID = value + } + if value, ok := _c.mutation.ChannelSubject(); ok { + _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value) + _node.ChannelSubject = value + } + if value, ok := _c.mutation.Metadata(); ok { + _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value) + _node.Metadata = value + } + if nodes := _c.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentitychannel.IdentityTable, + Columns: []string{authidentitychannel.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.IdentityID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.AuthIdentityChannel.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.AuthIdentityChannelUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *AuthIdentityChannelCreate) OnConflict(opts ...sql.ConflictOption) *AuthIdentityChannelUpsertOne { + _c.conflict = opts + return &AuthIdentityChannelUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *AuthIdentityChannelCreate) OnConflictColumns(columns ...string) *AuthIdentityChannelUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &AuthIdentityChannelUpsertOne{ + create: _c, + } +} + +type ( + // AuthIdentityChannelUpsertOne is the builder for "upsert"-ing + // one AuthIdentityChannel node. + AuthIdentityChannelUpsertOne struct { + create *AuthIdentityChannelCreate + } + + // AuthIdentityChannelUpsert is the "OnConflict" setter. + AuthIdentityChannelUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityChannelUpsert) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateUpdatedAt() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldUpdatedAt) + return u +} + +// SetIdentityID sets the "identity_id" field. +func (u *AuthIdentityChannelUpsert) SetIdentityID(v int64) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldIdentityID, v) + return u +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateIdentityID() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldIdentityID) + return u +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityChannelUpsert) SetProviderType(v string) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldProviderType, v) + return u +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateProviderType() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldProviderType) + return u +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityChannelUpsert) SetProviderKey(v string) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldProviderKey, v) + return u +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateProviderKey() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldProviderKey) + return u +} + +// SetChannel sets the "channel" field. +func (u *AuthIdentityChannelUpsert) SetChannel(v string) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldChannel, v) + return u +} + +// UpdateChannel sets the "channel" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateChannel() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldChannel) + return u +} + +// SetChannelAppID sets the "channel_app_id" field. +func (u *AuthIdentityChannelUpsert) SetChannelAppID(v string) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldChannelAppID, v) + return u +} + +// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateChannelAppID() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldChannelAppID) + return u +} + +// SetChannelSubject sets the "channel_subject" field. +func (u *AuthIdentityChannelUpsert) SetChannelSubject(v string) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldChannelSubject, v) + return u +} + +// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateChannelSubject() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldChannelSubject) + return u +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityChannelUpsert) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldMetadata, v) + return u +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateMetadata() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldMetadata) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *AuthIdentityChannelUpsertOne) UpdateNewValues() *AuthIdentityChannelUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(authidentitychannel.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *AuthIdentityChannelUpsertOne) Ignore() *AuthIdentityChannelUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *AuthIdentityChannelUpsertOne) DoNothing() *AuthIdentityChannelUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the AuthIdentityChannelCreate.OnConflict +// documentation for more info. +func (u *AuthIdentityChannelUpsertOne) Update(set func(*AuthIdentityChannelUpsert)) *AuthIdentityChannelUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&AuthIdentityChannelUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityChannelUpsertOne) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateUpdatedAt() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetIdentityID sets the "identity_id" field. +func (u *AuthIdentityChannelUpsertOne) SetIdentityID(v int64) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetIdentityID(v) + }) +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateIdentityID() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateIdentityID() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityChannelUpsertOne) SetProviderType(v string) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateProviderType() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityChannelUpsertOne) SetProviderKey(v string) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateProviderKey() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateProviderKey() + }) +} + +// SetChannel sets the "channel" field. +func (u *AuthIdentityChannelUpsertOne) SetChannel(v string) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannel(v) + }) +} + +// UpdateChannel sets the "channel" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateChannel() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannel() + }) +} + +// SetChannelAppID sets the "channel_app_id" field. +func (u *AuthIdentityChannelUpsertOne) SetChannelAppID(v string) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannelAppID(v) + }) +} + +// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateChannelAppID() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannelAppID() + }) +} + +// SetChannelSubject sets the "channel_subject" field. +func (u *AuthIdentityChannelUpsertOne) SetChannelSubject(v string) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannelSubject(v) + }) +} + +// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateChannelSubject() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannelSubject() + }) +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityChannelUpsertOne) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetMetadata(v) + }) +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateMetadata() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateMetadata() + }) +} + +// Exec executes the query. +func (u *AuthIdentityChannelUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for AuthIdentityChannelCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *AuthIdentityChannelUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *AuthIdentityChannelUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *AuthIdentityChannelUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// AuthIdentityChannelCreateBulk is the builder for creating many AuthIdentityChannel entities in bulk. +type AuthIdentityChannelCreateBulk struct { + config + err error + builders []*AuthIdentityChannelCreate + conflict []sql.ConflictOption +} + +// Save creates the AuthIdentityChannel entities in the database. +func (_c *AuthIdentityChannelCreateBulk) Save(ctx context.Context) ([]*AuthIdentityChannel, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*AuthIdentityChannel, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*AuthIdentityChannelMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *AuthIdentityChannelCreateBulk) SaveX(ctx context.Context) []*AuthIdentityChannel { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *AuthIdentityChannelCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *AuthIdentityChannelCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.AuthIdentityChannel.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.AuthIdentityChannelUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *AuthIdentityChannelCreateBulk) OnConflict(opts ...sql.ConflictOption) *AuthIdentityChannelUpsertBulk { + _c.conflict = opts + return &AuthIdentityChannelUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *AuthIdentityChannelCreateBulk) OnConflictColumns(columns ...string) *AuthIdentityChannelUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &AuthIdentityChannelUpsertBulk{ + create: _c, + } +} + +// AuthIdentityChannelUpsertBulk is the builder for "upsert"-ing +// a bulk of AuthIdentityChannel nodes. +type AuthIdentityChannelUpsertBulk struct { + create *AuthIdentityChannelCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *AuthIdentityChannelUpsertBulk) UpdateNewValues() *AuthIdentityChannelUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(authidentitychannel.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *AuthIdentityChannelUpsertBulk) Ignore() *AuthIdentityChannelUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *AuthIdentityChannelUpsertBulk) DoNothing() *AuthIdentityChannelUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the AuthIdentityChannelCreateBulk.OnConflict +// documentation for more info. +func (u *AuthIdentityChannelUpsertBulk) Update(set func(*AuthIdentityChannelUpsert)) *AuthIdentityChannelUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&AuthIdentityChannelUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityChannelUpsertBulk) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateUpdatedAt() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetIdentityID sets the "identity_id" field. +func (u *AuthIdentityChannelUpsertBulk) SetIdentityID(v int64) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetIdentityID(v) + }) +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateIdentityID() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateIdentityID() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityChannelUpsertBulk) SetProviderType(v string) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateProviderType() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityChannelUpsertBulk) SetProviderKey(v string) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateProviderKey() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateProviderKey() + }) +} + +// SetChannel sets the "channel" field. +func (u *AuthIdentityChannelUpsertBulk) SetChannel(v string) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannel(v) + }) +} + +// UpdateChannel sets the "channel" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateChannel() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannel() + }) +} + +// SetChannelAppID sets the "channel_app_id" field. +func (u *AuthIdentityChannelUpsertBulk) SetChannelAppID(v string) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannelAppID(v) + }) +} + +// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateChannelAppID() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannelAppID() + }) +} + +// SetChannelSubject sets the "channel_subject" field. +func (u *AuthIdentityChannelUpsertBulk) SetChannelSubject(v string) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannelSubject(v) + }) +} + +// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateChannelSubject() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannelSubject() + }) +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityChannelUpsertBulk) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetMetadata(v) + }) +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateMetadata() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateMetadata() + }) +} + +// Exec executes the query. +func (u *AuthIdentityChannelUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the AuthIdentityChannelCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for AuthIdentityChannelCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *AuthIdentityChannelUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/authidentitychannel_delete.go b/backend/ent/authidentitychannel_delete.go new file mode 100644 index 00000000..1a4acac5 --- /dev/null +++ b/backend/ent/authidentitychannel_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// AuthIdentityChannelDelete is the builder for deleting a AuthIdentityChannel entity. +type AuthIdentityChannelDelete struct { + config + hooks []Hook + mutation *AuthIdentityChannelMutation +} + +// Where appends a list predicates to the AuthIdentityChannelDelete builder. +func (_d *AuthIdentityChannelDelete) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *AuthIdentityChannelDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *AuthIdentityChannelDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *AuthIdentityChannelDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(authidentitychannel.Table, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// AuthIdentityChannelDeleteOne is the builder for deleting a single AuthIdentityChannel entity. +type AuthIdentityChannelDeleteOne struct { + _d *AuthIdentityChannelDelete +} + +// Where appends a list predicates to the AuthIdentityChannelDelete builder. +func (_d *AuthIdentityChannelDeleteOne) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *AuthIdentityChannelDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{authidentitychannel.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *AuthIdentityChannelDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/authidentitychannel_query.go b/backend/ent/authidentitychannel_query.go new file mode 100644 index 00000000..7a202b7f --- /dev/null +++ b/backend/ent/authidentitychannel_query.go @@ -0,0 +1,643 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// AuthIdentityChannelQuery is the builder for querying AuthIdentityChannel entities. +type AuthIdentityChannelQuery struct { + config + ctx *QueryContext + order []authidentitychannel.OrderOption + inters []Interceptor + predicates []predicate.AuthIdentityChannel + withIdentity *AuthIdentityQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the AuthIdentityChannelQuery builder. +func (_q *AuthIdentityChannelQuery) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *AuthIdentityChannelQuery) Limit(limit int) *AuthIdentityChannelQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *AuthIdentityChannelQuery) Offset(offset int) *AuthIdentityChannelQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *AuthIdentityChannelQuery) Unique(unique bool) *AuthIdentityChannelQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *AuthIdentityChannelQuery) Order(o ...authidentitychannel.OrderOption) *AuthIdentityChannelQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryIdentity chains the current query on the "identity" edge. +func (_q *AuthIdentityChannelQuery) QueryIdentity() *AuthIdentityQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(authidentitychannel.Table, authidentitychannel.FieldID, selector), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, authidentitychannel.IdentityTable, authidentitychannel.IdentityColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first AuthIdentityChannel entity from the query. +// Returns a *NotFoundError when no AuthIdentityChannel was found. +func (_q *AuthIdentityChannelQuery) First(ctx context.Context) (*AuthIdentityChannel, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{authidentitychannel.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) FirstX(ctx context.Context) *AuthIdentityChannel { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first AuthIdentityChannel ID from the query. +// Returns a *NotFoundError when no AuthIdentityChannel ID was found. +func (_q *AuthIdentityChannelQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{authidentitychannel.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single AuthIdentityChannel entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one AuthIdentityChannel entity is found. +// Returns a *NotFoundError when no AuthIdentityChannel entities are found. +func (_q *AuthIdentityChannelQuery) Only(ctx context.Context) (*AuthIdentityChannel, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{authidentitychannel.Label} + default: + return nil, &NotSingularError{authidentitychannel.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) OnlyX(ctx context.Context) *AuthIdentityChannel { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only AuthIdentityChannel ID in the query. +// Returns a *NotSingularError when more than one AuthIdentityChannel ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *AuthIdentityChannelQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{authidentitychannel.Label} + default: + err = &NotSingularError{authidentitychannel.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of AuthIdentityChannels. +func (_q *AuthIdentityChannelQuery) All(ctx context.Context) ([]*AuthIdentityChannel, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*AuthIdentityChannel, *AuthIdentityChannelQuery]() + return withInterceptors[[]*AuthIdentityChannel](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) AllX(ctx context.Context) []*AuthIdentityChannel { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of AuthIdentityChannel IDs. +func (_q *AuthIdentityChannelQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(authidentitychannel.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *AuthIdentityChannelQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*AuthIdentityChannelQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *AuthIdentityChannelQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the AuthIdentityChannelQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *AuthIdentityChannelQuery) Clone() *AuthIdentityChannelQuery { + if _q == nil { + return nil + } + return &AuthIdentityChannelQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]authidentitychannel.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.AuthIdentityChannel{}, _q.predicates...), + withIdentity: _q.withIdentity.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithIdentity tells the query-builder to eager-load the nodes that are connected to +// the "identity" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *AuthIdentityChannelQuery) WithIdentity(opts ...func(*AuthIdentityQuery)) *AuthIdentityChannelQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withIdentity = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.AuthIdentityChannel.Query(). +// GroupBy(authidentitychannel.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *AuthIdentityChannelQuery) GroupBy(field string, fields ...string) *AuthIdentityChannelGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &AuthIdentityChannelGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = authidentitychannel.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.AuthIdentityChannel.Query(). +// Select(authidentitychannel.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *AuthIdentityChannelQuery) Select(fields ...string) *AuthIdentityChannelSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &AuthIdentityChannelSelect{AuthIdentityChannelQuery: _q} + sbuild.label = authidentitychannel.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a AuthIdentityChannelSelect configured with the given aggregations. +func (_q *AuthIdentityChannelQuery) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *AuthIdentityChannelQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !authidentitychannel.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *AuthIdentityChannelQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AuthIdentityChannel, error) { + var ( + nodes = []*AuthIdentityChannel{} + _spec = _q.querySpec() + loadedTypes = [1]bool{ + _q.withIdentity != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*AuthIdentityChannel).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &AuthIdentityChannel{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withIdentity; query != nil { + if err := _q.loadIdentity(ctx, query, nodes, nil, + func(n *AuthIdentityChannel, e *AuthIdentity) { n.Edges.Identity = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *AuthIdentityChannelQuery) loadIdentity(ctx context.Context, query *AuthIdentityQuery, nodes []*AuthIdentityChannel, init func(*AuthIdentityChannel), assign func(*AuthIdentityChannel, *AuthIdentity)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*AuthIdentityChannel) + for i := range nodes { + fk := nodes[i].IdentityID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(authidentity.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "identity_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (_q *AuthIdentityChannelQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *AuthIdentityChannelQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, authidentitychannel.FieldID) + for i := range fields { + if fields[i] != authidentitychannel.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withIdentity != nil { + _spec.Node.AddColumnOnce(authidentitychannel.FieldIdentityID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *AuthIdentityChannelQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(authidentitychannel.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = authidentitychannel.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *AuthIdentityChannelQuery) ForUpdate(opts ...sql.LockOption) *AuthIdentityChannelQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *AuthIdentityChannelQuery) ForShare(opts ...sql.LockOption) *AuthIdentityChannelQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// AuthIdentityChannelGroupBy is the group-by builder for AuthIdentityChannel entities. +type AuthIdentityChannelGroupBy struct { + selector + build *AuthIdentityChannelQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *AuthIdentityChannelGroupBy) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *AuthIdentityChannelGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*AuthIdentityChannelQuery, *AuthIdentityChannelGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *AuthIdentityChannelGroupBy) sqlScan(ctx context.Context, root *AuthIdentityChannelQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// AuthIdentityChannelSelect is the builder for selecting fields of AuthIdentityChannel entities. +type AuthIdentityChannelSelect struct { + *AuthIdentityChannelQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *AuthIdentityChannelSelect) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *AuthIdentityChannelSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*AuthIdentityChannelQuery, *AuthIdentityChannelSelect](ctx, _s.AuthIdentityChannelQuery, _s, _s.inters, v) +} + +func (_s *AuthIdentityChannelSelect) sqlScan(ctx context.Context, root *AuthIdentityChannelQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/authidentitychannel_update.go b/backend/ent/authidentitychannel_update.go new file mode 100644 index 00000000..b550c454 --- /dev/null +++ b/backend/ent/authidentitychannel_update.go @@ -0,0 +1,581 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// AuthIdentityChannelUpdate is the builder for updating AuthIdentityChannel entities. +type AuthIdentityChannelUpdate struct { + config + hooks []Hook + mutation *AuthIdentityChannelMutation +} + +// Where appends a list predicates to the AuthIdentityChannelUpdate builder. +func (_u *AuthIdentityChannelUpdate) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *AuthIdentityChannelUpdate) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetIdentityID sets the "identity_id" field. +func (_u *AuthIdentityChannelUpdate) SetIdentityID(v int64) *AuthIdentityChannelUpdate { + _u.mutation.SetIdentityID(v) + return _u +} + +// SetNillableIdentityID sets the "identity_id" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableIdentityID(v *int64) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetIdentityID(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *AuthIdentityChannelUpdate) SetProviderType(v string) *AuthIdentityChannelUpdate { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableProviderType(v *string) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *AuthIdentityChannelUpdate) SetProviderKey(v string) *AuthIdentityChannelUpdate { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableProviderKey(v *string) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetChannel sets the "channel" field. +func (_u *AuthIdentityChannelUpdate) SetChannel(v string) *AuthIdentityChannelUpdate { + _u.mutation.SetChannel(v) + return _u +} + +// SetNillableChannel sets the "channel" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableChannel(v *string) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetChannel(*v) + } + return _u +} + +// SetChannelAppID sets the "channel_app_id" field. +func (_u *AuthIdentityChannelUpdate) SetChannelAppID(v string) *AuthIdentityChannelUpdate { + _u.mutation.SetChannelAppID(v) + return _u +} + +// SetNillableChannelAppID sets the "channel_app_id" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableChannelAppID(v *string) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetChannelAppID(*v) + } + return _u +} + +// SetChannelSubject sets the "channel_subject" field. +func (_u *AuthIdentityChannelUpdate) SetChannelSubject(v string) *AuthIdentityChannelUpdate { + _u.mutation.SetChannelSubject(v) + return _u +} + +// SetNillableChannelSubject sets the "channel_subject" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableChannelSubject(v *string) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetChannelSubject(*v) + } + return _u +} + +// SetMetadata sets the "metadata" field. +func (_u *AuthIdentityChannelUpdate) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpdate { + _u.mutation.SetMetadata(v) + return _u +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_u *AuthIdentityChannelUpdate) SetIdentity(v *AuthIdentity) *AuthIdentityChannelUpdate { + return _u.SetIdentityID(v.ID) +} + +// Mutation returns the AuthIdentityChannelMutation object of the builder. +func (_u *AuthIdentityChannelUpdate) Mutation() *AuthIdentityChannelMutation { + return _u.mutation +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (_u *AuthIdentityChannelUpdate) ClearIdentity() *AuthIdentityChannelUpdate { + _u.mutation.ClearIdentity() + return _u +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *AuthIdentityChannelUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *AuthIdentityChannelUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *AuthIdentityChannelUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *AuthIdentityChannelUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *AuthIdentityChannelUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := authidentitychannel.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *AuthIdentityChannelUpdate) check() error { + if v, ok := _u.mutation.ProviderType(); ok { + if err := authidentitychannel.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := authidentitychannel.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.Channel(); ok { + if err := authidentitychannel.ChannelValidator(v); err != nil { + return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)} + } + } + if v, ok := _u.mutation.ChannelAppID(); ok { + if err := authidentitychannel.ChannelAppIDValidator(v); err != nil { + return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)} + } + } + if v, ok := _u.mutation.ChannelSubject(); ok { + if err := authidentitychannel.ChannelSubjectValidator(v); err != nil { + return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)} + } + } + if _u.mutation.IdentityCleared() && len(_u.mutation.IdentityIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "AuthIdentityChannel.identity"`) + } + return nil +} + +func (_u *AuthIdentityChannelUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.Channel(); ok { + _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value) + } + if value, ok := _u.mutation.ChannelAppID(); ok { + _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value) + } + if value, ok := _u.mutation.ChannelSubject(); ok { + _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value) + } + if value, ok := _u.mutation.Metadata(); ok { + _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value) + } + if _u.mutation.IdentityCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentitychannel.IdentityTable, + Columns: []string{authidentitychannel.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentitychannel.IdentityTable, + Columns: []string{authidentitychannel.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{authidentitychannel.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// AuthIdentityChannelUpdateOne is the builder for updating a single AuthIdentityChannel entity. +type AuthIdentityChannelUpdateOne struct { + config + fields []string + hooks []Hook + mutation *AuthIdentityChannelMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *AuthIdentityChannelUpdateOne) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetIdentityID sets the "identity_id" field. +func (_u *AuthIdentityChannelUpdateOne) SetIdentityID(v int64) *AuthIdentityChannelUpdateOne { + _u.mutation.SetIdentityID(v) + return _u +} + +// SetNillableIdentityID sets the "identity_id" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableIdentityID(v *int64) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetIdentityID(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *AuthIdentityChannelUpdateOne) SetProviderType(v string) *AuthIdentityChannelUpdateOne { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableProviderType(v *string) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *AuthIdentityChannelUpdateOne) SetProviderKey(v string) *AuthIdentityChannelUpdateOne { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableProviderKey(v *string) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetChannel sets the "channel" field. +func (_u *AuthIdentityChannelUpdateOne) SetChannel(v string) *AuthIdentityChannelUpdateOne { + _u.mutation.SetChannel(v) + return _u +} + +// SetNillableChannel sets the "channel" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableChannel(v *string) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetChannel(*v) + } + return _u +} + +// SetChannelAppID sets the "channel_app_id" field. +func (_u *AuthIdentityChannelUpdateOne) SetChannelAppID(v string) *AuthIdentityChannelUpdateOne { + _u.mutation.SetChannelAppID(v) + return _u +} + +// SetNillableChannelAppID sets the "channel_app_id" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableChannelAppID(v *string) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetChannelAppID(*v) + } + return _u +} + +// SetChannelSubject sets the "channel_subject" field. +func (_u *AuthIdentityChannelUpdateOne) SetChannelSubject(v string) *AuthIdentityChannelUpdateOne { + _u.mutation.SetChannelSubject(v) + return _u +} + +// SetNillableChannelSubject sets the "channel_subject" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableChannelSubject(v *string) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetChannelSubject(*v) + } + return _u +} + +// SetMetadata sets the "metadata" field. +func (_u *AuthIdentityChannelUpdateOne) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpdateOne { + _u.mutation.SetMetadata(v) + return _u +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_u *AuthIdentityChannelUpdateOne) SetIdentity(v *AuthIdentity) *AuthIdentityChannelUpdateOne { + return _u.SetIdentityID(v.ID) +} + +// Mutation returns the AuthIdentityChannelMutation object of the builder. +func (_u *AuthIdentityChannelUpdateOne) Mutation() *AuthIdentityChannelMutation { + return _u.mutation +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (_u *AuthIdentityChannelUpdateOne) ClearIdentity() *AuthIdentityChannelUpdateOne { + _u.mutation.ClearIdentity() + return _u +} + +// Where appends a list predicates to the AuthIdentityChannelUpdate builder. +func (_u *AuthIdentityChannelUpdateOne) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *AuthIdentityChannelUpdateOne) Select(field string, fields ...string) *AuthIdentityChannelUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated AuthIdentityChannel entity. +func (_u *AuthIdentityChannelUpdateOne) Save(ctx context.Context) (*AuthIdentityChannel, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *AuthIdentityChannelUpdateOne) SaveX(ctx context.Context) *AuthIdentityChannel { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *AuthIdentityChannelUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *AuthIdentityChannelUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *AuthIdentityChannelUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := authidentitychannel.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *AuthIdentityChannelUpdateOne) check() error { + if v, ok := _u.mutation.ProviderType(); ok { + if err := authidentitychannel.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := authidentitychannel.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.Channel(); ok { + if err := authidentitychannel.ChannelValidator(v); err != nil { + return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)} + } + } + if v, ok := _u.mutation.ChannelAppID(); ok { + if err := authidentitychannel.ChannelAppIDValidator(v); err != nil { + return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)} + } + } + if v, ok := _u.mutation.ChannelSubject(); ok { + if err := authidentitychannel.ChannelSubjectValidator(v); err != nil { + return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)} + } + } + if _u.mutation.IdentityCleared() && len(_u.mutation.IdentityIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "AuthIdentityChannel.identity"`) + } + return nil +} + +func (_u *AuthIdentityChannelUpdateOne) sqlSave(ctx context.Context) (_node *AuthIdentityChannel, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AuthIdentityChannel.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, authidentitychannel.FieldID) + for _, f := range fields { + if !authidentitychannel.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != authidentitychannel.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.Channel(); ok { + _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value) + } + if value, ok := _u.mutation.ChannelAppID(); ok { + _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value) + } + if value, ok := _u.mutation.ChannelSubject(); ok { + _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value) + } + if value, ok := _u.mutation.Metadata(); ok { + _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value) + } + if _u.mutation.IdentityCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentitychannel.IdentityTable, + Columns: []string{authidentitychannel.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentitychannel.IdentityTable, + Columns: []string{authidentitychannel.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &AuthIdentityChannel{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{authidentitychannel.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/client.go b/backend/ent/client.go index e52e015a..b02f519b 100644 --- a/backend/ent/client.go +++ b/backend/ent/client.go @@ -20,12 +20,16 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" @@ -60,18 +64,26 @@ type Client struct { Announcement *AnnouncementClient // AnnouncementRead is the client for interacting with the AnnouncementRead builders. AnnouncementRead *AnnouncementReadClient + // AuthIdentity is the client for interacting with the AuthIdentity builders. + AuthIdentity *AuthIdentityClient + // AuthIdentityChannel is the client for interacting with the AuthIdentityChannel builders. + AuthIdentityChannel *AuthIdentityChannelClient // ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders. ErrorPassthroughRule *ErrorPassthroughRuleClient // Group is the client for interacting with the Group builders. Group *GroupClient // IdempotencyRecord is the client for interacting with the IdempotencyRecord builders. IdempotencyRecord *IdempotencyRecordClient + // IdentityAdoptionDecision is the client for interacting with the IdentityAdoptionDecision builders. + IdentityAdoptionDecision *IdentityAdoptionDecisionClient // PaymentAuditLog is the client for interacting with the PaymentAuditLog builders. PaymentAuditLog *PaymentAuditLogClient // PaymentOrder is the client for interacting with the PaymentOrder builders. PaymentOrder *PaymentOrderClient // PaymentProviderInstance is the client for interacting with the PaymentProviderInstance builders. PaymentProviderInstance *PaymentProviderInstanceClient + // PendingAuthSession is the client for interacting with the PendingAuthSession builders. + PendingAuthSession *PendingAuthSessionClient // PromoCode is the client for interacting with the PromoCode builders. PromoCode *PromoCodeClient // PromoCodeUsage is the client for interacting with the PromoCodeUsage builders. @@ -118,12 +130,16 @@ func (c *Client) init() { c.AccountGroup = NewAccountGroupClient(c.config) c.Announcement = NewAnnouncementClient(c.config) c.AnnouncementRead = NewAnnouncementReadClient(c.config) + c.AuthIdentity = NewAuthIdentityClient(c.config) + c.AuthIdentityChannel = NewAuthIdentityChannelClient(c.config) c.ErrorPassthroughRule = NewErrorPassthroughRuleClient(c.config) c.Group = NewGroupClient(c.config) c.IdempotencyRecord = NewIdempotencyRecordClient(c.config) + c.IdentityAdoptionDecision = NewIdentityAdoptionDecisionClient(c.config) c.PaymentAuditLog = NewPaymentAuditLogClient(c.config) c.PaymentOrder = NewPaymentOrderClient(c.config) c.PaymentProviderInstance = NewPaymentProviderInstanceClient(c.config) + c.PendingAuthSession = NewPendingAuthSessionClient(c.config) c.PromoCode = NewPromoCodeClient(c.config) c.PromoCodeUsage = NewPromoCodeUsageClient(c.config) c.Proxy = NewProxyClient(c.config) @@ -229,34 +245,38 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { cfg := c.config cfg.driver = tx return &Tx{ - ctx: ctx, - config: cfg, - APIKey: NewAPIKeyClient(cfg), - Account: NewAccountClient(cfg), - AccountGroup: NewAccountGroupClient(cfg), - Announcement: NewAnnouncementClient(cfg), - AnnouncementRead: NewAnnouncementReadClient(cfg), - ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), - Group: NewGroupClient(cfg), - IdempotencyRecord: NewIdempotencyRecordClient(cfg), - PaymentAuditLog: NewPaymentAuditLogClient(cfg), - PaymentOrder: NewPaymentOrderClient(cfg), - PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg), - PromoCode: NewPromoCodeClient(cfg), - PromoCodeUsage: NewPromoCodeUsageClient(cfg), - Proxy: NewProxyClient(cfg), - RedeemCode: NewRedeemCodeClient(cfg), - SecuritySecret: NewSecuritySecretClient(cfg), - Setting: NewSettingClient(cfg), - SubscriptionPlan: NewSubscriptionPlanClient(cfg), - TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg), - UsageCleanupTask: NewUsageCleanupTaskClient(cfg), - UsageLog: NewUsageLogClient(cfg), - User: NewUserClient(cfg), - UserAllowedGroup: NewUserAllowedGroupClient(cfg), - UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg), - UserAttributeValue: NewUserAttributeValueClient(cfg), - UserSubscription: NewUserSubscriptionClient(cfg), + ctx: ctx, + config: cfg, + APIKey: NewAPIKeyClient(cfg), + Account: NewAccountClient(cfg), + AccountGroup: NewAccountGroupClient(cfg), + Announcement: NewAnnouncementClient(cfg), + AnnouncementRead: NewAnnouncementReadClient(cfg), + AuthIdentity: NewAuthIdentityClient(cfg), + AuthIdentityChannel: NewAuthIdentityChannelClient(cfg), + ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), + Group: NewGroupClient(cfg), + IdempotencyRecord: NewIdempotencyRecordClient(cfg), + IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg), + PaymentAuditLog: NewPaymentAuditLogClient(cfg), + PaymentOrder: NewPaymentOrderClient(cfg), + PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg), + PendingAuthSession: NewPendingAuthSessionClient(cfg), + PromoCode: NewPromoCodeClient(cfg), + PromoCodeUsage: NewPromoCodeUsageClient(cfg), + Proxy: NewProxyClient(cfg), + RedeemCode: NewRedeemCodeClient(cfg), + SecuritySecret: NewSecuritySecretClient(cfg), + Setting: NewSettingClient(cfg), + SubscriptionPlan: NewSubscriptionPlanClient(cfg), + TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg), + UsageCleanupTask: NewUsageCleanupTaskClient(cfg), + UsageLog: NewUsageLogClient(cfg), + User: NewUserClient(cfg), + UserAllowedGroup: NewUserAllowedGroupClient(cfg), + UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg), + UserAttributeValue: NewUserAttributeValueClient(cfg), + UserSubscription: NewUserSubscriptionClient(cfg), }, nil } @@ -274,34 +294,38 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) cfg := c.config cfg.driver = &txDriver{tx: tx, drv: c.driver} return &Tx{ - ctx: ctx, - config: cfg, - APIKey: NewAPIKeyClient(cfg), - Account: NewAccountClient(cfg), - AccountGroup: NewAccountGroupClient(cfg), - Announcement: NewAnnouncementClient(cfg), - AnnouncementRead: NewAnnouncementReadClient(cfg), - ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), - Group: NewGroupClient(cfg), - IdempotencyRecord: NewIdempotencyRecordClient(cfg), - PaymentAuditLog: NewPaymentAuditLogClient(cfg), - PaymentOrder: NewPaymentOrderClient(cfg), - PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg), - PromoCode: NewPromoCodeClient(cfg), - PromoCodeUsage: NewPromoCodeUsageClient(cfg), - Proxy: NewProxyClient(cfg), - RedeemCode: NewRedeemCodeClient(cfg), - SecuritySecret: NewSecuritySecretClient(cfg), - Setting: NewSettingClient(cfg), - SubscriptionPlan: NewSubscriptionPlanClient(cfg), - TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg), - UsageCleanupTask: NewUsageCleanupTaskClient(cfg), - UsageLog: NewUsageLogClient(cfg), - User: NewUserClient(cfg), - UserAllowedGroup: NewUserAllowedGroupClient(cfg), - UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg), - UserAttributeValue: NewUserAttributeValueClient(cfg), - UserSubscription: NewUserSubscriptionClient(cfg), + ctx: ctx, + config: cfg, + APIKey: NewAPIKeyClient(cfg), + Account: NewAccountClient(cfg), + AccountGroup: NewAccountGroupClient(cfg), + Announcement: NewAnnouncementClient(cfg), + AnnouncementRead: NewAnnouncementReadClient(cfg), + AuthIdentity: NewAuthIdentityClient(cfg), + AuthIdentityChannel: NewAuthIdentityChannelClient(cfg), + ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), + Group: NewGroupClient(cfg), + IdempotencyRecord: NewIdempotencyRecordClient(cfg), + IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg), + PaymentAuditLog: NewPaymentAuditLogClient(cfg), + PaymentOrder: NewPaymentOrderClient(cfg), + PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg), + PendingAuthSession: NewPendingAuthSessionClient(cfg), + PromoCode: NewPromoCodeClient(cfg), + PromoCodeUsage: NewPromoCodeUsageClient(cfg), + Proxy: NewProxyClient(cfg), + RedeemCode: NewRedeemCodeClient(cfg), + SecuritySecret: NewSecuritySecretClient(cfg), + Setting: NewSettingClient(cfg), + SubscriptionPlan: NewSubscriptionPlanClient(cfg), + TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg), + UsageCleanupTask: NewUsageCleanupTaskClient(cfg), + UsageLog: NewUsageLogClient(cfg), + User: NewUserClient(cfg), + UserAllowedGroup: NewUserAllowedGroupClient(cfg), + UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg), + UserAttributeValue: NewUserAttributeValueClient(cfg), + UserSubscription: NewUserSubscriptionClient(cfg), }, nil } @@ -332,11 +356,12 @@ func (c *Client) Close() error { func (c *Client) Use(hooks ...Hook) { for _, n := range []interface{ Use(...Hook) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, - c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog, - c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage, - c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan, - c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User, - c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.AuthIdentity, c.AuthIdentityChannel, c.ErrorPassthroughRule, c.Group, + c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog, + c.PaymentOrder, c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode, + c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, + c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, + c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Use(hooks...) @@ -348,11 +373,12 @@ func (c *Client) Use(hooks ...Hook) { func (c *Client) Intercept(interceptors ...Interceptor) { for _, n := range []interface{ Intercept(...Interceptor) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, - c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog, - c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage, - c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan, - c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User, - c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.AuthIdentity, c.AuthIdentityChannel, c.ErrorPassthroughRule, c.Group, + c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog, + c.PaymentOrder, c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode, + c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, + c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, + c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Intercept(interceptors...) @@ -372,18 +398,26 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.Announcement.mutate(ctx, m) case *AnnouncementReadMutation: return c.AnnouncementRead.mutate(ctx, m) + case *AuthIdentityMutation: + return c.AuthIdentity.mutate(ctx, m) + case *AuthIdentityChannelMutation: + return c.AuthIdentityChannel.mutate(ctx, m) case *ErrorPassthroughRuleMutation: return c.ErrorPassthroughRule.mutate(ctx, m) case *GroupMutation: return c.Group.mutate(ctx, m) case *IdempotencyRecordMutation: return c.IdempotencyRecord.mutate(ctx, m) + case *IdentityAdoptionDecisionMutation: + return c.IdentityAdoptionDecision.mutate(ctx, m) case *PaymentAuditLogMutation: return c.PaymentAuditLog.mutate(ctx, m) case *PaymentOrderMutation: return c.PaymentOrder.mutate(ctx, m) case *PaymentProviderInstanceMutation: return c.PaymentProviderInstance.mutate(ctx, m) + case *PendingAuthSessionMutation: + return c.PendingAuthSession.mutate(ctx, m) case *PromoCodeMutation: return c.PromoCode.mutate(ctx, m) case *PromoCodeUsageMutation: @@ -1231,6 +1265,336 @@ func (c *AnnouncementReadClient) mutate(ctx context.Context, m *AnnouncementRead } } +// AuthIdentityClient is a client for the AuthIdentity schema. +type AuthIdentityClient struct { + config +} + +// NewAuthIdentityClient returns a client for the AuthIdentity from the given config. +func NewAuthIdentityClient(c config) *AuthIdentityClient { + return &AuthIdentityClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `authidentity.Hooks(f(g(h())))`. +func (c *AuthIdentityClient) Use(hooks ...Hook) { + c.hooks.AuthIdentity = append(c.hooks.AuthIdentity, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `authidentity.Intercept(f(g(h())))`. +func (c *AuthIdentityClient) Intercept(interceptors ...Interceptor) { + c.inters.AuthIdentity = append(c.inters.AuthIdentity, interceptors...) +} + +// Create returns a builder for creating a AuthIdentity entity. +func (c *AuthIdentityClient) Create() *AuthIdentityCreate { + mutation := newAuthIdentityMutation(c.config, OpCreate) + return &AuthIdentityCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of AuthIdentity entities. +func (c *AuthIdentityClient) CreateBulk(builders ...*AuthIdentityCreate) *AuthIdentityCreateBulk { + return &AuthIdentityCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *AuthIdentityClient) MapCreateBulk(slice any, setFunc func(*AuthIdentityCreate, int)) *AuthIdentityCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &AuthIdentityCreateBulk{err: fmt.Errorf("calling to AuthIdentityClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*AuthIdentityCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &AuthIdentityCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for AuthIdentity. +func (c *AuthIdentityClient) Update() *AuthIdentityUpdate { + mutation := newAuthIdentityMutation(c.config, OpUpdate) + return &AuthIdentityUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *AuthIdentityClient) UpdateOne(_m *AuthIdentity) *AuthIdentityUpdateOne { + mutation := newAuthIdentityMutation(c.config, OpUpdateOne, withAuthIdentity(_m)) + return &AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *AuthIdentityClient) UpdateOneID(id int64) *AuthIdentityUpdateOne { + mutation := newAuthIdentityMutation(c.config, OpUpdateOne, withAuthIdentityID(id)) + return &AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for AuthIdentity. +func (c *AuthIdentityClient) Delete() *AuthIdentityDelete { + mutation := newAuthIdentityMutation(c.config, OpDelete) + return &AuthIdentityDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *AuthIdentityClient) DeleteOne(_m *AuthIdentity) *AuthIdentityDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *AuthIdentityClient) DeleteOneID(id int64) *AuthIdentityDeleteOne { + builder := c.Delete().Where(authidentity.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &AuthIdentityDeleteOne{builder} +} + +// Query returns a query builder for AuthIdentity. +func (c *AuthIdentityClient) Query() *AuthIdentityQuery { + return &AuthIdentityQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeAuthIdentity}, + inters: c.Interceptors(), + } +} + +// Get returns a AuthIdentity entity by its id. +func (c *AuthIdentityClient) Get(ctx context.Context, id int64) (*AuthIdentity, error) { + return c.Query().Where(authidentity.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *AuthIdentityClient) GetX(ctx context.Context, id int64) *AuthIdentity { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryUser queries the user edge of a AuthIdentity. +func (c *AuthIdentityClient) QueryUser(_m *AuthIdentity) *UserQuery { + query := (&UserClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, id), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, authidentity.UserTable, authidentity.UserColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryChannels queries the channels edge of a AuthIdentity. +func (c *AuthIdentityClient) QueryChannels(_m *AuthIdentity) *AuthIdentityChannelQuery { + query := (&AuthIdentityChannelClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, id), + sqlgraph.To(authidentitychannel.Table, authidentitychannel.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, authidentity.ChannelsTable, authidentity.ChannelsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryAdoptionDecisions queries the adoption_decisions edge of a AuthIdentity. +func (c *AuthIdentityClient) QueryAdoptionDecisions(_m *AuthIdentity) *IdentityAdoptionDecisionQuery { + query := (&IdentityAdoptionDecisionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, id), + sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, authidentity.AdoptionDecisionsTable, authidentity.AdoptionDecisionsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *AuthIdentityClient) Hooks() []Hook { + return c.hooks.AuthIdentity +} + +// Interceptors returns the client interceptors. +func (c *AuthIdentityClient) Interceptors() []Interceptor { + return c.inters.AuthIdentity +} + +func (c *AuthIdentityClient) mutate(ctx context.Context, m *AuthIdentityMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&AuthIdentityCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&AuthIdentityUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&AuthIdentityDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown AuthIdentity mutation op: %q", m.Op()) + } +} + +// AuthIdentityChannelClient is a client for the AuthIdentityChannel schema. +type AuthIdentityChannelClient struct { + config +} + +// NewAuthIdentityChannelClient returns a client for the AuthIdentityChannel from the given config. +func NewAuthIdentityChannelClient(c config) *AuthIdentityChannelClient { + return &AuthIdentityChannelClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `authidentitychannel.Hooks(f(g(h())))`. +func (c *AuthIdentityChannelClient) Use(hooks ...Hook) { + c.hooks.AuthIdentityChannel = append(c.hooks.AuthIdentityChannel, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `authidentitychannel.Intercept(f(g(h())))`. +func (c *AuthIdentityChannelClient) Intercept(interceptors ...Interceptor) { + c.inters.AuthIdentityChannel = append(c.inters.AuthIdentityChannel, interceptors...) +} + +// Create returns a builder for creating a AuthIdentityChannel entity. +func (c *AuthIdentityChannelClient) Create() *AuthIdentityChannelCreate { + mutation := newAuthIdentityChannelMutation(c.config, OpCreate) + return &AuthIdentityChannelCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of AuthIdentityChannel entities. +func (c *AuthIdentityChannelClient) CreateBulk(builders ...*AuthIdentityChannelCreate) *AuthIdentityChannelCreateBulk { + return &AuthIdentityChannelCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *AuthIdentityChannelClient) MapCreateBulk(slice any, setFunc func(*AuthIdentityChannelCreate, int)) *AuthIdentityChannelCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &AuthIdentityChannelCreateBulk{err: fmt.Errorf("calling to AuthIdentityChannelClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*AuthIdentityChannelCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &AuthIdentityChannelCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for AuthIdentityChannel. +func (c *AuthIdentityChannelClient) Update() *AuthIdentityChannelUpdate { + mutation := newAuthIdentityChannelMutation(c.config, OpUpdate) + return &AuthIdentityChannelUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *AuthIdentityChannelClient) UpdateOne(_m *AuthIdentityChannel) *AuthIdentityChannelUpdateOne { + mutation := newAuthIdentityChannelMutation(c.config, OpUpdateOne, withAuthIdentityChannel(_m)) + return &AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *AuthIdentityChannelClient) UpdateOneID(id int64) *AuthIdentityChannelUpdateOne { + mutation := newAuthIdentityChannelMutation(c.config, OpUpdateOne, withAuthIdentityChannelID(id)) + return &AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for AuthIdentityChannel. +func (c *AuthIdentityChannelClient) Delete() *AuthIdentityChannelDelete { + mutation := newAuthIdentityChannelMutation(c.config, OpDelete) + return &AuthIdentityChannelDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *AuthIdentityChannelClient) DeleteOne(_m *AuthIdentityChannel) *AuthIdentityChannelDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *AuthIdentityChannelClient) DeleteOneID(id int64) *AuthIdentityChannelDeleteOne { + builder := c.Delete().Where(authidentitychannel.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &AuthIdentityChannelDeleteOne{builder} +} + +// Query returns a query builder for AuthIdentityChannel. +func (c *AuthIdentityChannelClient) Query() *AuthIdentityChannelQuery { + return &AuthIdentityChannelQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeAuthIdentityChannel}, + inters: c.Interceptors(), + } +} + +// Get returns a AuthIdentityChannel entity by its id. +func (c *AuthIdentityChannelClient) Get(ctx context.Context, id int64) (*AuthIdentityChannel, error) { + return c.Query().Where(authidentitychannel.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *AuthIdentityChannelClient) GetX(ctx context.Context, id int64) *AuthIdentityChannel { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryIdentity queries the identity edge of a AuthIdentityChannel. +func (c *AuthIdentityChannelClient) QueryIdentity(_m *AuthIdentityChannel) *AuthIdentityQuery { + query := (&AuthIdentityClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(authidentitychannel.Table, authidentitychannel.FieldID, id), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, authidentitychannel.IdentityTable, authidentitychannel.IdentityColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *AuthIdentityChannelClient) Hooks() []Hook { + return c.hooks.AuthIdentityChannel +} + +// Interceptors returns the client interceptors. +func (c *AuthIdentityChannelClient) Interceptors() []Interceptor { + return c.inters.AuthIdentityChannel +} + +func (c *AuthIdentityChannelClient) mutate(ctx context.Context, m *AuthIdentityChannelMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&AuthIdentityChannelCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&AuthIdentityChannelUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&AuthIdentityChannelDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown AuthIdentityChannel mutation op: %q", m.Op()) + } +} + // ErrorPassthroughRuleClient is a client for the ErrorPassthroughRule schema. type ErrorPassthroughRuleClient struct { config @@ -1760,6 +2124,171 @@ func (c *IdempotencyRecordClient) mutate(ctx context.Context, m *IdempotencyReco } } +// IdentityAdoptionDecisionClient is a client for the IdentityAdoptionDecision schema. +type IdentityAdoptionDecisionClient struct { + config +} + +// NewIdentityAdoptionDecisionClient returns a client for the IdentityAdoptionDecision from the given config. +func NewIdentityAdoptionDecisionClient(c config) *IdentityAdoptionDecisionClient { + return &IdentityAdoptionDecisionClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `identityadoptiondecision.Hooks(f(g(h())))`. +func (c *IdentityAdoptionDecisionClient) Use(hooks ...Hook) { + c.hooks.IdentityAdoptionDecision = append(c.hooks.IdentityAdoptionDecision, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `identityadoptiondecision.Intercept(f(g(h())))`. +func (c *IdentityAdoptionDecisionClient) Intercept(interceptors ...Interceptor) { + c.inters.IdentityAdoptionDecision = append(c.inters.IdentityAdoptionDecision, interceptors...) +} + +// Create returns a builder for creating a IdentityAdoptionDecision entity. +func (c *IdentityAdoptionDecisionClient) Create() *IdentityAdoptionDecisionCreate { + mutation := newIdentityAdoptionDecisionMutation(c.config, OpCreate) + return &IdentityAdoptionDecisionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of IdentityAdoptionDecision entities. +func (c *IdentityAdoptionDecisionClient) CreateBulk(builders ...*IdentityAdoptionDecisionCreate) *IdentityAdoptionDecisionCreateBulk { + return &IdentityAdoptionDecisionCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *IdentityAdoptionDecisionClient) MapCreateBulk(slice any, setFunc func(*IdentityAdoptionDecisionCreate, int)) *IdentityAdoptionDecisionCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &IdentityAdoptionDecisionCreateBulk{err: fmt.Errorf("calling to IdentityAdoptionDecisionClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*IdentityAdoptionDecisionCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &IdentityAdoptionDecisionCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for IdentityAdoptionDecision. +func (c *IdentityAdoptionDecisionClient) Update() *IdentityAdoptionDecisionUpdate { + mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdate) + return &IdentityAdoptionDecisionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *IdentityAdoptionDecisionClient) UpdateOne(_m *IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdateOne { + mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdateOne, withIdentityAdoptionDecision(_m)) + return &IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *IdentityAdoptionDecisionClient) UpdateOneID(id int64) *IdentityAdoptionDecisionUpdateOne { + mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdateOne, withIdentityAdoptionDecisionID(id)) + return &IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for IdentityAdoptionDecision. +func (c *IdentityAdoptionDecisionClient) Delete() *IdentityAdoptionDecisionDelete { + mutation := newIdentityAdoptionDecisionMutation(c.config, OpDelete) + return &IdentityAdoptionDecisionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *IdentityAdoptionDecisionClient) DeleteOne(_m *IdentityAdoptionDecision) *IdentityAdoptionDecisionDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *IdentityAdoptionDecisionClient) DeleteOneID(id int64) *IdentityAdoptionDecisionDeleteOne { + builder := c.Delete().Where(identityadoptiondecision.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &IdentityAdoptionDecisionDeleteOne{builder} +} + +// Query returns a query builder for IdentityAdoptionDecision. +func (c *IdentityAdoptionDecisionClient) Query() *IdentityAdoptionDecisionQuery { + return &IdentityAdoptionDecisionQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeIdentityAdoptionDecision}, + inters: c.Interceptors(), + } +} + +// Get returns a IdentityAdoptionDecision entity by its id. +func (c *IdentityAdoptionDecisionClient) Get(ctx context.Context, id int64) (*IdentityAdoptionDecision, error) { + return c.Query().Where(identityadoptiondecision.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *IdentityAdoptionDecisionClient) GetX(ctx context.Context, id int64) *IdentityAdoptionDecision { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryPendingAuthSession queries the pending_auth_session edge of a IdentityAdoptionDecision. +func (c *IdentityAdoptionDecisionClient) QueryPendingAuthSession(_m *IdentityAdoptionDecision) *PendingAuthSessionQuery { + query := (&PendingAuthSessionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, id), + sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID), + sqlgraph.Edge(sqlgraph.O2O, true, identityadoptiondecision.PendingAuthSessionTable, identityadoptiondecision.PendingAuthSessionColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryIdentity queries the identity edge of a IdentityAdoptionDecision. +func (c *IdentityAdoptionDecisionClient) QueryIdentity(_m *IdentityAdoptionDecision) *AuthIdentityQuery { + query := (&AuthIdentityClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, id), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, identityadoptiondecision.IdentityTable, identityadoptiondecision.IdentityColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *IdentityAdoptionDecisionClient) Hooks() []Hook { + return c.hooks.IdentityAdoptionDecision +} + +// Interceptors returns the client interceptors. +func (c *IdentityAdoptionDecisionClient) Interceptors() []Interceptor { + return c.inters.IdentityAdoptionDecision +} + +func (c *IdentityAdoptionDecisionClient) mutate(ctx context.Context, m *IdentityAdoptionDecisionMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&IdentityAdoptionDecisionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&IdentityAdoptionDecisionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&IdentityAdoptionDecisionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown IdentityAdoptionDecision mutation op: %q", m.Op()) + } +} + // PaymentAuditLogClient is a client for the PaymentAuditLog schema. type PaymentAuditLogClient struct { config @@ -2175,6 +2704,171 @@ func (c *PaymentProviderInstanceClient) mutate(ctx context.Context, m *PaymentPr } } +// PendingAuthSessionClient is a client for the PendingAuthSession schema. +type PendingAuthSessionClient struct { + config +} + +// NewPendingAuthSessionClient returns a client for the PendingAuthSession from the given config. +func NewPendingAuthSessionClient(c config) *PendingAuthSessionClient { + return &PendingAuthSessionClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `pendingauthsession.Hooks(f(g(h())))`. +func (c *PendingAuthSessionClient) Use(hooks ...Hook) { + c.hooks.PendingAuthSession = append(c.hooks.PendingAuthSession, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `pendingauthsession.Intercept(f(g(h())))`. +func (c *PendingAuthSessionClient) Intercept(interceptors ...Interceptor) { + c.inters.PendingAuthSession = append(c.inters.PendingAuthSession, interceptors...) +} + +// Create returns a builder for creating a PendingAuthSession entity. +func (c *PendingAuthSessionClient) Create() *PendingAuthSessionCreate { + mutation := newPendingAuthSessionMutation(c.config, OpCreate) + return &PendingAuthSessionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of PendingAuthSession entities. +func (c *PendingAuthSessionClient) CreateBulk(builders ...*PendingAuthSessionCreate) *PendingAuthSessionCreateBulk { + return &PendingAuthSessionCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *PendingAuthSessionClient) MapCreateBulk(slice any, setFunc func(*PendingAuthSessionCreate, int)) *PendingAuthSessionCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &PendingAuthSessionCreateBulk{err: fmt.Errorf("calling to PendingAuthSessionClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*PendingAuthSessionCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &PendingAuthSessionCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for PendingAuthSession. +func (c *PendingAuthSessionClient) Update() *PendingAuthSessionUpdate { + mutation := newPendingAuthSessionMutation(c.config, OpUpdate) + return &PendingAuthSessionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *PendingAuthSessionClient) UpdateOne(_m *PendingAuthSession) *PendingAuthSessionUpdateOne { + mutation := newPendingAuthSessionMutation(c.config, OpUpdateOne, withPendingAuthSession(_m)) + return &PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *PendingAuthSessionClient) UpdateOneID(id int64) *PendingAuthSessionUpdateOne { + mutation := newPendingAuthSessionMutation(c.config, OpUpdateOne, withPendingAuthSessionID(id)) + return &PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for PendingAuthSession. +func (c *PendingAuthSessionClient) Delete() *PendingAuthSessionDelete { + mutation := newPendingAuthSessionMutation(c.config, OpDelete) + return &PendingAuthSessionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *PendingAuthSessionClient) DeleteOne(_m *PendingAuthSession) *PendingAuthSessionDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *PendingAuthSessionClient) DeleteOneID(id int64) *PendingAuthSessionDeleteOne { + builder := c.Delete().Where(pendingauthsession.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &PendingAuthSessionDeleteOne{builder} +} + +// Query returns a query builder for PendingAuthSession. +func (c *PendingAuthSessionClient) Query() *PendingAuthSessionQuery { + return &PendingAuthSessionQuery{ + config: c.config, + ctx: &QueryContext{Type: TypePendingAuthSession}, + inters: c.Interceptors(), + } +} + +// Get returns a PendingAuthSession entity by its id. +func (c *PendingAuthSessionClient) Get(ctx context.Context, id int64) (*PendingAuthSession, error) { + return c.Query().Where(pendingauthsession.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *PendingAuthSessionClient) GetX(ctx context.Context, id int64) *PendingAuthSession { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryTargetUser queries the target_user edge of a PendingAuthSession. +func (c *PendingAuthSessionClient) QueryTargetUser(_m *PendingAuthSession) *UserQuery { + query := (&UserClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, id), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, pendingauthsession.TargetUserTable, pendingauthsession.TargetUserColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryAdoptionDecision queries the adoption_decision edge of a PendingAuthSession. +func (c *PendingAuthSessionClient) QueryAdoptionDecision(_m *PendingAuthSession) *IdentityAdoptionDecisionQuery { + query := (&IdentityAdoptionDecisionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, id), + sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID), + sqlgraph.Edge(sqlgraph.O2O, false, pendingauthsession.AdoptionDecisionTable, pendingauthsession.AdoptionDecisionColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *PendingAuthSessionClient) Hooks() []Hook { + return c.hooks.PendingAuthSession +} + +// Interceptors returns the client interceptors. +func (c *PendingAuthSessionClient) Interceptors() []Interceptor { + return c.inters.PendingAuthSession +} + +func (c *PendingAuthSessionClient) mutate(ctx context.Context, m *PendingAuthSessionMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&PendingAuthSessionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&PendingAuthSessionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&PendingAuthSessionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown PendingAuthSession mutation op: %q", m.Op()) + } +} + // PromoCodeClient is a client for the PromoCode schema. type PromoCodeClient struct { config @@ -3951,6 +4645,38 @@ func (c *UserClient) QueryPaymentOrders(_m *User) *PaymentOrderQuery { return query } +// QueryAuthIdentities queries the auth_identities edge of a User. +func (c *UserClient) QueryAuthIdentities(_m *User) *AuthIdentityQuery { + query := (&AuthIdentityClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, id), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.AuthIdentitiesTable, user.AuthIdentitiesColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryPendingAuthSessions queries the pending_auth_sessions edge of a User. +func (c *UserClient) QueryPendingAuthSessions(_m *User) *PendingAuthSessionQuery { + query := (&PendingAuthSessionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, id), + sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.PendingAuthSessionsTable, user.PendingAuthSessionsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + // QueryUserAllowedGroups queries the user_allowed_groups edge of a User. func (c *UserClient) QueryUserAllowedGroups(_m *User) *UserAllowedGroupQuery { query := (&UserAllowedGroupClient{config: c.config}).Query() @@ -4628,18 +5354,20 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription // hooks and interceptors per client, for fast access. type ( hooks struct { - APIKey, Account, AccountGroup, Announcement, AnnouncementRead, - ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder, - PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode, - SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, + APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity, + AuthIdentityChannel, ErrorPassthroughRule, Group, IdempotencyRecord, + IdentityAdoptionDecision, PaymentAuditLog, PaymentOrder, + PaymentProviderInstance, PendingAuthSession, PromoCode, PromoCodeUsage, Proxy, + RedeemCode, SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook } inters struct { - APIKey, Account, AccountGroup, Announcement, AnnouncementRead, - ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder, - PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode, - SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, + APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity, + AuthIdentityChannel, ErrorPassthroughRule, Group, IdempotencyRecord, + IdentityAdoptionDecision, PaymentAuditLog, PaymentOrder, + PaymentProviderInstance, PendingAuthSession, PromoCode, PromoCodeUsage, Proxy, + RedeemCode, SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor } diff --git a/backend/ent/ent.go b/backend/ent/ent.go index 96ed5e03..339e5369 100644 --- a/backend/ent/ent.go +++ b/backend/ent/ent.go @@ -17,12 +17,16 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" @@ -98,32 +102,36 @@ var ( func checkColumn(t, c string) error { initCheck.Do(func() { columnCheck = sql.NewColumnCheck(map[string]func(string) bool{ - apikey.Table: apikey.ValidColumn, - account.Table: account.ValidColumn, - accountgroup.Table: accountgroup.ValidColumn, - announcement.Table: announcement.ValidColumn, - announcementread.Table: announcementread.ValidColumn, - errorpassthroughrule.Table: errorpassthroughrule.ValidColumn, - group.Table: group.ValidColumn, - idempotencyrecord.Table: idempotencyrecord.ValidColumn, - paymentauditlog.Table: paymentauditlog.ValidColumn, - paymentorder.Table: paymentorder.ValidColumn, - paymentproviderinstance.Table: paymentproviderinstance.ValidColumn, - promocode.Table: promocode.ValidColumn, - promocodeusage.Table: promocodeusage.ValidColumn, - proxy.Table: proxy.ValidColumn, - redeemcode.Table: redeemcode.ValidColumn, - securitysecret.Table: securitysecret.ValidColumn, - setting.Table: setting.ValidColumn, - subscriptionplan.Table: subscriptionplan.ValidColumn, - tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn, - usagecleanuptask.Table: usagecleanuptask.ValidColumn, - usagelog.Table: usagelog.ValidColumn, - user.Table: user.ValidColumn, - userallowedgroup.Table: userallowedgroup.ValidColumn, - userattributedefinition.Table: userattributedefinition.ValidColumn, - userattributevalue.Table: userattributevalue.ValidColumn, - usersubscription.Table: usersubscription.ValidColumn, + apikey.Table: apikey.ValidColumn, + account.Table: account.ValidColumn, + accountgroup.Table: accountgroup.ValidColumn, + announcement.Table: announcement.ValidColumn, + announcementread.Table: announcementread.ValidColumn, + authidentity.Table: authidentity.ValidColumn, + authidentitychannel.Table: authidentitychannel.ValidColumn, + errorpassthroughrule.Table: errorpassthroughrule.ValidColumn, + group.Table: group.ValidColumn, + idempotencyrecord.Table: idempotencyrecord.ValidColumn, + identityadoptiondecision.Table: identityadoptiondecision.ValidColumn, + paymentauditlog.Table: paymentauditlog.ValidColumn, + paymentorder.Table: paymentorder.ValidColumn, + paymentproviderinstance.Table: paymentproviderinstance.ValidColumn, + pendingauthsession.Table: pendingauthsession.ValidColumn, + promocode.Table: promocode.ValidColumn, + promocodeusage.Table: promocodeusage.ValidColumn, + proxy.Table: proxy.ValidColumn, + redeemcode.Table: redeemcode.ValidColumn, + securitysecret.Table: securitysecret.ValidColumn, + setting.Table: setting.ValidColumn, + subscriptionplan.Table: subscriptionplan.ValidColumn, + tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn, + usagecleanuptask.Table: usagecleanuptask.ValidColumn, + usagelog.Table: usagelog.ValidColumn, + user.Table: user.ValidColumn, + userallowedgroup.Table: userallowedgroup.ValidColumn, + userattributedefinition.Table: userattributedefinition.ValidColumn, + userattributevalue.Table: userattributevalue.ValidColumn, + usersubscription.Table: usersubscription.ValidColumn, }) }) return columnCheck(t, c) diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go index 199dacea..46ac02bc 100644 --- a/backend/ent/hook/hook.go +++ b/backend/ent/hook/hook.go @@ -69,6 +69,30 @@ func (f AnnouncementReadFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.V return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AnnouncementReadMutation", m) } +// The AuthIdentityFunc type is an adapter to allow the use of ordinary +// function as AuthIdentity mutator. +type AuthIdentityFunc func(context.Context, *ent.AuthIdentityMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f AuthIdentityFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.AuthIdentityMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AuthIdentityMutation", m) +} + +// The AuthIdentityChannelFunc type is an adapter to allow the use of ordinary +// function as AuthIdentityChannel mutator. +type AuthIdentityChannelFunc func(context.Context, *ent.AuthIdentityChannelMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f AuthIdentityChannelFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.AuthIdentityChannelMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AuthIdentityChannelMutation", m) +} + // The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary // function as ErrorPassthroughRule mutator. type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleMutation) (ent.Value, error) @@ -105,6 +129,18 @@ func (f IdempotencyRecordFunc) Mutate(ctx context.Context, m ent.Mutation) (ent. return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdempotencyRecordMutation", m) } +// The IdentityAdoptionDecisionFunc type is an adapter to allow the use of ordinary +// function as IdentityAdoptionDecision mutator. +type IdentityAdoptionDecisionFunc func(context.Context, *ent.IdentityAdoptionDecisionMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f IdentityAdoptionDecisionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.IdentityAdoptionDecisionMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdentityAdoptionDecisionMutation", m) +} + // The PaymentAuditLogFunc type is an adapter to allow the use of ordinary // function as PaymentAuditLog mutator. type PaymentAuditLogFunc func(context.Context, *ent.PaymentAuditLogMutation) (ent.Value, error) @@ -141,6 +177,18 @@ func (f PaymentProviderInstanceFunc) Mutate(ctx context.Context, m ent.Mutation) return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PaymentProviderInstanceMutation", m) } +// The PendingAuthSessionFunc type is an adapter to allow the use of ordinary +// function as PendingAuthSession mutator. +type PendingAuthSessionFunc func(context.Context, *ent.PendingAuthSessionMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f PendingAuthSessionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.PendingAuthSessionMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PendingAuthSessionMutation", m) +} + // The PromoCodeFunc type is an adapter to allow the use of ordinary // function as PromoCode mutator. type PromoCodeFunc func(context.Context, *ent.PromoCodeMutation) (ent.Value, error) diff --git a/backend/ent/identityadoptiondecision.go b/backend/ent/identityadoptiondecision.go new file mode 100644 index 00000000..ecaee65c --- /dev/null +++ b/backend/ent/identityadoptiondecision.go @@ -0,0 +1,223 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" +) + +// IdentityAdoptionDecision is the model entity for the IdentityAdoptionDecision schema. +type IdentityAdoptionDecision struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // PendingAuthSessionID holds the value of the "pending_auth_session_id" field. + PendingAuthSessionID int64 `json:"pending_auth_session_id,omitempty"` + // IdentityID holds the value of the "identity_id" field. + IdentityID *int64 `json:"identity_id,omitempty"` + // AdoptDisplayName holds the value of the "adopt_display_name" field. + AdoptDisplayName bool `json:"adopt_display_name,omitempty"` + // AdoptAvatar holds the value of the "adopt_avatar" field. + AdoptAvatar bool `json:"adopt_avatar,omitempty"` + // DecidedAt holds the value of the "decided_at" field. + DecidedAt time.Time `json:"decided_at,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the IdentityAdoptionDecisionQuery when eager-loading is set. + Edges IdentityAdoptionDecisionEdges `json:"edges"` + selectValues sql.SelectValues +} + +// IdentityAdoptionDecisionEdges holds the relations/edges for other nodes in the graph. +type IdentityAdoptionDecisionEdges struct { + // PendingAuthSession holds the value of the pending_auth_session edge. + PendingAuthSession *PendingAuthSession `json:"pending_auth_session,omitempty"` + // Identity holds the value of the identity edge. + Identity *AuthIdentity `json:"identity,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [2]bool +} + +// PendingAuthSessionOrErr returns the PendingAuthSession value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e IdentityAdoptionDecisionEdges) PendingAuthSessionOrErr() (*PendingAuthSession, error) { + if e.PendingAuthSession != nil { + return e.PendingAuthSession, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: pendingauthsession.Label} + } + return nil, &NotLoadedError{edge: "pending_auth_session"} +} + +// IdentityOrErr returns the Identity value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e IdentityAdoptionDecisionEdges) IdentityOrErr() (*AuthIdentity, error) { + if e.Identity != nil { + return e.Identity, nil + } else if e.loadedTypes[1] { + return nil, &NotFoundError{label: authidentity.Label} + } + return nil, &NotLoadedError{edge: "identity"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*IdentityAdoptionDecision) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case identityadoptiondecision.FieldAdoptDisplayName, identityadoptiondecision.FieldAdoptAvatar: + values[i] = new(sql.NullBool) + case identityadoptiondecision.FieldID, identityadoptiondecision.FieldPendingAuthSessionID, identityadoptiondecision.FieldIdentityID: + values[i] = new(sql.NullInt64) + case identityadoptiondecision.FieldCreatedAt, identityadoptiondecision.FieldUpdatedAt, identityadoptiondecision.FieldDecidedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the IdentityAdoptionDecision fields. +func (_m *IdentityAdoptionDecision) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case identityadoptiondecision.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case identityadoptiondecision.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case identityadoptiondecision.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case identityadoptiondecision.FieldPendingAuthSessionID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field pending_auth_session_id", values[i]) + } else if value.Valid { + _m.PendingAuthSessionID = value.Int64 + } + case identityadoptiondecision.FieldIdentityID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field identity_id", values[i]) + } else if value.Valid { + _m.IdentityID = new(int64) + *_m.IdentityID = value.Int64 + } + case identityadoptiondecision.FieldAdoptDisplayName: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field adopt_display_name", values[i]) + } else if value.Valid { + _m.AdoptDisplayName = value.Bool + } + case identityadoptiondecision.FieldAdoptAvatar: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field adopt_avatar", values[i]) + } else if value.Valid { + _m.AdoptAvatar = value.Bool + } + case identityadoptiondecision.FieldDecidedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field decided_at", values[i]) + } else if value.Valid { + _m.DecidedAt = value.Time + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the IdentityAdoptionDecision. +// This includes values selected through modifiers, order, etc. +func (_m *IdentityAdoptionDecision) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryPendingAuthSession queries the "pending_auth_session" edge of the IdentityAdoptionDecision entity. +func (_m *IdentityAdoptionDecision) QueryPendingAuthSession() *PendingAuthSessionQuery { + return NewIdentityAdoptionDecisionClient(_m.config).QueryPendingAuthSession(_m) +} + +// QueryIdentity queries the "identity" edge of the IdentityAdoptionDecision entity. +func (_m *IdentityAdoptionDecision) QueryIdentity() *AuthIdentityQuery { + return NewIdentityAdoptionDecisionClient(_m.config).QueryIdentity(_m) +} + +// Update returns a builder for updating this IdentityAdoptionDecision. +// Note that you need to call IdentityAdoptionDecision.Unwrap() before calling this method if this IdentityAdoptionDecision +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *IdentityAdoptionDecision) Update() *IdentityAdoptionDecisionUpdateOne { + return NewIdentityAdoptionDecisionClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the IdentityAdoptionDecision entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *IdentityAdoptionDecision) Unwrap() *IdentityAdoptionDecision { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: IdentityAdoptionDecision is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *IdentityAdoptionDecision) String() string { + var builder strings.Builder + builder.WriteString("IdentityAdoptionDecision(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("pending_auth_session_id=") + builder.WriteString(fmt.Sprintf("%v", _m.PendingAuthSessionID)) + builder.WriteString(", ") + if v := _m.IdentityID; v != nil { + builder.WriteString("identity_id=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + builder.WriteString("adopt_display_name=") + builder.WriteString(fmt.Sprintf("%v", _m.AdoptDisplayName)) + builder.WriteString(", ") + builder.WriteString("adopt_avatar=") + builder.WriteString(fmt.Sprintf("%v", _m.AdoptAvatar)) + builder.WriteString(", ") + builder.WriteString("decided_at=") + builder.WriteString(_m.DecidedAt.Format(time.ANSIC)) + builder.WriteByte(')') + return builder.String() +} + +// IdentityAdoptionDecisions is a parsable slice of IdentityAdoptionDecision. +type IdentityAdoptionDecisions []*IdentityAdoptionDecision diff --git a/backend/ent/identityadoptiondecision/identityadoptiondecision.go b/backend/ent/identityadoptiondecision/identityadoptiondecision.go new file mode 100644 index 00000000..93adaf73 --- /dev/null +++ b/backend/ent/identityadoptiondecision/identityadoptiondecision.go @@ -0,0 +1,159 @@ +// Code generated by ent, DO NOT EDIT. + +package identityadoptiondecision + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the identityadoptiondecision type in the database. + Label = "identity_adoption_decision" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldPendingAuthSessionID holds the string denoting the pending_auth_session_id field in the database. + FieldPendingAuthSessionID = "pending_auth_session_id" + // FieldIdentityID holds the string denoting the identity_id field in the database. + FieldIdentityID = "identity_id" + // FieldAdoptDisplayName holds the string denoting the adopt_display_name field in the database. + FieldAdoptDisplayName = "adopt_display_name" + // FieldAdoptAvatar holds the string denoting the adopt_avatar field in the database. + FieldAdoptAvatar = "adopt_avatar" + // FieldDecidedAt holds the string denoting the decided_at field in the database. + FieldDecidedAt = "decided_at" + // EdgePendingAuthSession holds the string denoting the pending_auth_session edge name in mutations. + EdgePendingAuthSession = "pending_auth_session" + // EdgeIdentity holds the string denoting the identity edge name in mutations. + EdgeIdentity = "identity" + // Table holds the table name of the identityadoptiondecision in the database. + Table = "identity_adoption_decisions" + // PendingAuthSessionTable is the table that holds the pending_auth_session relation/edge. + PendingAuthSessionTable = "identity_adoption_decisions" + // PendingAuthSessionInverseTable is the table name for the PendingAuthSession entity. + // It exists in this package in order to avoid circular dependency with the "pendingauthsession" package. + PendingAuthSessionInverseTable = "pending_auth_sessions" + // PendingAuthSessionColumn is the table column denoting the pending_auth_session relation/edge. + PendingAuthSessionColumn = "pending_auth_session_id" + // IdentityTable is the table that holds the identity relation/edge. + IdentityTable = "identity_adoption_decisions" + // IdentityInverseTable is the table name for the AuthIdentity entity. + // It exists in this package in order to avoid circular dependency with the "authidentity" package. + IdentityInverseTable = "auth_identities" + // IdentityColumn is the table column denoting the identity relation/edge. + IdentityColumn = "identity_id" +) + +// Columns holds all SQL columns for identityadoptiondecision fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldPendingAuthSessionID, + FieldIdentityID, + FieldAdoptDisplayName, + FieldAdoptAvatar, + FieldDecidedAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // DefaultAdoptDisplayName holds the default value on creation for the "adopt_display_name" field. + DefaultAdoptDisplayName bool + // DefaultAdoptAvatar holds the default value on creation for the "adopt_avatar" field. + DefaultAdoptAvatar bool + // DefaultDecidedAt holds the default value on creation for the "decided_at" field. + DefaultDecidedAt func() time.Time +) + +// OrderOption defines the ordering options for the IdentityAdoptionDecision queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByPendingAuthSessionID orders the results by the pending_auth_session_id field. +func ByPendingAuthSessionID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPendingAuthSessionID, opts...).ToFunc() +} + +// ByIdentityID orders the results by the identity_id field. +func ByIdentityID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIdentityID, opts...).ToFunc() +} + +// ByAdoptDisplayName orders the results by the adopt_display_name field. +func ByAdoptDisplayName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAdoptDisplayName, opts...).ToFunc() +} + +// ByAdoptAvatar orders the results by the adopt_avatar field. +func ByAdoptAvatar(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAdoptAvatar, opts...).ToFunc() +} + +// ByDecidedAt orders the results by the decided_at field. +func ByDecidedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDecidedAt, opts...).ToFunc() +} + +// ByPendingAuthSessionField orders the results by pending_auth_session field. +func ByPendingAuthSessionField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newPendingAuthSessionStep(), sql.OrderByField(field, opts...)) + } +} + +// ByIdentityField orders the results by identity field. +func ByIdentityField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newIdentityStep(), sql.OrderByField(field, opts...)) + } +} +func newPendingAuthSessionStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(PendingAuthSessionInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2O, true, PendingAuthSessionTable, PendingAuthSessionColumn), + ) +} +func newIdentityStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(IdentityInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn), + ) +} diff --git a/backend/ent/identityadoptiondecision/where.go b/backend/ent/identityadoptiondecision/where.go new file mode 100644 index 00000000..1968f175 --- /dev/null +++ b/backend/ent/identityadoptiondecision/where.go @@ -0,0 +1,342 @@ +// Code generated by ent, DO NOT EDIT. + +package identityadoptiondecision + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// PendingAuthSessionID applies equality check predicate on the "pending_auth_session_id" field. It's identical to PendingAuthSessionIDEQ. +func PendingAuthSessionID(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldPendingAuthSessionID, v)) +} + +// IdentityID applies equality check predicate on the "identity_id" field. It's identical to IdentityIDEQ. +func IdentityID(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldIdentityID, v)) +} + +// AdoptDisplayName applies equality check predicate on the "adopt_display_name" field. It's identical to AdoptDisplayNameEQ. +func AdoptDisplayName(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptDisplayName, v)) +} + +// AdoptAvatar applies equality check predicate on the "adopt_avatar" field. It's identical to AdoptAvatarEQ. +func AdoptAvatar(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptAvatar, v)) +} + +// DecidedAt applies equality check predicate on the "decided_at" field. It's identical to DecidedAtEQ. +func DecidedAt(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldDecidedAt, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// PendingAuthSessionIDEQ applies the EQ predicate on the "pending_auth_session_id" field. +func PendingAuthSessionIDEQ(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldPendingAuthSessionID, v)) +} + +// PendingAuthSessionIDNEQ applies the NEQ predicate on the "pending_auth_session_id" field. +func PendingAuthSessionIDNEQ(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldPendingAuthSessionID, v)) +} + +// PendingAuthSessionIDIn applies the In predicate on the "pending_auth_session_id" field. +func PendingAuthSessionIDIn(vs ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldPendingAuthSessionID, vs...)) +} + +// PendingAuthSessionIDNotIn applies the NotIn predicate on the "pending_auth_session_id" field. +func PendingAuthSessionIDNotIn(vs ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldPendingAuthSessionID, vs...)) +} + +// IdentityIDEQ applies the EQ predicate on the "identity_id" field. +func IdentityIDEQ(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldIdentityID, v)) +} + +// IdentityIDNEQ applies the NEQ predicate on the "identity_id" field. +func IdentityIDNEQ(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldIdentityID, v)) +} + +// IdentityIDIn applies the In predicate on the "identity_id" field. +func IdentityIDIn(vs ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldIdentityID, vs...)) +} + +// IdentityIDNotIn applies the NotIn predicate on the "identity_id" field. +func IdentityIDNotIn(vs ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldIdentityID, vs...)) +} + +// IdentityIDIsNil applies the IsNil predicate on the "identity_id" field. +func IdentityIDIsNil() predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIsNull(FieldIdentityID)) +} + +// IdentityIDNotNil applies the NotNil predicate on the "identity_id" field. +func IdentityIDNotNil() predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotNull(FieldIdentityID)) +} + +// AdoptDisplayNameEQ applies the EQ predicate on the "adopt_display_name" field. +func AdoptDisplayNameEQ(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptDisplayName, v)) +} + +// AdoptDisplayNameNEQ applies the NEQ predicate on the "adopt_display_name" field. +func AdoptDisplayNameNEQ(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldAdoptDisplayName, v)) +} + +// AdoptAvatarEQ applies the EQ predicate on the "adopt_avatar" field. +func AdoptAvatarEQ(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptAvatar, v)) +} + +// AdoptAvatarNEQ applies the NEQ predicate on the "adopt_avatar" field. +func AdoptAvatarNEQ(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldAdoptAvatar, v)) +} + +// DecidedAtEQ applies the EQ predicate on the "decided_at" field. +func DecidedAtEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldDecidedAt, v)) +} + +// DecidedAtNEQ applies the NEQ predicate on the "decided_at" field. +func DecidedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldDecidedAt, v)) +} + +// DecidedAtIn applies the In predicate on the "decided_at" field. +func DecidedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldDecidedAt, vs...)) +} + +// DecidedAtNotIn applies the NotIn predicate on the "decided_at" field. +func DecidedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldDecidedAt, vs...)) +} + +// DecidedAtGT applies the GT predicate on the "decided_at" field. +func DecidedAtGT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldDecidedAt, v)) +} + +// DecidedAtGTE applies the GTE predicate on the "decided_at" field. +func DecidedAtGTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldDecidedAt, v)) +} + +// DecidedAtLT applies the LT predicate on the "decided_at" field. +func DecidedAtLT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldDecidedAt, v)) +} + +// DecidedAtLTE applies the LTE predicate on the "decided_at" field. +func DecidedAtLTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldDecidedAt, v)) +} + +// HasPendingAuthSession applies the HasEdge predicate on the "pending_auth_session" edge. +func HasPendingAuthSession() predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2O, true, PendingAuthSessionTable, PendingAuthSessionColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasPendingAuthSessionWith applies the HasEdge predicate on the "pending_auth_session" edge with a given conditions (other predicates). +func HasPendingAuthSessionWith(preds ...predicate.PendingAuthSession) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + step := newPendingAuthSessionStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasIdentity applies the HasEdge predicate on the "identity" edge. +func HasIdentity() predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasIdentityWith applies the HasEdge predicate on the "identity" edge with a given conditions (other predicates). +func HasIdentityWith(preds ...predicate.AuthIdentity) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + step := newIdentityStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.NotPredicates(p)) +} diff --git a/backend/ent/identityadoptiondecision_create.go b/backend/ent/identityadoptiondecision_create.go new file mode 100644 index 00000000..491ba9f9 --- /dev/null +++ b/backend/ent/identityadoptiondecision_create.go @@ -0,0 +1,843 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" +) + +// IdentityAdoptionDecisionCreate is the builder for creating a IdentityAdoptionDecision entity. +type IdentityAdoptionDecisionCreate struct { + config + mutation *IdentityAdoptionDecisionMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *IdentityAdoptionDecisionCreate) SetCreatedAt(v time.Time) *IdentityAdoptionDecisionCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableCreatedAt(v *time.Time) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *IdentityAdoptionDecisionCreate) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableUpdatedAt(v *time.Time) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (_c *IdentityAdoptionDecisionCreate) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionCreate { + _c.mutation.SetPendingAuthSessionID(v) + return _c +} + +// SetIdentityID sets the "identity_id" field. +func (_c *IdentityAdoptionDecisionCreate) SetIdentityID(v int64) *IdentityAdoptionDecisionCreate { + _c.mutation.SetIdentityID(v) + return _c +} + +// SetNillableIdentityID sets the "identity_id" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetIdentityID(*v) + } + return _c +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (_c *IdentityAdoptionDecisionCreate) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionCreate { + _c.mutation.SetAdoptDisplayName(v) + return _c +} + +// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetAdoptDisplayName(*v) + } + return _c +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (_c *IdentityAdoptionDecisionCreate) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionCreate { + _c.mutation.SetAdoptAvatar(v) + return _c +} + +// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetAdoptAvatar(*v) + } + return _c +} + +// SetDecidedAt sets the "decided_at" field. +func (_c *IdentityAdoptionDecisionCreate) SetDecidedAt(v time.Time) *IdentityAdoptionDecisionCreate { + _c.mutation.SetDecidedAt(v) + return _c +} + +// SetNillableDecidedAt sets the "decided_at" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableDecidedAt(v *time.Time) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetDecidedAt(*v) + } + return _c +} + +// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity. +func (_c *IdentityAdoptionDecisionCreate) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionCreate { + return _c.SetPendingAuthSessionID(v.ID) +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_c *IdentityAdoptionDecisionCreate) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionCreate { + return _c.SetIdentityID(v.ID) +} + +// Mutation returns the IdentityAdoptionDecisionMutation object of the builder. +func (_c *IdentityAdoptionDecisionCreate) Mutation() *IdentityAdoptionDecisionMutation { + return _c.mutation +} + +// Save creates the IdentityAdoptionDecision in the database. +func (_c *IdentityAdoptionDecisionCreate) Save(ctx context.Context) (*IdentityAdoptionDecision, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *IdentityAdoptionDecisionCreate) SaveX(ctx context.Context) *IdentityAdoptionDecision { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *IdentityAdoptionDecisionCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *IdentityAdoptionDecisionCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *IdentityAdoptionDecisionCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := identityadoptiondecision.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := identityadoptiondecision.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.AdoptDisplayName(); !ok { + v := identityadoptiondecision.DefaultAdoptDisplayName + _c.mutation.SetAdoptDisplayName(v) + } + if _, ok := _c.mutation.AdoptAvatar(); !ok { + v := identityadoptiondecision.DefaultAdoptAvatar + _c.mutation.SetAdoptAvatar(v) + } + if _, ok := _c.mutation.DecidedAt(); !ok { + v := identityadoptiondecision.DefaultDecidedAt() + _c.mutation.SetDecidedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *IdentityAdoptionDecisionCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.updated_at"`)} + } + if _, ok := _c.mutation.PendingAuthSessionID(); !ok { + return &ValidationError{Name: "pending_auth_session_id", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.pending_auth_session_id"`)} + } + if _, ok := _c.mutation.AdoptDisplayName(); !ok { + return &ValidationError{Name: "adopt_display_name", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.adopt_display_name"`)} + } + if _, ok := _c.mutation.AdoptAvatar(); !ok { + return &ValidationError{Name: "adopt_avatar", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.adopt_avatar"`)} + } + if _, ok := _c.mutation.DecidedAt(); !ok { + return &ValidationError{Name: "decided_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.decided_at"`)} + } + if len(_c.mutation.PendingAuthSessionIDs()) == 0 { + return &ValidationError{Name: "pending_auth_session", err: errors.New(`ent: missing required edge "IdentityAdoptionDecision.pending_auth_session"`)} + } + return nil +} + +func (_c *IdentityAdoptionDecisionCreate) sqlSave(ctx context.Context) (*IdentityAdoptionDecision, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *IdentityAdoptionDecisionCreate) createSpec() (*IdentityAdoptionDecision, *sqlgraph.CreateSpec) { + var ( + _node = &IdentityAdoptionDecision{config: _c.config} + _spec = sqlgraph.NewCreateSpec(identityadoptiondecision.Table, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(identityadoptiondecision.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.AdoptDisplayName(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value) + _node.AdoptDisplayName = value + } + if value, ok := _c.mutation.AdoptAvatar(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value) + _node.AdoptAvatar = value + } + if value, ok := _c.mutation.DecidedAt(); ok { + _spec.SetField(identityadoptiondecision.FieldDecidedAt, field.TypeTime, value) + _node.DecidedAt = value + } + if nodes := _c.mutation.PendingAuthSessionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: true, + Table: identityadoptiondecision.PendingAuthSessionTable, + Columns: []string{identityadoptiondecision.PendingAuthSessionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.PendingAuthSessionID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: identityadoptiondecision.IdentityTable, + Columns: []string{identityadoptiondecision.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.IdentityID = &nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.IdentityAdoptionDecision.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.IdentityAdoptionDecisionUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *IdentityAdoptionDecisionCreate) OnConflict(opts ...sql.ConflictOption) *IdentityAdoptionDecisionUpsertOne { + _c.conflict = opts + return &IdentityAdoptionDecisionUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *IdentityAdoptionDecisionCreate) OnConflictColumns(columns ...string) *IdentityAdoptionDecisionUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &IdentityAdoptionDecisionUpsertOne{ + create: _c, + } +} + +type ( + // IdentityAdoptionDecisionUpsertOne is the builder for "upsert"-ing + // one IdentityAdoptionDecision node. + IdentityAdoptionDecisionUpsertOne struct { + create *IdentityAdoptionDecisionCreate + } + + // IdentityAdoptionDecisionUpsert is the "OnConflict" setter. + IdentityAdoptionDecisionUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdentityAdoptionDecisionUpsert) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsert { + u.Set(identityadoptiondecision.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsert) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsert { + u.SetExcluded(identityadoptiondecision.FieldUpdatedAt) + return u +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (u *IdentityAdoptionDecisionUpsert) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsert { + u.Set(identityadoptiondecision.FieldPendingAuthSessionID, v) + return u +} + +// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsert) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsert { + u.SetExcluded(identityadoptiondecision.FieldPendingAuthSessionID) + return u +} + +// SetIdentityID sets the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsert) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsert { + u.Set(identityadoptiondecision.FieldIdentityID, v) + return u +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsert) UpdateIdentityID() *IdentityAdoptionDecisionUpsert { + u.SetExcluded(identityadoptiondecision.FieldIdentityID) + return u +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsert) ClearIdentityID() *IdentityAdoptionDecisionUpsert { + u.SetNull(identityadoptiondecision.FieldIdentityID) + return u +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (u *IdentityAdoptionDecisionUpsert) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsert { + u.Set(identityadoptiondecision.FieldAdoptDisplayName, v) + return u +} + +// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsert) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsert { + u.SetExcluded(identityadoptiondecision.FieldAdoptDisplayName) + return u +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (u *IdentityAdoptionDecisionUpsert) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsert { + u.Set(identityadoptiondecision.FieldAdoptAvatar, v) + return u +} + +// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsert) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsert { + u.SetExcluded(identityadoptiondecision.FieldAdoptAvatar) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *IdentityAdoptionDecisionUpsertOne) UpdateNewValues() *IdentityAdoptionDecisionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(identityadoptiondecision.FieldCreatedAt) + } + if _, exists := u.create.mutation.DecidedAt(); exists { + s.SetIgnore(identityadoptiondecision.FieldDecidedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *IdentityAdoptionDecisionUpsertOne) Ignore() *IdentityAdoptionDecisionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *IdentityAdoptionDecisionUpsertOne) DoNothing() *IdentityAdoptionDecisionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the IdentityAdoptionDecisionCreate.OnConflict +// documentation for more info. +func (u *IdentityAdoptionDecisionUpsertOne) Update(set func(*IdentityAdoptionDecisionUpsert)) *IdentityAdoptionDecisionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&IdentityAdoptionDecisionUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdentityAdoptionDecisionUpsertOne) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertOne) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (u *IdentityAdoptionDecisionUpsertOne) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetPendingAuthSessionID(v) + }) +} + +// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertOne) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdatePendingAuthSessionID() + }) +} + +// SetIdentityID sets the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsertOne) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetIdentityID(v) + }) +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertOne) UpdateIdentityID() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateIdentityID() + }) +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsertOne) ClearIdentityID() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.ClearIdentityID() + }) +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (u *IdentityAdoptionDecisionUpsertOne) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetAdoptDisplayName(v) + }) +} + +// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertOne) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateAdoptDisplayName() + }) +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (u *IdentityAdoptionDecisionUpsertOne) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetAdoptAvatar(v) + }) +} + +// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertOne) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateAdoptAvatar() + }) +} + +// Exec executes the query. +func (u *IdentityAdoptionDecisionUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for IdentityAdoptionDecisionCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *IdentityAdoptionDecisionUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *IdentityAdoptionDecisionUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *IdentityAdoptionDecisionUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// IdentityAdoptionDecisionCreateBulk is the builder for creating many IdentityAdoptionDecision entities in bulk. +type IdentityAdoptionDecisionCreateBulk struct { + config + err error + builders []*IdentityAdoptionDecisionCreate + conflict []sql.ConflictOption +} + +// Save creates the IdentityAdoptionDecision entities in the database. +func (_c *IdentityAdoptionDecisionCreateBulk) Save(ctx context.Context) ([]*IdentityAdoptionDecision, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*IdentityAdoptionDecision, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*IdentityAdoptionDecisionMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *IdentityAdoptionDecisionCreateBulk) SaveX(ctx context.Context) []*IdentityAdoptionDecision { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *IdentityAdoptionDecisionCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *IdentityAdoptionDecisionCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.IdentityAdoptionDecision.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.IdentityAdoptionDecisionUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *IdentityAdoptionDecisionCreateBulk) OnConflict(opts ...sql.ConflictOption) *IdentityAdoptionDecisionUpsertBulk { + _c.conflict = opts + return &IdentityAdoptionDecisionUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *IdentityAdoptionDecisionCreateBulk) OnConflictColumns(columns ...string) *IdentityAdoptionDecisionUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &IdentityAdoptionDecisionUpsertBulk{ + create: _c, + } +} + +// IdentityAdoptionDecisionUpsertBulk is the builder for "upsert"-ing +// a bulk of IdentityAdoptionDecision nodes. +type IdentityAdoptionDecisionUpsertBulk struct { + create *IdentityAdoptionDecisionCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *IdentityAdoptionDecisionUpsertBulk) UpdateNewValues() *IdentityAdoptionDecisionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(identityadoptiondecision.FieldCreatedAt) + } + if _, exists := b.mutation.DecidedAt(); exists { + s.SetIgnore(identityadoptiondecision.FieldDecidedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *IdentityAdoptionDecisionUpsertBulk) Ignore() *IdentityAdoptionDecisionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *IdentityAdoptionDecisionUpsertBulk) DoNothing() *IdentityAdoptionDecisionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the IdentityAdoptionDecisionCreateBulk.OnConflict +// documentation for more info. +func (u *IdentityAdoptionDecisionUpsertBulk) Update(set func(*IdentityAdoptionDecisionUpsert)) *IdentityAdoptionDecisionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&IdentityAdoptionDecisionUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdentityAdoptionDecisionUpsertBulk) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertBulk) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (u *IdentityAdoptionDecisionUpsertBulk) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetPendingAuthSessionID(v) + }) +} + +// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertBulk) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdatePendingAuthSessionID() + }) +} + +// SetIdentityID sets the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsertBulk) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetIdentityID(v) + }) +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertBulk) UpdateIdentityID() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateIdentityID() + }) +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsertBulk) ClearIdentityID() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.ClearIdentityID() + }) +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (u *IdentityAdoptionDecisionUpsertBulk) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetAdoptDisplayName(v) + }) +} + +// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertBulk) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateAdoptDisplayName() + }) +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (u *IdentityAdoptionDecisionUpsertBulk) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetAdoptAvatar(v) + }) +} + +// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertBulk) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateAdoptAvatar() + }) +} + +// Exec executes the query. +func (u *IdentityAdoptionDecisionUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the IdentityAdoptionDecisionCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for IdentityAdoptionDecisionCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *IdentityAdoptionDecisionUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/identityadoptiondecision_delete.go b/backend/ent/identityadoptiondecision_delete.go new file mode 100644 index 00000000..ef3d328d --- /dev/null +++ b/backend/ent/identityadoptiondecision_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdentityAdoptionDecisionDelete is the builder for deleting a IdentityAdoptionDecision entity. +type IdentityAdoptionDecisionDelete struct { + config + hooks []Hook + mutation *IdentityAdoptionDecisionMutation +} + +// Where appends a list predicates to the IdentityAdoptionDecisionDelete builder. +func (_d *IdentityAdoptionDecisionDelete) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *IdentityAdoptionDecisionDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *IdentityAdoptionDecisionDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *IdentityAdoptionDecisionDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(identityadoptiondecision.Table, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// IdentityAdoptionDecisionDeleteOne is the builder for deleting a single IdentityAdoptionDecision entity. +type IdentityAdoptionDecisionDeleteOne struct { + _d *IdentityAdoptionDecisionDelete +} + +// Where appends a list predicates to the IdentityAdoptionDecisionDelete builder. +func (_d *IdentityAdoptionDecisionDeleteOne) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *IdentityAdoptionDecisionDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{identityadoptiondecision.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *IdentityAdoptionDecisionDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/identityadoptiondecision_query.go b/backend/ent/identityadoptiondecision_query.go new file mode 100644 index 00000000..4082d8ee --- /dev/null +++ b/backend/ent/identityadoptiondecision_query.go @@ -0,0 +1,721 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdentityAdoptionDecisionQuery is the builder for querying IdentityAdoptionDecision entities. +type IdentityAdoptionDecisionQuery struct { + config + ctx *QueryContext + order []identityadoptiondecision.OrderOption + inters []Interceptor + predicates []predicate.IdentityAdoptionDecision + withPendingAuthSession *PendingAuthSessionQuery + withIdentity *AuthIdentityQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the IdentityAdoptionDecisionQuery builder. +func (_q *IdentityAdoptionDecisionQuery) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *IdentityAdoptionDecisionQuery) Limit(limit int) *IdentityAdoptionDecisionQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *IdentityAdoptionDecisionQuery) Offset(offset int) *IdentityAdoptionDecisionQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *IdentityAdoptionDecisionQuery) Unique(unique bool) *IdentityAdoptionDecisionQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *IdentityAdoptionDecisionQuery) Order(o ...identityadoptiondecision.OrderOption) *IdentityAdoptionDecisionQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryPendingAuthSession chains the current query on the "pending_auth_session" edge. +func (_q *IdentityAdoptionDecisionQuery) QueryPendingAuthSession() *PendingAuthSessionQuery { + query := (&PendingAuthSessionClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, selector), + sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID), + sqlgraph.Edge(sqlgraph.O2O, true, identityadoptiondecision.PendingAuthSessionTable, identityadoptiondecision.PendingAuthSessionColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryIdentity chains the current query on the "identity" edge. +func (_q *IdentityAdoptionDecisionQuery) QueryIdentity() *AuthIdentityQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, selector), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, identityadoptiondecision.IdentityTable, identityadoptiondecision.IdentityColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first IdentityAdoptionDecision entity from the query. +// Returns a *NotFoundError when no IdentityAdoptionDecision was found. +func (_q *IdentityAdoptionDecisionQuery) First(ctx context.Context) (*IdentityAdoptionDecision, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{identityadoptiondecision.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) FirstX(ctx context.Context) *IdentityAdoptionDecision { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first IdentityAdoptionDecision ID from the query. +// Returns a *NotFoundError when no IdentityAdoptionDecision ID was found. +func (_q *IdentityAdoptionDecisionQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{identityadoptiondecision.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single IdentityAdoptionDecision entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one IdentityAdoptionDecision entity is found. +// Returns a *NotFoundError when no IdentityAdoptionDecision entities are found. +func (_q *IdentityAdoptionDecisionQuery) Only(ctx context.Context) (*IdentityAdoptionDecision, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{identityadoptiondecision.Label} + default: + return nil, &NotSingularError{identityadoptiondecision.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) OnlyX(ctx context.Context) *IdentityAdoptionDecision { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only IdentityAdoptionDecision ID in the query. +// Returns a *NotSingularError when more than one IdentityAdoptionDecision ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *IdentityAdoptionDecisionQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{identityadoptiondecision.Label} + default: + err = &NotSingularError{identityadoptiondecision.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of IdentityAdoptionDecisions. +func (_q *IdentityAdoptionDecisionQuery) All(ctx context.Context) ([]*IdentityAdoptionDecision, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*IdentityAdoptionDecision, *IdentityAdoptionDecisionQuery]() + return withInterceptors[[]*IdentityAdoptionDecision](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) AllX(ctx context.Context) []*IdentityAdoptionDecision { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of IdentityAdoptionDecision IDs. +func (_q *IdentityAdoptionDecisionQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(identityadoptiondecision.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *IdentityAdoptionDecisionQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*IdentityAdoptionDecisionQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *IdentityAdoptionDecisionQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the IdentityAdoptionDecisionQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *IdentityAdoptionDecisionQuery) Clone() *IdentityAdoptionDecisionQuery { + if _q == nil { + return nil + } + return &IdentityAdoptionDecisionQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]identityadoptiondecision.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.IdentityAdoptionDecision{}, _q.predicates...), + withPendingAuthSession: _q.withPendingAuthSession.Clone(), + withIdentity: _q.withIdentity.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithPendingAuthSession tells the query-builder to eager-load the nodes that are connected to +// the "pending_auth_session" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *IdentityAdoptionDecisionQuery) WithPendingAuthSession(opts ...func(*PendingAuthSessionQuery)) *IdentityAdoptionDecisionQuery { + query := (&PendingAuthSessionClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withPendingAuthSession = query + return _q +} + +// WithIdentity tells the query-builder to eager-load the nodes that are connected to +// the "identity" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *IdentityAdoptionDecisionQuery) WithIdentity(opts ...func(*AuthIdentityQuery)) *IdentityAdoptionDecisionQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withIdentity = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.IdentityAdoptionDecision.Query(). +// GroupBy(identityadoptiondecision.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *IdentityAdoptionDecisionQuery) GroupBy(field string, fields ...string) *IdentityAdoptionDecisionGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &IdentityAdoptionDecisionGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = identityadoptiondecision.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.IdentityAdoptionDecision.Query(). +// Select(identityadoptiondecision.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *IdentityAdoptionDecisionQuery) Select(fields ...string) *IdentityAdoptionDecisionSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &IdentityAdoptionDecisionSelect{IdentityAdoptionDecisionQuery: _q} + sbuild.label = identityadoptiondecision.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a IdentityAdoptionDecisionSelect configured with the given aggregations. +func (_q *IdentityAdoptionDecisionQuery) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *IdentityAdoptionDecisionQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !identityadoptiondecision.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *IdentityAdoptionDecisionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*IdentityAdoptionDecision, error) { + var ( + nodes = []*IdentityAdoptionDecision{} + _spec = _q.querySpec() + loadedTypes = [2]bool{ + _q.withPendingAuthSession != nil, + _q.withIdentity != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*IdentityAdoptionDecision).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &IdentityAdoptionDecision{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withPendingAuthSession; query != nil { + if err := _q.loadPendingAuthSession(ctx, query, nodes, nil, + func(n *IdentityAdoptionDecision, e *PendingAuthSession) { n.Edges.PendingAuthSession = e }); err != nil { + return nil, err + } + } + if query := _q.withIdentity; query != nil { + if err := _q.loadIdentity(ctx, query, nodes, nil, + func(n *IdentityAdoptionDecision, e *AuthIdentity) { n.Edges.Identity = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *IdentityAdoptionDecisionQuery) loadPendingAuthSession(ctx context.Context, query *PendingAuthSessionQuery, nodes []*IdentityAdoptionDecision, init func(*IdentityAdoptionDecision), assign func(*IdentityAdoptionDecision, *PendingAuthSession)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*IdentityAdoptionDecision) + for i := range nodes { + fk := nodes[i].PendingAuthSessionID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(pendingauthsession.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "pending_auth_session_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (_q *IdentityAdoptionDecisionQuery) loadIdentity(ctx context.Context, query *AuthIdentityQuery, nodes []*IdentityAdoptionDecision, init func(*IdentityAdoptionDecision), assign func(*IdentityAdoptionDecision, *AuthIdentity)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*IdentityAdoptionDecision) + for i := range nodes { + if nodes[i].IdentityID == nil { + continue + } + fk := *nodes[i].IdentityID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(authidentity.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "identity_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (_q *IdentityAdoptionDecisionQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *IdentityAdoptionDecisionQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, identityadoptiondecision.FieldID) + for i := range fields { + if fields[i] != identityadoptiondecision.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withPendingAuthSession != nil { + _spec.Node.AddColumnOnce(identityadoptiondecision.FieldPendingAuthSessionID) + } + if _q.withIdentity != nil { + _spec.Node.AddColumnOnce(identityadoptiondecision.FieldIdentityID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *IdentityAdoptionDecisionQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(identityadoptiondecision.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = identityadoptiondecision.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *IdentityAdoptionDecisionQuery) ForUpdate(opts ...sql.LockOption) *IdentityAdoptionDecisionQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *IdentityAdoptionDecisionQuery) ForShare(opts ...sql.LockOption) *IdentityAdoptionDecisionQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// IdentityAdoptionDecisionGroupBy is the group-by builder for IdentityAdoptionDecision entities. +type IdentityAdoptionDecisionGroupBy struct { + selector + build *IdentityAdoptionDecisionQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *IdentityAdoptionDecisionGroupBy) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *IdentityAdoptionDecisionGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*IdentityAdoptionDecisionQuery, *IdentityAdoptionDecisionGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *IdentityAdoptionDecisionGroupBy) sqlScan(ctx context.Context, root *IdentityAdoptionDecisionQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// IdentityAdoptionDecisionSelect is the builder for selecting fields of IdentityAdoptionDecision entities. +type IdentityAdoptionDecisionSelect struct { + *IdentityAdoptionDecisionQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *IdentityAdoptionDecisionSelect) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *IdentityAdoptionDecisionSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*IdentityAdoptionDecisionQuery, *IdentityAdoptionDecisionSelect](ctx, _s.IdentityAdoptionDecisionQuery, _s, _s.inters, v) +} + +func (_s *IdentityAdoptionDecisionSelect) sqlScan(ctx context.Context, root *IdentityAdoptionDecisionQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/identityadoptiondecision_update.go b/backend/ent/identityadoptiondecision_update.go new file mode 100644 index 00000000..0ca21d27 --- /dev/null +++ b/backend/ent/identityadoptiondecision_update.go @@ -0,0 +1,532 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdentityAdoptionDecisionUpdate is the builder for updating IdentityAdoptionDecision entities. +type IdentityAdoptionDecisionUpdate struct { + config + hooks []Hook + mutation *IdentityAdoptionDecisionMutation +} + +// Where appends a list predicates to the IdentityAdoptionDecisionUpdate builder. +func (_u *IdentityAdoptionDecisionUpdate) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *IdentityAdoptionDecisionUpdate) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (_u *IdentityAdoptionDecisionUpdate) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpdate { + _u.mutation.SetPendingAuthSessionID(v) + return _u +} + +// SetNillablePendingAuthSessionID sets the "pending_auth_session_id" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdate) SetNillablePendingAuthSessionID(v *int64) *IdentityAdoptionDecisionUpdate { + if v != nil { + _u.SetPendingAuthSessionID(*v) + } + return _u +} + +// SetIdentityID sets the "identity_id" field. +func (_u *IdentityAdoptionDecisionUpdate) SetIdentityID(v int64) *IdentityAdoptionDecisionUpdate { + _u.mutation.SetIdentityID(v) + return _u +} + +// SetNillableIdentityID sets the "identity_id" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdate) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionUpdate { + if v != nil { + _u.SetIdentityID(*v) + } + return _u +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (_u *IdentityAdoptionDecisionUpdate) ClearIdentityID() *IdentityAdoptionDecisionUpdate { + _u.mutation.ClearIdentityID() + return _u +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (_u *IdentityAdoptionDecisionUpdate) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpdate { + _u.mutation.SetAdoptDisplayName(v) + return _u +} + +// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdate) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionUpdate { + if v != nil { + _u.SetAdoptDisplayName(*v) + } + return _u +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (_u *IdentityAdoptionDecisionUpdate) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpdate { + _u.mutation.SetAdoptAvatar(v) + return _u +} + +// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdate) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionUpdate { + if v != nil { + _u.SetAdoptAvatar(*v) + } + return _u +} + +// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity. +func (_u *IdentityAdoptionDecisionUpdate) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionUpdate { + return _u.SetPendingAuthSessionID(v.ID) +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_u *IdentityAdoptionDecisionUpdate) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionUpdate { + return _u.SetIdentityID(v.ID) +} + +// Mutation returns the IdentityAdoptionDecisionMutation object of the builder. +func (_u *IdentityAdoptionDecisionUpdate) Mutation() *IdentityAdoptionDecisionMutation { + return _u.mutation +} + +// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity. +func (_u *IdentityAdoptionDecisionUpdate) ClearPendingAuthSession() *IdentityAdoptionDecisionUpdate { + _u.mutation.ClearPendingAuthSession() + return _u +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (_u *IdentityAdoptionDecisionUpdate) ClearIdentity() *IdentityAdoptionDecisionUpdate { + _u.mutation.ClearIdentity() + return _u +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *IdentityAdoptionDecisionUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *IdentityAdoptionDecisionUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *IdentityAdoptionDecisionUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *IdentityAdoptionDecisionUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *IdentityAdoptionDecisionUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := identityadoptiondecision.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *IdentityAdoptionDecisionUpdate) check() error { + if _u.mutation.PendingAuthSessionCleared() && len(_u.mutation.PendingAuthSessionIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "IdentityAdoptionDecision.pending_auth_session"`) + } + return nil +} + +func (_u *IdentityAdoptionDecisionUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.AdoptDisplayName(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value) + } + if value, ok := _u.mutation.AdoptAvatar(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value) + } + if _u.mutation.PendingAuthSessionCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: true, + Table: identityadoptiondecision.PendingAuthSessionTable, + Columns: []string{identityadoptiondecision.PendingAuthSessionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.PendingAuthSessionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: true, + Table: identityadoptiondecision.PendingAuthSessionTable, + Columns: []string{identityadoptiondecision.PendingAuthSessionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.IdentityCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: identityadoptiondecision.IdentityTable, + Columns: []string{identityadoptiondecision.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: identityadoptiondecision.IdentityTable, + Columns: []string{identityadoptiondecision.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{identityadoptiondecision.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// IdentityAdoptionDecisionUpdateOne is the builder for updating a single IdentityAdoptionDecision entity. +type IdentityAdoptionDecisionUpdateOne struct { + config + fields []string + hooks []Hook + mutation *IdentityAdoptionDecisionMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *IdentityAdoptionDecisionUpdateOne) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (_u *IdentityAdoptionDecisionUpdateOne) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.SetPendingAuthSessionID(v) + return _u +} + +// SetNillablePendingAuthSessionID sets the "pending_auth_session_id" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdateOne) SetNillablePendingAuthSessionID(v *int64) *IdentityAdoptionDecisionUpdateOne { + if v != nil { + _u.SetPendingAuthSessionID(*v) + } + return _u +} + +// SetIdentityID sets the "identity_id" field. +func (_u *IdentityAdoptionDecisionUpdateOne) SetIdentityID(v int64) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.SetIdentityID(v) + return _u +} + +// SetNillableIdentityID sets the "identity_id" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionUpdateOne { + if v != nil { + _u.SetIdentityID(*v) + } + return _u +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (_u *IdentityAdoptionDecisionUpdateOne) ClearIdentityID() *IdentityAdoptionDecisionUpdateOne { + _u.mutation.ClearIdentityID() + return _u +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (_u *IdentityAdoptionDecisionUpdateOne) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.SetAdoptDisplayName(v) + return _u +} + +// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionUpdateOne { + if v != nil { + _u.SetAdoptDisplayName(*v) + } + return _u +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (_u *IdentityAdoptionDecisionUpdateOne) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.SetAdoptAvatar(v) + return _u +} + +// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionUpdateOne { + if v != nil { + _u.SetAdoptAvatar(*v) + } + return _u +} + +// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity. +func (_u *IdentityAdoptionDecisionUpdateOne) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionUpdateOne { + return _u.SetPendingAuthSessionID(v.ID) +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_u *IdentityAdoptionDecisionUpdateOne) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionUpdateOne { + return _u.SetIdentityID(v.ID) +} + +// Mutation returns the IdentityAdoptionDecisionMutation object of the builder. +func (_u *IdentityAdoptionDecisionUpdateOne) Mutation() *IdentityAdoptionDecisionMutation { + return _u.mutation +} + +// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity. +func (_u *IdentityAdoptionDecisionUpdateOne) ClearPendingAuthSession() *IdentityAdoptionDecisionUpdateOne { + _u.mutation.ClearPendingAuthSession() + return _u +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (_u *IdentityAdoptionDecisionUpdateOne) ClearIdentity() *IdentityAdoptionDecisionUpdateOne { + _u.mutation.ClearIdentity() + return _u +} + +// Where appends a list predicates to the IdentityAdoptionDecisionUpdate builder. +func (_u *IdentityAdoptionDecisionUpdateOne) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *IdentityAdoptionDecisionUpdateOne) Select(field string, fields ...string) *IdentityAdoptionDecisionUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated IdentityAdoptionDecision entity. +func (_u *IdentityAdoptionDecisionUpdateOne) Save(ctx context.Context) (*IdentityAdoptionDecision, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *IdentityAdoptionDecisionUpdateOne) SaveX(ctx context.Context) *IdentityAdoptionDecision { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *IdentityAdoptionDecisionUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *IdentityAdoptionDecisionUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *IdentityAdoptionDecisionUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := identityadoptiondecision.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *IdentityAdoptionDecisionUpdateOne) check() error { + if _u.mutation.PendingAuthSessionCleared() && len(_u.mutation.PendingAuthSessionIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "IdentityAdoptionDecision.pending_auth_session"`) + } + return nil +} + +func (_u *IdentityAdoptionDecisionUpdateOne) sqlSave(ctx context.Context) (_node *IdentityAdoptionDecision, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "IdentityAdoptionDecision.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, identityadoptiondecision.FieldID) + for _, f := range fields { + if !identityadoptiondecision.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != identityadoptiondecision.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.AdoptDisplayName(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value) + } + if value, ok := _u.mutation.AdoptAvatar(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value) + } + if _u.mutation.PendingAuthSessionCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: true, + Table: identityadoptiondecision.PendingAuthSessionTable, + Columns: []string{identityadoptiondecision.PendingAuthSessionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.PendingAuthSessionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: true, + Table: identityadoptiondecision.PendingAuthSessionTable, + Columns: []string{identityadoptiondecision.PendingAuthSessionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.IdentityCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: identityadoptiondecision.IdentityTable, + Columns: []string{identityadoptiondecision.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: identityadoptiondecision.IdentityTable, + Columns: []string{identityadoptiondecision.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &IdentityAdoptionDecision{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{identityadoptiondecision.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go index 8d8320bb..157c5122 100644 --- a/backend/ent/intercept/intercept.go +++ b/backend/ent/intercept/intercept.go @@ -13,12 +13,16 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" @@ -228,6 +232,60 @@ func (f TraverseAnnouncementRead) Traverse(ctx context.Context, q ent.Query) err return fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementReadQuery", q) } +// The AuthIdentityFunc type is an adapter to allow the use of ordinary function as a Querier. +type AuthIdentityFunc func(context.Context, *ent.AuthIdentityQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f AuthIdentityFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.AuthIdentityQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityQuery", q) +} + +// The TraverseAuthIdentity type is an adapter to allow the use of ordinary function as Traverser. +type TraverseAuthIdentity func(context.Context, *ent.AuthIdentityQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseAuthIdentity) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseAuthIdentity) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.AuthIdentityQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityQuery", q) +} + +// The AuthIdentityChannelFunc type is an adapter to allow the use of ordinary function as a Querier. +type AuthIdentityChannelFunc func(context.Context, *ent.AuthIdentityChannelQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f AuthIdentityChannelFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.AuthIdentityChannelQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityChannelQuery", q) +} + +// The TraverseAuthIdentityChannel type is an adapter to allow the use of ordinary function as Traverser. +type TraverseAuthIdentityChannel func(context.Context, *ent.AuthIdentityChannelQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseAuthIdentityChannel) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseAuthIdentityChannel) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.AuthIdentityChannelQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityChannelQuery", q) +} + // The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary function as a Querier. type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleQuery) (ent.Value, error) @@ -309,6 +367,33 @@ func (f TraverseIdempotencyRecord) Traverse(ctx context.Context, q ent.Query) er return fmt.Errorf("unexpected query type %T. expect *ent.IdempotencyRecordQuery", q) } +// The IdentityAdoptionDecisionFunc type is an adapter to allow the use of ordinary function as a Querier. +type IdentityAdoptionDecisionFunc func(context.Context, *ent.IdentityAdoptionDecisionQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f IdentityAdoptionDecisionFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.IdentityAdoptionDecisionQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.IdentityAdoptionDecisionQuery", q) +} + +// The TraverseIdentityAdoptionDecision type is an adapter to allow the use of ordinary function as Traverser. +type TraverseIdentityAdoptionDecision func(context.Context, *ent.IdentityAdoptionDecisionQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseIdentityAdoptionDecision) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseIdentityAdoptionDecision) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.IdentityAdoptionDecisionQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.IdentityAdoptionDecisionQuery", q) +} + // The PaymentAuditLogFunc type is an adapter to allow the use of ordinary function as a Querier. type PaymentAuditLogFunc func(context.Context, *ent.PaymentAuditLogQuery) (ent.Value, error) @@ -390,6 +475,33 @@ func (f TraversePaymentProviderInstance) Traverse(ctx context.Context, q ent.Que return fmt.Errorf("unexpected query type %T. expect *ent.PaymentProviderInstanceQuery", q) } +// The PendingAuthSessionFunc type is an adapter to allow the use of ordinary function as a Querier. +type PendingAuthSessionFunc func(context.Context, *ent.PendingAuthSessionQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f PendingAuthSessionFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.PendingAuthSessionQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.PendingAuthSessionQuery", q) +} + +// The TraversePendingAuthSession type is an adapter to allow the use of ordinary function as Traverser. +type TraversePendingAuthSession func(context.Context, *ent.PendingAuthSessionQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraversePendingAuthSession) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraversePendingAuthSession) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.PendingAuthSessionQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.PendingAuthSessionQuery", q) +} + // The PromoCodeFunc type is an adapter to allow the use of ordinary function as a Querier. type PromoCodeFunc func(context.Context, *ent.PromoCodeQuery) (ent.Value, error) @@ -808,18 +920,26 @@ func NewQuery(q ent.Query) (Query, error) { return &query[*ent.AnnouncementQuery, predicate.Announcement, announcement.OrderOption]{typ: ent.TypeAnnouncement, tq: q}, nil case *ent.AnnouncementReadQuery: return &query[*ent.AnnouncementReadQuery, predicate.AnnouncementRead, announcementread.OrderOption]{typ: ent.TypeAnnouncementRead, tq: q}, nil + case *ent.AuthIdentityQuery: + return &query[*ent.AuthIdentityQuery, predicate.AuthIdentity, authidentity.OrderOption]{typ: ent.TypeAuthIdentity, tq: q}, nil + case *ent.AuthIdentityChannelQuery: + return &query[*ent.AuthIdentityChannelQuery, predicate.AuthIdentityChannel, authidentitychannel.OrderOption]{typ: ent.TypeAuthIdentityChannel, tq: q}, nil case *ent.ErrorPassthroughRuleQuery: return &query[*ent.ErrorPassthroughRuleQuery, predicate.ErrorPassthroughRule, errorpassthroughrule.OrderOption]{typ: ent.TypeErrorPassthroughRule, tq: q}, nil case *ent.GroupQuery: return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil case *ent.IdempotencyRecordQuery: return &query[*ent.IdempotencyRecordQuery, predicate.IdempotencyRecord, idempotencyrecord.OrderOption]{typ: ent.TypeIdempotencyRecord, tq: q}, nil + case *ent.IdentityAdoptionDecisionQuery: + return &query[*ent.IdentityAdoptionDecisionQuery, predicate.IdentityAdoptionDecision, identityadoptiondecision.OrderOption]{typ: ent.TypeIdentityAdoptionDecision, tq: q}, nil case *ent.PaymentAuditLogQuery: return &query[*ent.PaymentAuditLogQuery, predicate.PaymentAuditLog, paymentauditlog.OrderOption]{typ: ent.TypePaymentAuditLog, tq: q}, nil case *ent.PaymentOrderQuery: return &query[*ent.PaymentOrderQuery, predicate.PaymentOrder, paymentorder.OrderOption]{typ: ent.TypePaymentOrder, tq: q}, nil case *ent.PaymentProviderInstanceQuery: return &query[*ent.PaymentProviderInstanceQuery, predicate.PaymentProviderInstance, paymentproviderinstance.OrderOption]{typ: ent.TypePaymentProviderInstance, tq: q}, nil + case *ent.PendingAuthSessionQuery: + return &query[*ent.PendingAuthSessionQuery, predicate.PendingAuthSession, pendingauthsession.OrderOption]{typ: ent.TypePendingAuthSession, tq: q}, nil case *ent.PromoCodeQuery: return &query[*ent.PromoCodeQuery, predicate.PromoCode, promocode.OrderOption]{typ: ent.TypePromoCode, tq: q}, nil case *ent.PromoCodeUsageQuery: diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 68bdbf55..bf41e73b 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -338,6 +338,89 @@ var ( }, }, } + // AuthIdentitiesColumns holds the columns for the "auth_identities" table. + AuthIdentitiesColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "provider_type", Type: field.TypeString, Size: 20}, + {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "provider_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "issuer", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "metadata", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "user_id", Type: field.TypeInt64}, + } + // AuthIdentitiesTable holds the schema information for the "auth_identities" table. + AuthIdentitiesTable = &schema.Table{ + Name: "auth_identities", + Columns: AuthIdentitiesColumns, + PrimaryKey: []*schema.Column{AuthIdentitiesColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "auth_identities_users_auth_identities", + Columns: []*schema.Column{AuthIdentitiesColumns[9]}, + RefColumns: []*schema.Column{UsersColumns[0]}, + OnDelete: schema.NoAction, + }, + }, + Indexes: []*schema.Index{ + { + Name: "authidentity_provider_type_provider_key_provider_subject", + Unique: true, + Columns: []*schema.Column{AuthIdentitiesColumns[3], AuthIdentitiesColumns[4], AuthIdentitiesColumns[5]}, + }, + { + Name: "authidentity_user_id", + Unique: false, + Columns: []*schema.Column{AuthIdentitiesColumns[9]}, + }, + { + Name: "authidentity_user_id_provider_type", + Unique: false, + Columns: []*schema.Column{AuthIdentitiesColumns[9], AuthIdentitiesColumns[3]}, + }, + }, + } + // AuthIdentityChannelsColumns holds the columns for the "auth_identity_channels" table. + AuthIdentityChannelsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "provider_type", Type: field.TypeString, Size: 20}, + {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "channel", Type: field.TypeString, Size: 20}, + {Name: "channel_app_id", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "channel_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "metadata", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "identity_id", Type: field.TypeInt64}, + } + // AuthIdentityChannelsTable holds the schema information for the "auth_identity_channels" table. + AuthIdentityChannelsTable = &schema.Table{ + Name: "auth_identity_channels", + Columns: AuthIdentityChannelsColumns, + PrimaryKey: []*schema.Column{AuthIdentityChannelsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "auth_identity_channels_auth_identities_channels", + Columns: []*schema.Column{AuthIdentityChannelsColumns[9]}, + RefColumns: []*schema.Column{AuthIdentitiesColumns[0]}, + OnDelete: schema.NoAction, + }, + }, + Indexes: []*schema.Index{ + { + Name: "authidentitychannel_provider_type_provider_key_channel_channel_app_id_channel_subject", + Unique: true, + Columns: []*schema.Column{AuthIdentityChannelsColumns[3], AuthIdentityChannelsColumns[4], AuthIdentityChannelsColumns[5], AuthIdentityChannelsColumns[6], AuthIdentityChannelsColumns[7]}, + }, + { + Name: "authidentitychannel_identity_id", + Unique: false, + Columns: []*schema.Column{AuthIdentityChannelsColumns[9]}, + }, + }, + } // ErrorPassthroughRulesColumns holds the columns for the "error_passthrough_rules" table. ErrorPassthroughRulesColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, @@ -485,6 +568,49 @@ var ( }, }, } + // IdentityAdoptionDecisionsColumns holds the columns for the "identity_adoption_decisions" table. + IdentityAdoptionDecisionsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "adopt_display_name", Type: field.TypeBool, Default: false}, + {Name: "adopt_avatar", Type: field.TypeBool, Default: false}, + {Name: "decided_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "identity_id", Type: field.TypeInt64, Nullable: true}, + {Name: "pending_auth_session_id", Type: field.TypeInt64, Unique: true}, + } + // IdentityAdoptionDecisionsTable holds the schema information for the "identity_adoption_decisions" table. + IdentityAdoptionDecisionsTable = &schema.Table{ + Name: "identity_adoption_decisions", + Columns: IdentityAdoptionDecisionsColumns, + PrimaryKey: []*schema.Column{IdentityAdoptionDecisionsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "identity_adoption_decisions_auth_identities_adoption_decisions", + Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[6]}, + RefColumns: []*schema.Column{AuthIdentitiesColumns[0]}, + OnDelete: schema.SetNull, + }, + { + Symbol: "identity_adoption_decisions_pending_auth_sessions_adoption_decision", + Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[7]}, + RefColumns: []*schema.Column{PendingAuthSessionsColumns[0]}, + OnDelete: schema.NoAction, + }, + }, + Indexes: []*schema.Index{ + { + Name: "identityadoptiondecision_pending_auth_session_id", + Unique: true, + Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[7]}, + }, + { + Name: "identityadoptiondecision_identity_id", + Unique: false, + Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[6]}, + }, + }, + } // PaymentAuditLogsColumns holds the columns for the "payment_audit_logs" table. PaymentAuditLogsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, @@ -638,6 +764,72 @@ var ( }, }, } + // PendingAuthSessionsColumns holds the columns for the "pending_auth_sessions" table. + PendingAuthSessionsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "session_token", Type: field.TypeString, Size: 255}, + {Name: "intent", Type: field.TypeString, Size: 40}, + {Name: "provider_type", Type: field.TypeString, Size: 20}, + {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "provider_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "redirect_to", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "resolved_email", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "registration_password_hash", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "upstream_identity_claims", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "local_flow_state", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "browser_session_key", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "completion_code_hash", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "completion_code_expires_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "email_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "password_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "totp_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "expires_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "consumed_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "target_user_id", Type: field.TypeInt64, Nullable: true}, + } + // PendingAuthSessionsTable holds the schema information for the "pending_auth_sessions" table. + PendingAuthSessionsTable = &schema.Table{ + Name: "pending_auth_sessions", + Columns: PendingAuthSessionsColumns, + PrimaryKey: []*schema.Column{PendingAuthSessionsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "pending_auth_sessions_users_pending_auth_sessions", + Columns: []*schema.Column{PendingAuthSessionsColumns[21]}, + RefColumns: []*schema.Column{UsersColumns[0]}, + OnDelete: schema.SetNull, + }, + }, + Indexes: []*schema.Index{ + { + Name: "pendingauthsession_session_token", + Unique: true, + Columns: []*schema.Column{PendingAuthSessionsColumns[3]}, + }, + { + Name: "pendingauthsession_target_user_id", + Unique: false, + Columns: []*schema.Column{PendingAuthSessionsColumns[21]}, + }, + { + Name: "pendingauthsession_expires_at", + Unique: false, + Columns: []*schema.Column{PendingAuthSessionsColumns[19]}, + }, + { + Name: "pendingauthsession_provider_type_provider_key_provider_subject", + Unique: false, + Columns: []*schema.Column{PendingAuthSessionsColumns[5], PendingAuthSessionsColumns[6], PendingAuthSessionsColumns[7]}, + }, + { + Name: "pendingauthsession_completion_code_hash", + Unique: false, + Columns: []*schema.Column{PendingAuthSessionsColumns[14]}, + }, + }, + } // PromoCodesColumns holds the columns for the "promo_codes" table. PromoCodesColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, @@ -1079,6 +1271,9 @@ var ( {Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "totp_enabled", Type: field.TypeBool, Default: false}, {Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true}, + {Name: "signup_source", Type: field.TypeString, Size: 20, Default: "email"}, + {Name: "last_login_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "last_active_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "balance_notify_enabled", Type: field.TypeBool, Default: true}, {Name: "balance_notify_threshold_type", Type: field.TypeString, Default: "fixed"}, {Name: "balance_notify_threshold", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, @@ -1318,12 +1513,16 @@ var ( AccountGroupsTable, AnnouncementsTable, AnnouncementReadsTable, + AuthIdentitiesTable, + AuthIdentityChannelsTable, ErrorPassthroughRulesTable, GroupsTable, IdempotencyRecordsTable, + IdentityAdoptionDecisionsTable, PaymentAuditLogsTable, PaymentOrdersTable, PaymentProviderInstancesTable, + PendingAuthSessionsTable, PromoCodesTable, PromoCodeUsagesTable, ProxiesTable, @@ -1365,6 +1564,14 @@ func init() { AnnouncementReadsTable.Annotation = &entsql.Annotation{ Table: "announcement_reads", } + AuthIdentitiesTable.ForeignKeys[0].RefTable = UsersTable + AuthIdentitiesTable.Annotation = &entsql.Annotation{ + Table: "auth_identities", + } + AuthIdentityChannelsTable.ForeignKeys[0].RefTable = AuthIdentitiesTable + AuthIdentityChannelsTable.Annotation = &entsql.Annotation{ + Table: "auth_identity_channels", + } ErrorPassthroughRulesTable.Annotation = &entsql.Annotation{ Table: "error_passthrough_rules", } @@ -1374,6 +1581,11 @@ func init() { IdempotencyRecordsTable.Annotation = &entsql.Annotation{ Table: "idempotency_records", } + IdentityAdoptionDecisionsTable.ForeignKeys[0].RefTable = AuthIdentitiesTable + IdentityAdoptionDecisionsTable.ForeignKeys[1].RefTable = PendingAuthSessionsTable + IdentityAdoptionDecisionsTable.Annotation = &entsql.Annotation{ + Table: "identity_adoption_decisions", + } PaymentAuditLogsTable.Annotation = &entsql.Annotation{ Table: "payment_audit_logs", } @@ -1384,6 +1596,10 @@ func init() { PaymentProviderInstancesTable.Annotation = &entsql.Annotation{ Table: "payment_provider_instances", } + PendingAuthSessionsTable.ForeignKeys[0].RefTable = UsersTable + PendingAuthSessionsTable.Annotation = &entsql.Annotation{ + Table: "pending_auth_sessions", + } PromoCodesTable.Annotation = &entsql.Annotation{ Table: "promo_codes", } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 524ccb92..12905c9a 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -17,12 +17,16 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" @@ -51,32 +55,36 @@ const ( OpUpdateOne = ent.OpUpdateOne // Node types. - TypeAPIKey = "APIKey" - TypeAccount = "Account" - TypeAccountGroup = "AccountGroup" - TypeAnnouncement = "Announcement" - TypeAnnouncementRead = "AnnouncementRead" - TypeErrorPassthroughRule = "ErrorPassthroughRule" - TypeGroup = "Group" - TypeIdempotencyRecord = "IdempotencyRecord" - TypePaymentAuditLog = "PaymentAuditLog" - TypePaymentOrder = "PaymentOrder" - TypePaymentProviderInstance = "PaymentProviderInstance" - TypePromoCode = "PromoCode" - TypePromoCodeUsage = "PromoCodeUsage" - TypeProxy = "Proxy" - TypeRedeemCode = "RedeemCode" - TypeSecuritySecret = "SecuritySecret" - TypeSetting = "Setting" - TypeSubscriptionPlan = "SubscriptionPlan" - TypeTLSFingerprintProfile = "TLSFingerprintProfile" - TypeUsageCleanupTask = "UsageCleanupTask" - TypeUsageLog = "UsageLog" - TypeUser = "User" - TypeUserAllowedGroup = "UserAllowedGroup" - TypeUserAttributeDefinition = "UserAttributeDefinition" - TypeUserAttributeValue = "UserAttributeValue" - TypeUserSubscription = "UserSubscription" + TypeAPIKey = "APIKey" + TypeAccount = "Account" + TypeAccountGroup = "AccountGroup" + TypeAnnouncement = "Announcement" + TypeAnnouncementRead = "AnnouncementRead" + TypeAuthIdentity = "AuthIdentity" + TypeAuthIdentityChannel = "AuthIdentityChannel" + TypeErrorPassthroughRule = "ErrorPassthroughRule" + TypeGroup = "Group" + TypeIdempotencyRecord = "IdempotencyRecord" + TypeIdentityAdoptionDecision = "IdentityAdoptionDecision" + TypePaymentAuditLog = "PaymentAuditLog" + TypePaymentOrder = "PaymentOrder" + TypePaymentProviderInstance = "PaymentProviderInstance" + TypePendingAuthSession = "PendingAuthSession" + TypePromoCode = "PromoCode" + TypePromoCodeUsage = "PromoCodeUsage" + TypeProxy = "Proxy" + TypeRedeemCode = "RedeemCode" + TypeSecuritySecret = "SecuritySecret" + TypeSetting = "Setting" + TypeSubscriptionPlan = "SubscriptionPlan" + TypeTLSFingerprintProfile = "TLSFingerprintProfile" + TypeUsageCleanupTask = "UsageCleanupTask" + TypeUsageLog = "UsageLog" + TypeUser = "User" + TypeUserAllowedGroup = "UserAllowedGroup" + TypeUserAttributeDefinition = "UserAttributeDefinition" + TypeUserAttributeValue = "UserAttributeValue" + TypeUserSubscription = "UserSubscription" ) // APIKeyMutation represents an operation that mutates the APIKey nodes in the graph. @@ -6887,6 +6895,1845 @@ func (m *AnnouncementReadMutation) ResetEdge(name string) error { return fmt.Errorf("unknown AnnouncementRead edge %s", name) } +// AuthIdentityMutation represents an operation that mutates the AuthIdentity nodes in the graph. +type AuthIdentityMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + provider_type *string + provider_key *string + provider_subject *string + verified_at *time.Time + issuer *string + metadata *map[string]interface{} + clearedFields map[string]struct{} + user *int64 + cleareduser bool + channels map[int64]struct{} + removedchannels map[int64]struct{} + clearedchannels bool + adoption_decisions map[int64]struct{} + removedadoption_decisions map[int64]struct{} + clearedadoption_decisions bool + done bool + oldValue func(context.Context) (*AuthIdentity, error) + predicates []predicate.AuthIdentity +} + +var _ ent.Mutation = (*AuthIdentityMutation)(nil) + +// authidentityOption allows management of the mutation configuration using functional options. +type authidentityOption func(*AuthIdentityMutation) + +// newAuthIdentityMutation creates new mutation for the AuthIdentity entity. +func newAuthIdentityMutation(c config, op Op, opts ...authidentityOption) *AuthIdentityMutation { + m := &AuthIdentityMutation{ + config: c, + op: op, + typ: TypeAuthIdentity, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withAuthIdentityID sets the ID field of the mutation. +func withAuthIdentityID(id int64) authidentityOption { + return func(m *AuthIdentityMutation) { + var ( + err error + once sync.Once + value *AuthIdentity + ) + m.oldValue = func(ctx context.Context) (*AuthIdentity, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().AuthIdentity.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withAuthIdentity sets the old AuthIdentity of the mutation. +func withAuthIdentity(node *AuthIdentity) authidentityOption { + return func(m *AuthIdentityMutation) { + m.oldValue = func(context.Context) (*AuthIdentity, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m AuthIdentityMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m AuthIdentityMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *AuthIdentityMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *AuthIdentityMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().AuthIdentity.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *AuthIdentityMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *AuthIdentityMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *AuthIdentityMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *AuthIdentityMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *AuthIdentityMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *AuthIdentityMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetUserID sets the "user_id" field. +func (m *AuthIdentityMutation) SetUserID(i int64) { + m.user = &i +} + +// UserID returns the value of the "user_id" field in the mutation. +func (m *AuthIdentityMutation) UserID() (r int64, exists bool) { + v := m.user + if v == nil { + return + } + return *v, true +} + +// OldUserID returns the old "user_id" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityMutation) OldUserID(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUserID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUserID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUserID: %w", err) + } + return oldValue.UserID, nil +} + +// ResetUserID resets all changes to the "user_id" field. +func (m *AuthIdentityMutation) ResetUserID() { + m.user = nil +} + +// SetProviderType sets the "provider_type" field. +func (m *AuthIdentityMutation) SetProviderType(s string) { + m.provider_type = &s +} + +// ProviderType returns the value of the "provider_type" field in the mutation. +func (m *AuthIdentityMutation) ProviderType() (r string, exists bool) { + v := m.provider_type + if v == nil { + return + } + return *v, true +} + +// OldProviderType returns the old "provider_type" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityMutation) OldProviderType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderType: %w", err) + } + return oldValue.ProviderType, nil +} + +// ResetProviderType resets all changes to the "provider_type" field. +func (m *AuthIdentityMutation) ResetProviderType() { + m.provider_type = nil +} + +// SetProviderKey sets the "provider_key" field. +func (m *AuthIdentityMutation) SetProviderKey(s string) { + m.provider_key = &s +} + +// ProviderKey returns the value of the "provider_key" field in the mutation. +func (m *AuthIdentityMutation) ProviderKey() (r string, exists bool) { + v := m.provider_key + if v == nil { + return + } + return *v, true +} + +// OldProviderKey returns the old "provider_key" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityMutation) OldProviderKey(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderKey is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderKey requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderKey: %w", err) + } + return oldValue.ProviderKey, nil +} + +// ResetProviderKey resets all changes to the "provider_key" field. +func (m *AuthIdentityMutation) ResetProviderKey() { + m.provider_key = nil +} + +// SetProviderSubject sets the "provider_subject" field. +func (m *AuthIdentityMutation) SetProviderSubject(s string) { + m.provider_subject = &s +} + +// ProviderSubject returns the value of the "provider_subject" field in the mutation. +func (m *AuthIdentityMutation) ProviderSubject() (r string, exists bool) { + v := m.provider_subject + if v == nil { + return + } + return *v, true +} + +// OldProviderSubject returns the old "provider_subject" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityMutation) OldProviderSubject(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderSubject is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderSubject requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderSubject: %w", err) + } + return oldValue.ProviderSubject, nil +} + +// ResetProviderSubject resets all changes to the "provider_subject" field. +func (m *AuthIdentityMutation) ResetProviderSubject() { + m.provider_subject = nil +} + +// SetVerifiedAt sets the "verified_at" field. +func (m *AuthIdentityMutation) SetVerifiedAt(t time.Time) { + m.verified_at = &t +} + +// VerifiedAt returns the value of the "verified_at" field in the mutation. +func (m *AuthIdentityMutation) VerifiedAt() (r time.Time, exists bool) { + v := m.verified_at + if v == nil { + return + } + return *v, true +} + +// OldVerifiedAt returns the old "verified_at" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityMutation) OldVerifiedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldVerifiedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldVerifiedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldVerifiedAt: %w", err) + } + return oldValue.VerifiedAt, nil +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (m *AuthIdentityMutation) ClearVerifiedAt() { + m.verified_at = nil + m.clearedFields[authidentity.FieldVerifiedAt] = struct{}{} +} + +// VerifiedAtCleared returns if the "verified_at" field was cleared in this mutation. +func (m *AuthIdentityMutation) VerifiedAtCleared() bool { + _, ok := m.clearedFields[authidentity.FieldVerifiedAt] + return ok +} + +// ResetVerifiedAt resets all changes to the "verified_at" field. +func (m *AuthIdentityMutation) ResetVerifiedAt() { + m.verified_at = nil + delete(m.clearedFields, authidentity.FieldVerifiedAt) +} + +// SetIssuer sets the "issuer" field. +func (m *AuthIdentityMutation) SetIssuer(s string) { + m.issuer = &s +} + +// Issuer returns the value of the "issuer" field in the mutation. +func (m *AuthIdentityMutation) Issuer() (r string, exists bool) { + v := m.issuer + if v == nil { + return + } + return *v, true +} + +// OldIssuer returns the old "issuer" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityMutation) OldIssuer(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIssuer is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIssuer requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIssuer: %w", err) + } + return oldValue.Issuer, nil +} + +// ClearIssuer clears the value of the "issuer" field. +func (m *AuthIdentityMutation) ClearIssuer() { + m.issuer = nil + m.clearedFields[authidentity.FieldIssuer] = struct{}{} +} + +// IssuerCleared returns if the "issuer" field was cleared in this mutation. +func (m *AuthIdentityMutation) IssuerCleared() bool { + _, ok := m.clearedFields[authidentity.FieldIssuer] + return ok +} + +// ResetIssuer resets all changes to the "issuer" field. +func (m *AuthIdentityMutation) ResetIssuer() { + m.issuer = nil + delete(m.clearedFields, authidentity.FieldIssuer) +} + +// SetMetadata sets the "metadata" field. +func (m *AuthIdentityMutation) SetMetadata(value map[string]interface{}) { + m.metadata = &value +} + +// Metadata returns the value of the "metadata" field in the mutation. +func (m *AuthIdentityMutation) Metadata() (r map[string]interface{}, exists bool) { + v := m.metadata + if v == nil { + return + } + return *v, true +} + +// OldMetadata returns the old "metadata" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityMutation) OldMetadata(ctx context.Context) (v map[string]interface{}, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMetadata is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMetadata requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMetadata: %w", err) + } + return oldValue.Metadata, nil +} + +// ResetMetadata resets all changes to the "metadata" field. +func (m *AuthIdentityMutation) ResetMetadata() { + m.metadata = nil +} + +// ClearUser clears the "user" edge to the User entity. +func (m *AuthIdentityMutation) ClearUser() { + m.cleareduser = true + m.clearedFields[authidentity.FieldUserID] = struct{}{} +} + +// UserCleared reports if the "user" edge to the User entity was cleared. +func (m *AuthIdentityMutation) UserCleared() bool { + return m.cleareduser +} + +// UserIDs returns the "user" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// UserID instead. It exists only for internal usage by the builders. +func (m *AuthIdentityMutation) UserIDs() (ids []int64) { + if id := m.user; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetUser resets all changes to the "user" edge. +func (m *AuthIdentityMutation) ResetUser() { + m.user = nil + m.cleareduser = false +} + +// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by ids. +func (m *AuthIdentityMutation) AddChannelIDs(ids ...int64) { + if m.channels == nil { + m.channels = make(map[int64]struct{}) + } + for i := range ids { + m.channels[ids[i]] = struct{}{} + } +} + +// ClearChannels clears the "channels" edge to the AuthIdentityChannel entity. +func (m *AuthIdentityMutation) ClearChannels() { + m.clearedchannels = true +} + +// ChannelsCleared reports if the "channels" edge to the AuthIdentityChannel entity was cleared. +func (m *AuthIdentityMutation) ChannelsCleared() bool { + return m.clearedchannels +} + +// RemoveChannelIDs removes the "channels" edge to the AuthIdentityChannel entity by IDs. +func (m *AuthIdentityMutation) RemoveChannelIDs(ids ...int64) { + if m.removedchannels == nil { + m.removedchannels = make(map[int64]struct{}) + } + for i := range ids { + delete(m.channels, ids[i]) + m.removedchannels[ids[i]] = struct{}{} + } +} + +// RemovedChannels returns the removed IDs of the "channels" edge to the AuthIdentityChannel entity. +func (m *AuthIdentityMutation) RemovedChannelsIDs() (ids []int64) { + for id := range m.removedchannels { + ids = append(ids, id) + } + return +} + +// ChannelsIDs returns the "channels" edge IDs in the mutation. +func (m *AuthIdentityMutation) ChannelsIDs() (ids []int64) { + for id := range m.channels { + ids = append(ids, id) + } + return +} + +// ResetChannels resets all changes to the "channels" edge. +func (m *AuthIdentityMutation) ResetChannels() { + m.channels = nil + m.clearedchannels = false + m.removedchannels = nil +} + +// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by ids. +func (m *AuthIdentityMutation) AddAdoptionDecisionIDs(ids ...int64) { + if m.adoption_decisions == nil { + m.adoption_decisions = make(map[int64]struct{}) + } + for i := range ids { + m.adoption_decisions[ids[i]] = struct{}{} + } +} + +// ClearAdoptionDecisions clears the "adoption_decisions" edge to the IdentityAdoptionDecision entity. +func (m *AuthIdentityMutation) ClearAdoptionDecisions() { + m.clearedadoption_decisions = true +} + +// AdoptionDecisionsCleared reports if the "adoption_decisions" edge to the IdentityAdoptionDecision entity was cleared. +func (m *AuthIdentityMutation) AdoptionDecisionsCleared() bool { + return m.clearedadoption_decisions +} + +// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs. +func (m *AuthIdentityMutation) RemoveAdoptionDecisionIDs(ids ...int64) { + if m.removedadoption_decisions == nil { + m.removedadoption_decisions = make(map[int64]struct{}) + } + for i := range ids { + delete(m.adoption_decisions, ids[i]) + m.removedadoption_decisions[ids[i]] = struct{}{} + } +} + +// RemovedAdoptionDecisions returns the removed IDs of the "adoption_decisions" edge to the IdentityAdoptionDecision entity. +func (m *AuthIdentityMutation) RemovedAdoptionDecisionsIDs() (ids []int64) { + for id := range m.removedadoption_decisions { + ids = append(ids, id) + } + return +} + +// AdoptionDecisionsIDs returns the "adoption_decisions" edge IDs in the mutation. +func (m *AuthIdentityMutation) AdoptionDecisionsIDs() (ids []int64) { + for id := range m.adoption_decisions { + ids = append(ids, id) + } + return +} + +// ResetAdoptionDecisions resets all changes to the "adoption_decisions" edge. +func (m *AuthIdentityMutation) ResetAdoptionDecisions() { + m.adoption_decisions = nil + m.clearedadoption_decisions = false + m.removedadoption_decisions = nil +} + +// Where appends a list predicates to the AuthIdentityMutation builder. +func (m *AuthIdentityMutation) Where(ps ...predicate.AuthIdentity) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the AuthIdentityMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *AuthIdentityMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.AuthIdentity, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *AuthIdentityMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *AuthIdentityMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (AuthIdentity). +func (m *AuthIdentityMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *AuthIdentityMutation) Fields() []string { + fields := make([]string, 0, 9) + if m.created_at != nil { + fields = append(fields, authidentity.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, authidentity.FieldUpdatedAt) + } + if m.user != nil { + fields = append(fields, authidentity.FieldUserID) + } + if m.provider_type != nil { + fields = append(fields, authidentity.FieldProviderType) + } + if m.provider_key != nil { + fields = append(fields, authidentity.FieldProviderKey) + } + if m.provider_subject != nil { + fields = append(fields, authidentity.FieldProviderSubject) + } + if m.verified_at != nil { + fields = append(fields, authidentity.FieldVerifiedAt) + } + if m.issuer != nil { + fields = append(fields, authidentity.FieldIssuer) + } + if m.metadata != nil { + fields = append(fields, authidentity.FieldMetadata) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *AuthIdentityMutation) Field(name string) (ent.Value, bool) { + switch name { + case authidentity.FieldCreatedAt: + return m.CreatedAt() + case authidentity.FieldUpdatedAt: + return m.UpdatedAt() + case authidentity.FieldUserID: + return m.UserID() + case authidentity.FieldProviderType: + return m.ProviderType() + case authidentity.FieldProviderKey: + return m.ProviderKey() + case authidentity.FieldProviderSubject: + return m.ProviderSubject() + case authidentity.FieldVerifiedAt: + return m.VerifiedAt() + case authidentity.FieldIssuer: + return m.Issuer() + case authidentity.FieldMetadata: + return m.Metadata() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *AuthIdentityMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case authidentity.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case authidentity.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case authidentity.FieldUserID: + return m.OldUserID(ctx) + case authidentity.FieldProviderType: + return m.OldProviderType(ctx) + case authidentity.FieldProviderKey: + return m.OldProviderKey(ctx) + case authidentity.FieldProviderSubject: + return m.OldProviderSubject(ctx) + case authidentity.FieldVerifiedAt: + return m.OldVerifiedAt(ctx) + case authidentity.FieldIssuer: + return m.OldIssuer(ctx) + case authidentity.FieldMetadata: + return m.OldMetadata(ctx) + } + return nil, fmt.Errorf("unknown AuthIdentity field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *AuthIdentityMutation) SetField(name string, value ent.Value) error { + switch name { + case authidentity.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case authidentity.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case authidentity.FieldUserID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUserID(v) + return nil + case authidentity.FieldProviderType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderType(v) + return nil + case authidentity.FieldProviderKey: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderKey(v) + return nil + case authidentity.FieldProviderSubject: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderSubject(v) + return nil + case authidentity.FieldVerifiedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetVerifiedAt(v) + return nil + case authidentity.FieldIssuer: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIssuer(v) + return nil + case authidentity.FieldMetadata: + v, ok := value.(map[string]interface{}) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMetadata(v) + return nil + } + return fmt.Errorf("unknown AuthIdentity field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *AuthIdentityMutation) AddedFields() []string { + var fields []string + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *AuthIdentityMutation) AddedField(name string) (ent.Value, bool) { + switch name { + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *AuthIdentityMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown AuthIdentity numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *AuthIdentityMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(authidentity.FieldVerifiedAt) { + fields = append(fields, authidentity.FieldVerifiedAt) + } + if m.FieldCleared(authidentity.FieldIssuer) { + fields = append(fields, authidentity.FieldIssuer) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *AuthIdentityMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *AuthIdentityMutation) ClearField(name string) error { + switch name { + case authidentity.FieldVerifiedAt: + m.ClearVerifiedAt() + return nil + case authidentity.FieldIssuer: + m.ClearIssuer() + return nil + } + return fmt.Errorf("unknown AuthIdentity nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *AuthIdentityMutation) ResetField(name string) error { + switch name { + case authidentity.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case authidentity.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case authidentity.FieldUserID: + m.ResetUserID() + return nil + case authidentity.FieldProviderType: + m.ResetProviderType() + return nil + case authidentity.FieldProviderKey: + m.ResetProviderKey() + return nil + case authidentity.FieldProviderSubject: + m.ResetProviderSubject() + return nil + case authidentity.FieldVerifiedAt: + m.ResetVerifiedAt() + return nil + case authidentity.FieldIssuer: + m.ResetIssuer() + return nil + case authidentity.FieldMetadata: + m.ResetMetadata() + return nil + } + return fmt.Errorf("unknown AuthIdentity field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *AuthIdentityMutation) AddedEdges() []string { + edges := make([]string, 0, 3) + if m.user != nil { + edges = append(edges, authidentity.EdgeUser) + } + if m.channels != nil { + edges = append(edges, authidentity.EdgeChannels) + } + if m.adoption_decisions != nil { + edges = append(edges, authidentity.EdgeAdoptionDecisions) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *AuthIdentityMutation) AddedIDs(name string) []ent.Value { + switch name { + case authidentity.EdgeUser: + if id := m.user; id != nil { + return []ent.Value{*id} + } + case authidentity.EdgeChannels: + ids := make([]ent.Value, 0, len(m.channels)) + for id := range m.channels { + ids = append(ids, id) + } + return ids + case authidentity.EdgeAdoptionDecisions: + ids := make([]ent.Value, 0, len(m.adoption_decisions)) + for id := range m.adoption_decisions { + ids = append(ids, id) + } + return ids + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *AuthIdentityMutation) RemovedEdges() []string { + edges := make([]string, 0, 3) + if m.removedchannels != nil { + edges = append(edges, authidentity.EdgeChannels) + } + if m.removedadoption_decisions != nil { + edges = append(edges, authidentity.EdgeAdoptionDecisions) + } + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *AuthIdentityMutation) RemovedIDs(name string) []ent.Value { + switch name { + case authidentity.EdgeChannels: + ids := make([]ent.Value, 0, len(m.removedchannels)) + for id := range m.removedchannels { + ids = append(ids, id) + } + return ids + case authidentity.EdgeAdoptionDecisions: + ids := make([]ent.Value, 0, len(m.removedadoption_decisions)) + for id := range m.removedadoption_decisions { + ids = append(ids, id) + } + return ids + } + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *AuthIdentityMutation) ClearedEdges() []string { + edges := make([]string, 0, 3) + if m.cleareduser { + edges = append(edges, authidentity.EdgeUser) + } + if m.clearedchannels { + edges = append(edges, authidentity.EdgeChannels) + } + if m.clearedadoption_decisions { + edges = append(edges, authidentity.EdgeAdoptionDecisions) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *AuthIdentityMutation) EdgeCleared(name string) bool { + switch name { + case authidentity.EdgeUser: + return m.cleareduser + case authidentity.EdgeChannels: + return m.clearedchannels + case authidentity.EdgeAdoptionDecisions: + return m.clearedadoption_decisions + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *AuthIdentityMutation) ClearEdge(name string) error { + switch name { + case authidentity.EdgeUser: + m.ClearUser() + return nil + } + return fmt.Errorf("unknown AuthIdentity unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *AuthIdentityMutation) ResetEdge(name string) error { + switch name { + case authidentity.EdgeUser: + m.ResetUser() + return nil + case authidentity.EdgeChannels: + m.ResetChannels() + return nil + case authidentity.EdgeAdoptionDecisions: + m.ResetAdoptionDecisions() + return nil + } + return fmt.Errorf("unknown AuthIdentity edge %s", name) +} + +// AuthIdentityChannelMutation represents an operation that mutates the AuthIdentityChannel nodes in the graph. +type AuthIdentityChannelMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + provider_type *string + provider_key *string + channel *string + channel_app_id *string + channel_subject *string + metadata *map[string]interface{} + clearedFields map[string]struct{} + identity *int64 + clearedidentity bool + done bool + oldValue func(context.Context) (*AuthIdentityChannel, error) + predicates []predicate.AuthIdentityChannel +} + +var _ ent.Mutation = (*AuthIdentityChannelMutation)(nil) + +// authidentitychannelOption allows management of the mutation configuration using functional options. +type authidentitychannelOption func(*AuthIdentityChannelMutation) + +// newAuthIdentityChannelMutation creates new mutation for the AuthIdentityChannel entity. +func newAuthIdentityChannelMutation(c config, op Op, opts ...authidentitychannelOption) *AuthIdentityChannelMutation { + m := &AuthIdentityChannelMutation{ + config: c, + op: op, + typ: TypeAuthIdentityChannel, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withAuthIdentityChannelID sets the ID field of the mutation. +func withAuthIdentityChannelID(id int64) authidentitychannelOption { + return func(m *AuthIdentityChannelMutation) { + var ( + err error + once sync.Once + value *AuthIdentityChannel + ) + m.oldValue = func(ctx context.Context) (*AuthIdentityChannel, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().AuthIdentityChannel.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withAuthIdentityChannel sets the old AuthIdentityChannel of the mutation. +func withAuthIdentityChannel(node *AuthIdentityChannel) authidentitychannelOption { + return func(m *AuthIdentityChannelMutation) { + m.oldValue = func(context.Context) (*AuthIdentityChannel, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m AuthIdentityChannelMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m AuthIdentityChannelMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *AuthIdentityChannelMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *AuthIdentityChannelMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().AuthIdentityChannel.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *AuthIdentityChannelMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *AuthIdentityChannelMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *AuthIdentityChannelMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *AuthIdentityChannelMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *AuthIdentityChannelMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *AuthIdentityChannelMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetIdentityID sets the "identity_id" field. +func (m *AuthIdentityChannelMutation) SetIdentityID(i int64) { + m.identity = &i +} + +// IdentityID returns the value of the "identity_id" field in the mutation. +func (m *AuthIdentityChannelMutation) IdentityID() (r int64, exists bool) { + v := m.identity + if v == nil { + return + } + return *v, true +} + +// OldIdentityID returns the old "identity_id" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldIdentityID(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIdentityID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIdentityID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIdentityID: %w", err) + } + return oldValue.IdentityID, nil +} + +// ResetIdentityID resets all changes to the "identity_id" field. +func (m *AuthIdentityChannelMutation) ResetIdentityID() { + m.identity = nil +} + +// SetProviderType sets the "provider_type" field. +func (m *AuthIdentityChannelMutation) SetProviderType(s string) { + m.provider_type = &s +} + +// ProviderType returns the value of the "provider_type" field in the mutation. +func (m *AuthIdentityChannelMutation) ProviderType() (r string, exists bool) { + v := m.provider_type + if v == nil { + return + } + return *v, true +} + +// OldProviderType returns the old "provider_type" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldProviderType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderType: %w", err) + } + return oldValue.ProviderType, nil +} + +// ResetProviderType resets all changes to the "provider_type" field. +func (m *AuthIdentityChannelMutation) ResetProviderType() { + m.provider_type = nil +} + +// SetProviderKey sets the "provider_key" field. +func (m *AuthIdentityChannelMutation) SetProviderKey(s string) { + m.provider_key = &s +} + +// ProviderKey returns the value of the "provider_key" field in the mutation. +func (m *AuthIdentityChannelMutation) ProviderKey() (r string, exists bool) { + v := m.provider_key + if v == nil { + return + } + return *v, true +} + +// OldProviderKey returns the old "provider_key" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldProviderKey(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderKey is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderKey requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderKey: %w", err) + } + return oldValue.ProviderKey, nil +} + +// ResetProviderKey resets all changes to the "provider_key" field. +func (m *AuthIdentityChannelMutation) ResetProviderKey() { + m.provider_key = nil +} + +// SetChannel sets the "channel" field. +func (m *AuthIdentityChannelMutation) SetChannel(s string) { + m.channel = &s +} + +// Channel returns the value of the "channel" field in the mutation. +func (m *AuthIdentityChannelMutation) Channel() (r string, exists bool) { + v := m.channel + if v == nil { + return + } + return *v, true +} + +// OldChannel returns the old "channel" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldChannel(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldChannel is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldChannel requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldChannel: %w", err) + } + return oldValue.Channel, nil +} + +// ResetChannel resets all changes to the "channel" field. +func (m *AuthIdentityChannelMutation) ResetChannel() { + m.channel = nil +} + +// SetChannelAppID sets the "channel_app_id" field. +func (m *AuthIdentityChannelMutation) SetChannelAppID(s string) { + m.channel_app_id = &s +} + +// ChannelAppID returns the value of the "channel_app_id" field in the mutation. +func (m *AuthIdentityChannelMutation) ChannelAppID() (r string, exists bool) { + v := m.channel_app_id + if v == nil { + return + } + return *v, true +} + +// OldChannelAppID returns the old "channel_app_id" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldChannelAppID(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldChannelAppID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldChannelAppID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldChannelAppID: %w", err) + } + return oldValue.ChannelAppID, nil +} + +// ResetChannelAppID resets all changes to the "channel_app_id" field. +func (m *AuthIdentityChannelMutation) ResetChannelAppID() { + m.channel_app_id = nil +} + +// SetChannelSubject sets the "channel_subject" field. +func (m *AuthIdentityChannelMutation) SetChannelSubject(s string) { + m.channel_subject = &s +} + +// ChannelSubject returns the value of the "channel_subject" field in the mutation. +func (m *AuthIdentityChannelMutation) ChannelSubject() (r string, exists bool) { + v := m.channel_subject + if v == nil { + return + } + return *v, true +} + +// OldChannelSubject returns the old "channel_subject" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldChannelSubject(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldChannelSubject is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldChannelSubject requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldChannelSubject: %w", err) + } + return oldValue.ChannelSubject, nil +} + +// ResetChannelSubject resets all changes to the "channel_subject" field. +func (m *AuthIdentityChannelMutation) ResetChannelSubject() { + m.channel_subject = nil +} + +// SetMetadata sets the "metadata" field. +func (m *AuthIdentityChannelMutation) SetMetadata(value map[string]interface{}) { + m.metadata = &value +} + +// Metadata returns the value of the "metadata" field in the mutation. +func (m *AuthIdentityChannelMutation) Metadata() (r map[string]interface{}, exists bool) { + v := m.metadata + if v == nil { + return + } + return *v, true +} + +// OldMetadata returns the old "metadata" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldMetadata(ctx context.Context) (v map[string]interface{}, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMetadata is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMetadata requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMetadata: %w", err) + } + return oldValue.Metadata, nil +} + +// ResetMetadata resets all changes to the "metadata" field. +func (m *AuthIdentityChannelMutation) ResetMetadata() { + m.metadata = nil +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (m *AuthIdentityChannelMutation) ClearIdentity() { + m.clearedidentity = true + m.clearedFields[authidentitychannel.FieldIdentityID] = struct{}{} +} + +// IdentityCleared reports if the "identity" edge to the AuthIdentity entity was cleared. +func (m *AuthIdentityChannelMutation) IdentityCleared() bool { + return m.clearedidentity +} + +// IdentityIDs returns the "identity" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// IdentityID instead. It exists only for internal usage by the builders. +func (m *AuthIdentityChannelMutation) IdentityIDs() (ids []int64) { + if id := m.identity; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetIdentity resets all changes to the "identity" edge. +func (m *AuthIdentityChannelMutation) ResetIdentity() { + m.identity = nil + m.clearedidentity = false +} + +// Where appends a list predicates to the AuthIdentityChannelMutation builder. +func (m *AuthIdentityChannelMutation) Where(ps ...predicate.AuthIdentityChannel) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the AuthIdentityChannelMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *AuthIdentityChannelMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.AuthIdentityChannel, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *AuthIdentityChannelMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *AuthIdentityChannelMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (AuthIdentityChannel). +func (m *AuthIdentityChannelMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *AuthIdentityChannelMutation) Fields() []string { + fields := make([]string, 0, 9) + if m.created_at != nil { + fields = append(fields, authidentitychannel.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, authidentitychannel.FieldUpdatedAt) + } + if m.identity != nil { + fields = append(fields, authidentitychannel.FieldIdentityID) + } + if m.provider_type != nil { + fields = append(fields, authidentitychannel.FieldProviderType) + } + if m.provider_key != nil { + fields = append(fields, authidentitychannel.FieldProviderKey) + } + if m.channel != nil { + fields = append(fields, authidentitychannel.FieldChannel) + } + if m.channel_app_id != nil { + fields = append(fields, authidentitychannel.FieldChannelAppID) + } + if m.channel_subject != nil { + fields = append(fields, authidentitychannel.FieldChannelSubject) + } + if m.metadata != nil { + fields = append(fields, authidentitychannel.FieldMetadata) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *AuthIdentityChannelMutation) Field(name string) (ent.Value, bool) { + switch name { + case authidentitychannel.FieldCreatedAt: + return m.CreatedAt() + case authidentitychannel.FieldUpdatedAt: + return m.UpdatedAt() + case authidentitychannel.FieldIdentityID: + return m.IdentityID() + case authidentitychannel.FieldProviderType: + return m.ProviderType() + case authidentitychannel.FieldProviderKey: + return m.ProviderKey() + case authidentitychannel.FieldChannel: + return m.Channel() + case authidentitychannel.FieldChannelAppID: + return m.ChannelAppID() + case authidentitychannel.FieldChannelSubject: + return m.ChannelSubject() + case authidentitychannel.FieldMetadata: + return m.Metadata() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *AuthIdentityChannelMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case authidentitychannel.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case authidentitychannel.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case authidentitychannel.FieldIdentityID: + return m.OldIdentityID(ctx) + case authidentitychannel.FieldProviderType: + return m.OldProviderType(ctx) + case authidentitychannel.FieldProviderKey: + return m.OldProviderKey(ctx) + case authidentitychannel.FieldChannel: + return m.OldChannel(ctx) + case authidentitychannel.FieldChannelAppID: + return m.OldChannelAppID(ctx) + case authidentitychannel.FieldChannelSubject: + return m.OldChannelSubject(ctx) + case authidentitychannel.FieldMetadata: + return m.OldMetadata(ctx) + } + return nil, fmt.Errorf("unknown AuthIdentityChannel field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *AuthIdentityChannelMutation) SetField(name string, value ent.Value) error { + switch name { + case authidentitychannel.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case authidentitychannel.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case authidentitychannel.FieldIdentityID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIdentityID(v) + return nil + case authidentitychannel.FieldProviderType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderType(v) + return nil + case authidentitychannel.FieldProviderKey: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderKey(v) + return nil + case authidentitychannel.FieldChannel: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetChannel(v) + return nil + case authidentitychannel.FieldChannelAppID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetChannelAppID(v) + return nil + case authidentitychannel.FieldChannelSubject: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetChannelSubject(v) + return nil + case authidentitychannel.FieldMetadata: + v, ok := value.(map[string]interface{}) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMetadata(v) + return nil + } + return fmt.Errorf("unknown AuthIdentityChannel field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *AuthIdentityChannelMutation) AddedFields() []string { + var fields []string + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *AuthIdentityChannelMutation) AddedField(name string) (ent.Value, bool) { + switch name { + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *AuthIdentityChannelMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown AuthIdentityChannel numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *AuthIdentityChannelMutation) ClearedFields() []string { + return nil +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *AuthIdentityChannelMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *AuthIdentityChannelMutation) ClearField(name string) error { + return fmt.Errorf("unknown AuthIdentityChannel nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *AuthIdentityChannelMutation) ResetField(name string) error { + switch name { + case authidentitychannel.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case authidentitychannel.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case authidentitychannel.FieldIdentityID: + m.ResetIdentityID() + return nil + case authidentitychannel.FieldProviderType: + m.ResetProviderType() + return nil + case authidentitychannel.FieldProviderKey: + m.ResetProviderKey() + return nil + case authidentitychannel.FieldChannel: + m.ResetChannel() + return nil + case authidentitychannel.FieldChannelAppID: + m.ResetChannelAppID() + return nil + case authidentitychannel.FieldChannelSubject: + m.ResetChannelSubject() + return nil + case authidentitychannel.FieldMetadata: + m.ResetMetadata() + return nil + } + return fmt.Errorf("unknown AuthIdentityChannel field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *AuthIdentityChannelMutation) AddedEdges() []string { + edges := make([]string, 0, 1) + if m.identity != nil { + edges = append(edges, authidentitychannel.EdgeIdentity) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *AuthIdentityChannelMutation) AddedIDs(name string) []ent.Value { + switch name { + case authidentitychannel.EdgeIdentity: + if id := m.identity; id != nil { + return []ent.Value{*id} + } + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *AuthIdentityChannelMutation) RemovedEdges() []string { + edges := make([]string, 0, 1) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *AuthIdentityChannelMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *AuthIdentityChannelMutation) ClearedEdges() []string { + edges := make([]string, 0, 1) + if m.clearedidentity { + edges = append(edges, authidentitychannel.EdgeIdentity) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *AuthIdentityChannelMutation) EdgeCleared(name string) bool { + switch name { + case authidentitychannel.EdgeIdentity: + return m.clearedidentity + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *AuthIdentityChannelMutation) ClearEdge(name string) error { + switch name { + case authidentitychannel.EdgeIdentity: + m.ClearIdentity() + return nil + } + return fmt.Errorf("unknown AuthIdentityChannel unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *AuthIdentityChannelMutation) ResetEdge(name string) error { + switch name { + case authidentitychannel.EdgeIdentity: + m.ResetIdentity() + return nil + } + return fmt.Errorf("unknown AuthIdentityChannel edge %s", name) +} + // ErrorPassthroughRuleMutation represents an operation that mutates the ErrorPassthroughRule nodes in the graph. type ErrorPassthroughRuleMutation struct { config @@ -12191,6 +14038,781 @@ func (m *IdempotencyRecordMutation) ResetEdge(name string) error { return fmt.Errorf("unknown IdempotencyRecord edge %s", name) } +// IdentityAdoptionDecisionMutation represents an operation that mutates the IdentityAdoptionDecision nodes in the graph. +type IdentityAdoptionDecisionMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + adopt_display_name *bool + adopt_avatar *bool + decided_at *time.Time + clearedFields map[string]struct{} + pending_auth_session *int64 + clearedpending_auth_session bool + identity *int64 + clearedidentity bool + done bool + oldValue func(context.Context) (*IdentityAdoptionDecision, error) + predicates []predicate.IdentityAdoptionDecision +} + +var _ ent.Mutation = (*IdentityAdoptionDecisionMutation)(nil) + +// identityadoptiondecisionOption allows management of the mutation configuration using functional options. +type identityadoptiondecisionOption func(*IdentityAdoptionDecisionMutation) + +// newIdentityAdoptionDecisionMutation creates new mutation for the IdentityAdoptionDecision entity. +func newIdentityAdoptionDecisionMutation(c config, op Op, opts ...identityadoptiondecisionOption) *IdentityAdoptionDecisionMutation { + m := &IdentityAdoptionDecisionMutation{ + config: c, + op: op, + typ: TypeIdentityAdoptionDecision, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withIdentityAdoptionDecisionID sets the ID field of the mutation. +func withIdentityAdoptionDecisionID(id int64) identityadoptiondecisionOption { + return func(m *IdentityAdoptionDecisionMutation) { + var ( + err error + once sync.Once + value *IdentityAdoptionDecision + ) + m.oldValue = func(ctx context.Context) (*IdentityAdoptionDecision, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().IdentityAdoptionDecision.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withIdentityAdoptionDecision sets the old IdentityAdoptionDecision of the mutation. +func withIdentityAdoptionDecision(node *IdentityAdoptionDecision) identityadoptiondecisionOption { + return func(m *IdentityAdoptionDecisionMutation) { + m.oldValue = func(context.Context) (*IdentityAdoptionDecision, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m IdentityAdoptionDecisionMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m IdentityAdoptionDecisionMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *IdentityAdoptionDecisionMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *IdentityAdoptionDecisionMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().IdentityAdoptionDecision.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *IdentityAdoptionDecisionMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdentityAdoptionDecisionMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *IdentityAdoptionDecisionMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *IdentityAdoptionDecisionMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdentityAdoptionDecisionMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *IdentityAdoptionDecisionMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (m *IdentityAdoptionDecisionMutation) SetPendingAuthSessionID(i int64) { + m.pending_auth_session = &i +} + +// PendingAuthSessionID returns the value of the "pending_auth_session_id" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionID() (r int64, exists bool) { + v := m.pending_auth_session + if v == nil { + return + } + return *v, true +} + +// OldPendingAuthSessionID returns the old "pending_auth_session_id" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdentityAdoptionDecisionMutation) OldPendingAuthSessionID(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPendingAuthSessionID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPendingAuthSessionID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPendingAuthSessionID: %w", err) + } + return oldValue.PendingAuthSessionID, nil +} + +// ResetPendingAuthSessionID resets all changes to the "pending_auth_session_id" field. +func (m *IdentityAdoptionDecisionMutation) ResetPendingAuthSessionID() { + m.pending_auth_session = nil +} + +// SetIdentityID sets the "identity_id" field. +func (m *IdentityAdoptionDecisionMutation) SetIdentityID(i int64) { + m.identity = &i +} + +// IdentityID returns the value of the "identity_id" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) IdentityID() (r int64, exists bool) { + v := m.identity + if v == nil { + return + } + return *v, true +} + +// OldIdentityID returns the old "identity_id" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdentityAdoptionDecisionMutation) OldIdentityID(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIdentityID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIdentityID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIdentityID: %w", err) + } + return oldValue.IdentityID, nil +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (m *IdentityAdoptionDecisionMutation) ClearIdentityID() { + m.identity = nil + m.clearedFields[identityadoptiondecision.FieldIdentityID] = struct{}{} +} + +// IdentityIDCleared returns if the "identity_id" field was cleared in this mutation. +func (m *IdentityAdoptionDecisionMutation) IdentityIDCleared() bool { + _, ok := m.clearedFields[identityadoptiondecision.FieldIdentityID] + return ok +} + +// ResetIdentityID resets all changes to the "identity_id" field. +func (m *IdentityAdoptionDecisionMutation) ResetIdentityID() { + m.identity = nil + delete(m.clearedFields, identityadoptiondecision.FieldIdentityID) +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (m *IdentityAdoptionDecisionMutation) SetAdoptDisplayName(b bool) { + m.adopt_display_name = &b +} + +// AdoptDisplayName returns the value of the "adopt_display_name" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) AdoptDisplayName() (r bool, exists bool) { + v := m.adopt_display_name + if v == nil { + return + } + return *v, true +} + +// OldAdoptDisplayName returns the old "adopt_display_name" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdentityAdoptionDecisionMutation) OldAdoptDisplayName(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAdoptDisplayName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAdoptDisplayName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAdoptDisplayName: %w", err) + } + return oldValue.AdoptDisplayName, nil +} + +// ResetAdoptDisplayName resets all changes to the "adopt_display_name" field. +func (m *IdentityAdoptionDecisionMutation) ResetAdoptDisplayName() { + m.adopt_display_name = nil +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (m *IdentityAdoptionDecisionMutation) SetAdoptAvatar(b bool) { + m.adopt_avatar = &b +} + +// AdoptAvatar returns the value of the "adopt_avatar" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) AdoptAvatar() (r bool, exists bool) { + v := m.adopt_avatar + if v == nil { + return + } + return *v, true +} + +// OldAdoptAvatar returns the old "adopt_avatar" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdentityAdoptionDecisionMutation) OldAdoptAvatar(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAdoptAvatar is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAdoptAvatar requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAdoptAvatar: %w", err) + } + return oldValue.AdoptAvatar, nil +} + +// ResetAdoptAvatar resets all changes to the "adopt_avatar" field. +func (m *IdentityAdoptionDecisionMutation) ResetAdoptAvatar() { + m.adopt_avatar = nil +} + +// SetDecidedAt sets the "decided_at" field. +func (m *IdentityAdoptionDecisionMutation) SetDecidedAt(t time.Time) { + m.decided_at = &t +} + +// DecidedAt returns the value of the "decided_at" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) DecidedAt() (r time.Time, exists bool) { + v := m.decided_at + if v == nil { + return + } + return *v, true +} + +// OldDecidedAt returns the old "decided_at" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdentityAdoptionDecisionMutation) OldDecidedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDecidedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDecidedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDecidedAt: %w", err) + } + return oldValue.DecidedAt, nil +} + +// ResetDecidedAt resets all changes to the "decided_at" field. +func (m *IdentityAdoptionDecisionMutation) ResetDecidedAt() { + m.decided_at = nil +} + +// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity. +func (m *IdentityAdoptionDecisionMutation) ClearPendingAuthSession() { + m.clearedpending_auth_session = true + m.clearedFields[identityadoptiondecision.FieldPendingAuthSessionID] = struct{}{} +} + +// PendingAuthSessionCleared reports if the "pending_auth_session" edge to the PendingAuthSession entity was cleared. +func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionCleared() bool { + return m.clearedpending_auth_session +} + +// PendingAuthSessionIDs returns the "pending_auth_session" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// PendingAuthSessionID instead. It exists only for internal usage by the builders. +func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionIDs() (ids []int64) { + if id := m.pending_auth_session; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetPendingAuthSession resets all changes to the "pending_auth_session" edge. +func (m *IdentityAdoptionDecisionMutation) ResetPendingAuthSession() { + m.pending_auth_session = nil + m.clearedpending_auth_session = false +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (m *IdentityAdoptionDecisionMutation) ClearIdentity() { + m.clearedidentity = true + m.clearedFields[identityadoptiondecision.FieldIdentityID] = struct{}{} +} + +// IdentityCleared reports if the "identity" edge to the AuthIdentity entity was cleared. +func (m *IdentityAdoptionDecisionMutation) IdentityCleared() bool { + return m.IdentityIDCleared() || m.clearedidentity +} + +// IdentityIDs returns the "identity" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// IdentityID instead. It exists only for internal usage by the builders. +func (m *IdentityAdoptionDecisionMutation) IdentityIDs() (ids []int64) { + if id := m.identity; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetIdentity resets all changes to the "identity" edge. +func (m *IdentityAdoptionDecisionMutation) ResetIdentity() { + m.identity = nil + m.clearedidentity = false +} + +// Where appends a list predicates to the IdentityAdoptionDecisionMutation builder. +func (m *IdentityAdoptionDecisionMutation) Where(ps ...predicate.IdentityAdoptionDecision) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the IdentityAdoptionDecisionMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *IdentityAdoptionDecisionMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.IdentityAdoptionDecision, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *IdentityAdoptionDecisionMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *IdentityAdoptionDecisionMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (IdentityAdoptionDecision). +func (m *IdentityAdoptionDecisionMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *IdentityAdoptionDecisionMutation) Fields() []string { + fields := make([]string, 0, 7) + if m.created_at != nil { + fields = append(fields, identityadoptiondecision.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, identityadoptiondecision.FieldUpdatedAt) + } + if m.pending_auth_session != nil { + fields = append(fields, identityadoptiondecision.FieldPendingAuthSessionID) + } + if m.identity != nil { + fields = append(fields, identityadoptiondecision.FieldIdentityID) + } + if m.adopt_display_name != nil { + fields = append(fields, identityadoptiondecision.FieldAdoptDisplayName) + } + if m.adopt_avatar != nil { + fields = append(fields, identityadoptiondecision.FieldAdoptAvatar) + } + if m.decided_at != nil { + fields = append(fields, identityadoptiondecision.FieldDecidedAt) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *IdentityAdoptionDecisionMutation) Field(name string) (ent.Value, bool) { + switch name { + case identityadoptiondecision.FieldCreatedAt: + return m.CreatedAt() + case identityadoptiondecision.FieldUpdatedAt: + return m.UpdatedAt() + case identityadoptiondecision.FieldPendingAuthSessionID: + return m.PendingAuthSessionID() + case identityadoptiondecision.FieldIdentityID: + return m.IdentityID() + case identityadoptiondecision.FieldAdoptDisplayName: + return m.AdoptDisplayName() + case identityadoptiondecision.FieldAdoptAvatar: + return m.AdoptAvatar() + case identityadoptiondecision.FieldDecidedAt: + return m.DecidedAt() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *IdentityAdoptionDecisionMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case identityadoptiondecision.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case identityadoptiondecision.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case identityadoptiondecision.FieldPendingAuthSessionID: + return m.OldPendingAuthSessionID(ctx) + case identityadoptiondecision.FieldIdentityID: + return m.OldIdentityID(ctx) + case identityadoptiondecision.FieldAdoptDisplayName: + return m.OldAdoptDisplayName(ctx) + case identityadoptiondecision.FieldAdoptAvatar: + return m.OldAdoptAvatar(ctx) + case identityadoptiondecision.FieldDecidedAt: + return m.OldDecidedAt(ctx) + } + return nil, fmt.Errorf("unknown IdentityAdoptionDecision field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *IdentityAdoptionDecisionMutation) SetField(name string, value ent.Value) error { + switch name { + case identityadoptiondecision.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case identityadoptiondecision.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case identityadoptiondecision.FieldPendingAuthSessionID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPendingAuthSessionID(v) + return nil + case identityadoptiondecision.FieldIdentityID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIdentityID(v) + return nil + case identityadoptiondecision.FieldAdoptDisplayName: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAdoptDisplayName(v) + return nil + case identityadoptiondecision.FieldAdoptAvatar: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAdoptAvatar(v) + return nil + case identityadoptiondecision.FieldDecidedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDecidedAt(v) + return nil + } + return fmt.Errorf("unknown IdentityAdoptionDecision field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *IdentityAdoptionDecisionMutation) AddedFields() []string { + var fields []string + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *IdentityAdoptionDecisionMutation) AddedField(name string) (ent.Value, bool) { + switch name { + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *IdentityAdoptionDecisionMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown IdentityAdoptionDecision numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *IdentityAdoptionDecisionMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(identityadoptiondecision.FieldIdentityID) { + fields = append(fields, identityadoptiondecision.FieldIdentityID) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *IdentityAdoptionDecisionMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *IdentityAdoptionDecisionMutation) ClearField(name string) error { + switch name { + case identityadoptiondecision.FieldIdentityID: + m.ClearIdentityID() + return nil + } + return fmt.Errorf("unknown IdentityAdoptionDecision nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *IdentityAdoptionDecisionMutation) ResetField(name string) error { + switch name { + case identityadoptiondecision.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case identityadoptiondecision.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case identityadoptiondecision.FieldPendingAuthSessionID: + m.ResetPendingAuthSessionID() + return nil + case identityadoptiondecision.FieldIdentityID: + m.ResetIdentityID() + return nil + case identityadoptiondecision.FieldAdoptDisplayName: + m.ResetAdoptDisplayName() + return nil + case identityadoptiondecision.FieldAdoptAvatar: + m.ResetAdoptAvatar() + return nil + case identityadoptiondecision.FieldDecidedAt: + m.ResetDecidedAt() + return nil + } + return fmt.Errorf("unknown IdentityAdoptionDecision field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *IdentityAdoptionDecisionMutation) AddedEdges() []string { + edges := make([]string, 0, 2) + if m.pending_auth_session != nil { + edges = append(edges, identityadoptiondecision.EdgePendingAuthSession) + } + if m.identity != nil { + edges = append(edges, identityadoptiondecision.EdgeIdentity) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *IdentityAdoptionDecisionMutation) AddedIDs(name string) []ent.Value { + switch name { + case identityadoptiondecision.EdgePendingAuthSession: + if id := m.pending_auth_session; id != nil { + return []ent.Value{*id} + } + case identityadoptiondecision.EdgeIdentity: + if id := m.identity; id != nil { + return []ent.Value{*id} + } + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *IdentityAdoptionDecisionMutation) RemovedEdges() []string { + edges := make([]string, 0, 2) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *IdentityAdoptionDecisionMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *IdentityAdoptionDecisionMutation) ClearedEdges() []string { + edges := make([]string, 0, 2) + if m.clearedpending_auth_session { + edges = append(edges, identityadoptiondecision.EdgePendingAuthSession) + } + if m.clearedidentity { + edges = append(edges, identityadoptiondecision.EdgeIdentity) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *IdentityAdoptionDecisionMutation) EdgeCleared(name string) bool { + switch name { + case identityadoptiondecision.EdgePendingAuthSession: + return m.clearedpending_auth_session + case identityadoptiondecision.EdgeIdentity: + return m.clearedidentity + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *IdentityAdoptionDecisionMutation) ClearEdge(name string) error { + switch name { + case identityadoptiondecision.EdgePendingAuthSession: + m.ClearPendingAuthSession() + return nil + case identityadoptiondecision.EdgeIdentity: + m.ClearIdentity() + return nil + } + return fmt.Errorf("unknown IdentityAdoptionDecision unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *IdentityAdoptionDecisionMutation) ResetEdge(name string) error { + switch name { + case identityadoptiondecision.EdgePendingAuthSession: + m.ResetPendingAuthSession() + return nil + case identityadoptiondecision.EdgeIdentity: + m.ResetIdentity() + return nil + } + return fmt.Errorf("unknown IdentityAdoptionDecision edge %s", name) +} + // PaymentAuditLogMutation represents an operation that mutates the PaymentAuditLog nodes in the graph. type PaymentAuditLogMutation struct { config @@ -16595,6 +19217,1645 @@ func (m *PaymentProviderInstanceMutation) ResetEdge(name string) error { return fmt.Errorf("unknown PaymentProviderInstance edge %s", name) } +// PendingAuthSessionMutation represents an operation that mutates the PendingAuthSession nodes in the graph. +type PendingAuthSessionMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + session_token *string + intent *string + provider_type *string + provider_key *string + provider_subject *string + redirect_to *string + resolved_email *string + registration_password_hash *string + upstream_identity_claims *map[string]interface{} + local_flow_state *map[string]interface{} + browser_session_key *string + completion_code_hash *string + completion_code_expires_at *time.Time + email_verified_at *time.Time + password_verified_at *time.Time + totp_verified_at *time.Time + expires_at *time.Time + consumed_at *time.Time + clearedFields map[string]struct{} + target_user *int64 + clearedtarget_user bool + adoption_decision *int64 + clearedadoption_decision bool + done bool + oldValue func(context.Context) (*PendingAuthSession, error) + predicates []predicate.PendingAuthSession +} + +var _ ent.Mutation = (*PendingAuthSessionMutation)(nil) + +// pendingauthsessionOption allows management of the mutation configuration using functional options. +type pendingauthsessionOption func(*PendingAuthSessionMutation) + +// newPendingAuthSessionMutation creates new mutation for the PendingAuthSession entity. +func newPendingAuthSessionMutation(c config, op Op, opts ...pendingauthsessionOption) *PendingAuthSessionMutation { + m := &PendingAuthSessionMutation{ + config: c, + op: op, + typ: TypePendingAuthSession, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withPendingAuthSessionID sets the ID field of the mutation. +func withPendingAuthSessionID(id int64) pendingauthsessionOption { + return func(m *PendingAuthSessionMutation) { + var ( + err error + once sync.Once + value *PendingAuthSession + ) + m.oldValue = func(ctx context.Context) (*PendingAuthSession, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().PendingAuthSession.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withPendingAuthSession sets the old PendingAuthSession of the mutation. +func withPendingAuthSession(node *PendingAuthSession) pendingauthsessionOption { + return func(m *PendingAuthSessionMutation) { + m.oldValue = func(context.Context) (*PendingAuthSession, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m PendingAuthSessionMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m PendingAuthSessionMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *PendingAuthSessionMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *PendingAuthSessionMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().PendingAuthSession.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *PendingAuthSessionMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *PendingAuthSessionMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *PendingAuthSessionMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *PendingAuthSessionMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *PendingAuthSessionMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *PendingAuthSessionMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetSessionToken sets the "session_token" field. +func (m *PendingAuthSessionMutation) SetSessionToken(s string) { + m.session_token = &s +} + +// SessionToken returns the value of the "session_token" field in the mutation. +func (m *PendingAuthSessionMutation) SessionToken() (r string, exists bool) { + v := m.session_token + if v == nil { + return + } + return *v, true +} + +// OldSessionToken returns the old "session_token" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldSessionToken(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSessionToken is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSessionToken requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSessionToken: %w", err) + } + return oldValue.SessionToken, nil +} + +// ResetSessionToken resets all changes to the "session_token" field. +func (m *PendingAuthSessionMutation) ResetSessionToken() { + m.session_token = nil +} + +// SetIntent sets the "intent" field. +func (m *PendingAuthSessionMutation) SetIntent(s string) { + m.intent = &s +} + +// Intent returns the value of the "intent" field in the mutation. +func (m *PendingAuthSessionMutation) Intent() (r string, exists bool) { + v := m.intent + if v == nil { + return + } + return *v, true +} + +// OldIntent returns the old "intent" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldIntent(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIntent is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIntent requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIntent: %w", err) + } + return oldValue.Intent, nil +} + +// ResetIntent resets all changes to the "intent" field. +func (m *PendingAuthSessionMutation) ResetIntent() { + m.intent = nil +} + +// SetProviderType sets the "provider_type" field. +func (m *PendingAuthSessionMutation) SetProviderType(s string) { + m.provider_type = &s +} + +// ProviderType returns the value of the "provider_type" field in the mutation. +func (m *PendingAuthSessionMutation) ProviderType() (r string, exists bool) { + v := m.provider_type + if v == nil { + return + } + return *v, true +} + +// OldProviderType returns the old "provider_type" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldProviderType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderType: %w", err) + } + return oldValue.ProviderType, nil +} + +// ResetProviderType resets all changes to the "provider_type" field. +func (m *PendingAuthSessionMutation) ResetProviderType() { + m.provider_type = nil +} + +// SetProviderKey sets the "provider_key" field. +func (m *PendingAuthSessionMutation) SetProviderKey(s string) { + m.provider_key = &s +} + +// ProviderKey returns the value of the "provider_key" field in the mutation. +func (m *PendingAuthSessionMutation) ProviderKey() (r string, exists bool) { + v := m.provider_key + if v == nil { + return + } + return *v, true +} + +// OldProviderKey returns the old "provider_key" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldProviderKey(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderKey is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderKey requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderKey: %w", err) + } + return oldValue.ProviderKey, nil +} + +// ResetProviderKey resets all changes to the "provider_key" field. +func (m *PendingAuthSessionMutation) ResetProviderKey() { + m.provider_key = nil +} + +// SetProviderSubject sets the "provider_subject" field. +func (m *PendingAuthSessionMutation) SetProviderSubject(s string) { + m.provider_subject = &s +} + +// ProviderSubject returns the value of the "provider_subject" field in the mutation. +func (m *PendingAuthSessionMutation) ProviderSubject() (r string, exists bool) { + v := m.provider_subject + if v == nil { + return + } + return *v, true +} + +// OldProviderSubject returns the old "provider_subject" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldProviderSubject(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderSubject is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderSubject requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderSubject: %w", err) + } + return oldValue.ProviderSubject, nil +} + +// ResetProviderSubject resets all changes to the "provider_subject" field. +func (m *PendingAuthSessionMutation) ResetProviderSubject() { + m.provider_subject = nil +} + +// SetTargetUserID sets the "target_user_id" field. +func (m *PendingAuthSessionMutation) SetTargetUserID(i int64) { + m.target_user = &i +} + +// TargetUserID returns the value of the "target_user_id" field in the mutation. +func (m *PendingAuthSessionMutation) TargetUserID() (r int64, exists bool) { + v := m.target_user + if v == nil { + return + } + return *v, true +} + +// OldTargetUserID returns the old "target_user_id" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldTargetUserID(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTargetUserID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTargetUserID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTargetUserID: %w", err) + } + return oldValue.TargetUserID, nil +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (m *PendingAuthSessionMutation) ClearTargetUserID() { + m.target_user = nil + m.clearedFields[pendingauthsession.FieldTargetUserID] = struct{}{} +} + +// TargetUserIDCleared returns if the "target_user_id" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) TargetUserIDCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldTargetUserID] + return ok +} + +// ResetTargetUserID resets all changes to the "target_user_id" field. +func (m *PendingAuthSessionMutation) ResetTargetUserID() { + m.target_user = nil + delete(m.clearedFields, pendingauthsession.FieldTargetUserID) +} + +// SetRedirectTo sets the "redirect_to" field. +func (m *PendingAuthSessionMutation) SetRedirectTo(s string) { + m.redirect_to = &s +} + +// RedirectTo returns the value of the "redirect_to" field in the mutation. +func (m *PendingAuthSessionMutation) RedirectTo() (r string, exists bool) { + v := m.redirect_to + if v == nil { + return + } + return *v, true +} + +// OldRedirectTo returns the old "redirect_to" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldRedirectTo(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRedirectTo is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRedirectTo requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRedirectTo: %w", err) + } + return oldValue.RedirectTo, nil +} + +// ResetRedirectTo resets all changes to the "redirect_to" field. +func (m *PendingAuthSessionMutation) ResetRedirectTo() { + m.redirect_to = nil +} + +// SetResolvedEmail sets the "resolved_email" field. +func (m *PendingAuthSessionMutation) SetResolvedEmail(s string) { + m.resolved_email = &s +} + +// ResolvedEmail returns the value of the "resolved_email" field in the mutation. +func (m *PendingAuthSessionMutation) ResolvedEmail() (r string, exists bool) { + v := m.resolved_email + if v == nil { + return + } + return *v, true +} + +// OldResolvedEmail returns the old "resolved_email" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldResolvedEmail(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldResolvedEmail is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldResolvedEmail requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldResolvedEmail: %w", err) + } + return oldValue.ResolvedEmail, nil +} + +// ResetResolvedEmail resets all changes to the "resolved_email" field. +func (m *PendingAuthSessionMutation) ResetResolvedEmail() { + m.resolved_email = nil +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (m *PendingAuthSessionMutation) SetRegistrationPasswordHash(s string) { + m.registration_password_hash = &s +} + +// RegistrationPasswordHash returns the value of the "registration_password_hash" field in the mutation. +func (m *PendingAuthSessionMutation) RegistrationPasswordHash() (r string, exists bool) { + v := m.registration_password_hash + if v == nil { + return + } + return *v, true +} + +// OldRegistrationPasswordHash returns the old "registration_password_hash" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldRegistrationPasswordHash(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRegistrationPasswordHash is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRegistrationPasswordHash requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRegistrationPasswordHash: %w", err) + } + return oldValue.RegistrationPasswordHash, nil +} + +// ResetRegistrationPasswordHash resets all changes to the "registration_password_hash" field. +func (m *PendingAuthSessionMutation) ResetRegistrationPasswordHash() { + m.registration_password_hash = nil +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (m *PendingAuthSessionMutation) SetUpstreamIdentityClaims(value map[string]interface{}) { + m.upstream_identity_claims = &value +} + +// UpstreamIdentityClaims returns the value of the "upstream_identity_claims" field in the mutation. +func (m *PendingAuthSessionMutation) UpstreamIdentityClaims() (r map[string]interface{}, exists bool) { + v := m.upstream_identity_claims + if v == nil { + return + } + return *v, true +} + +// OldUpstreamIdentityClaims returns the old "upstream_identity_claims" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldUpstreamIdentityClaims(ctx context.Context) (v map[string]interface{}, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpstreamIdentityClaims is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpstreamIdentityClaims requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpstreamIdentityClaims: %w", err) + } + return oldValue.UpstreamIdentityClaims, nil +} + +// ResetUpstreamIdentityClaims resets all changes to the "upstream_identity_claims" field. +func (m *PendingAuthSessionMutation) ResetUpstreamIdentityClaims() { + m.upstream_identity_claims = nil +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (m *PendingAuthSessionMutation) SetLocalFlowState(value map[string]interface{}) { + m.local_flow_state = &value +} + +// LocalFlowState returns the value of the "local_flow_state" field in the mutation. +func (m *PendingAuthSessionMutation) LocalFlowState() (r map[string]interface{}, exists bool) { + v := m.local_flow_state + if v == nil { + return + } + return *v, true +} + +// OldLocalFlowState returns the old "local_flow_state" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldLocalFlowState(ctx context.Context) (v map[string]interface{}, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLocalFlowState is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLocalFlowState requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLocalFlowState: %w", err) + } + return oldValue.LocalFlowState, nil +} + +// ResetLocalFlowState resets all changes to the "local_flow_state" field. +func (m *PendingAuthSessionMutation) ResetLocalFlowState() { + m.local_flow_state = nil +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (m *PendingAuthSessionMutation) SetBrowserSessionKey(s string) { + m.browser_session_key = &s +} + +// BrowserSessionKey returns the value of the "browser_session_key" field in the mutation. +func (m *PendingAuthSessionMutation) BrowserSessionKey() (r string, exists bool) { + v := m.browser_session_key + if v == nil { + return + } + return *v, true +} + +// OldBrowserSessionKey returns the old "browser_session_key" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldBrowserSessionKey(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldBrowserSessionKey is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldBrowserSessionKey requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldBrowserSessionKey: %w", err) + } + return oldValue.BrowserSessionKey, nil +} + +// ResetBrowserSessionKey resets all changes to the "browser_session_key" field. +func (m *PendingAuthSessionMutation) ResetBrowserSessionKey() { + m.browser_session_key = nil +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (m *PendingAuthSessionMutation) SetCompletionCodeHash(s string) { + m.completion_code_hash = &s +} + +// CompletionCodeHash returns the value of the "completion_code_hash" field in the mutation. +func (m *PendingAuthSessionMutation) CompletionCodeHash() (r string, exists bool) { + v := m.completion_code_hash + if v == nil { + return + } + return *v, true +} + +// OldCompletionCodeHash returns the old "completion_code_hash" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldCompletionCodeHash(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCompletionCodeHash is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCompletionCodeHash requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCompletionCodeHash: %w", err) + } + return oldValue.CompletionCodeHash, nil +} + +// ResetCompletionCodeHash resets all changes to the "completion_code_hash" field. +func (m *PendingAuthSessionMutation) ResetCompletionCodeHash() { + m.completion_code_hash = nil +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (m *PendingAuthSessionMutation) SetCompletionCodeExpiresAt(t time.Time) { + m.completion_code_expires_at = &t +} + +// CompletionCodeExpiresAt returns the value of the "completion_code_expires_at" field in the mutation. +func (m *PendingAuthSessionMutation) CompletionCodeExpiresAt() (r time.Time, exists bool) { + v := m.completion_code_expires_at + if v == nil { + return + } + return *v, true +} + +// OldCompletionCodeExpiresAt returns the old "completion_code_expires_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldCompletionCodeExpiresAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCompletionCodeExpiresAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCompletionCodeExpiresAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCompletionCodeExpiresAt: %w", err) + } + return oldValue.CompletionCodeExpiresAt, nil +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (m *PendingAuthSessionMutation) ClearCompletionCodeExpiresAt() { + m.completion_code_expires_at = nil + m.clearedFields[pendingauthsession.FieldCompletionCodeExpiresAt] = struct{}{} +} + +// CompletionCodeExpiresAtCleared returns if the "completion_code_expires_at" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) CompletionCodeExpiresAtCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldCompletionCodeExpiresAt] + return ok +} + +// ResetCompletionCodeExpiresAt resets all changes to the "completion_code_expires_at" field. +func (m *PendingAuthSessionMutation) ResetCompletionCodeExpiresAt() { + m.completion_code_expires_at = nil + delete(m.clearedFields, pendingauthsession.FieldCompletionCodeExpiresAt) +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (m *PendingAuthSessionMutation) SetEmailVerifiedAt(t time.Time) { + m.email_verified_at = &t +} + +// EmailVerifiedAt returns the value of the "email_verified_at" field in the mutation. +func (m *PendingAuthSessionMutation) EmailVerifiedAt() (r time.Time, exists bool) { + v := m.email_verified_at + if v == nil { + return + } + return *v, true +} + +// OldEmailVerifiedAt returns the old "email_verified_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldEmailVerifiedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldEmailVerifiedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldEmailVerifiedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldEmailVerifiedAt: %w", err) + } + return oldValue.EmailVerifiedAt, nil +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (m *PendingAuthSessionMutation) ClearEmailVerifiedAt() { + m.email_verified_at = nil + m.clearedFields[pendingauthsession.FieldEmailVerifiedAt] = struct{}{} +} + +// EmailVerifiedAtCleared returns if the "email_verified_at" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) EmailVerifiedAtCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldEmailVerifiedAt] + return ok +} + +// ResetEmailVerifiedAt resets all changes to the "email_verified_at" field. +func (m *PendingAuthSessionMutation) ResetEmailVerifiedAt() { + m.email_verified_at = nil + delete(m.clearedFields, pendingauthsession.FieldEmailVerifiedAt) +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (m *PendingAuthSessionMutation) SetPasswordVerifiedAt(t time.Time) { + m.password_verified_at = &t +} + +// PasswordVerifiedAt returns the value of the "password_verified_at" field in the mutation. +func (m *PendingAuthSessionMutation) PasswordVerifiedAt() (r time.Time, exists bool) { + v := m.password_verified_at + if v == nil { + return + } + return *v, true +} + +// OldPasswordVerifiedAt returns the old "password_verified_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldPasswordVerifiedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPasswordVerifiedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPasswordVerifiedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPasswordVerifiedAt: %w", err) + } + return oldValue.PasswordVerifiedAt, nil +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (m *PendingAuthSessionMutation) ClearPasswordVerifiedAt() { + m.password_verified_at = nil + m.clearedFields[pendingauthsession.FieldPasswordVerifiedAt] = struct{}{} +} + +// PasswordVerifiedAtCleared returns if the "password_verified_at" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) PasswordVerifiedAtCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldPasswordVerifiedAt] + return ok +} + +// ResetPasswordVerifiedAt resets all changes to the "password_verified_at" field. +func (m *PendingAuthSessionMutation) ResetPasswordVerifiedAt() { + m.password_verified_at = nil + delete(m.clearedFields, pendingauthsession.FieldPasswordVerifiedAt) +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (m *PendingAuthSessionMutation) SetTotpVerifiedAt(t time.Time) { + m.totp_verified_at = &t +} + +// TotpVerifiedAt returns the value of the "totp_verified_at" field in the mutation. +func (m *PendingAuthSessionMutation) TotpVerifiedAt() (r time.Time, exists bool) { + v := m.totp_verified_at + if v == nil { + return + } + return *v, true +} + +// OldTotpVerifiedAt returns the old "totp_verified_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldTotpVerifiedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTotpVerifiedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTotpVerifiedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTotpVerifiedAt: %w", err) + } + return oldValue.TotpVerifiedAt, nil +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (m *PendingAuthSessionMutation) ClearTotpVerifiedAt() { + m.totp_verified_at = nil + m.clearedFields[pendingauthsession.FieldTotpVerifiedAt] = struct{}{} +} + +// TotpVerifiedAtCleared returns if the "totp_verified_at" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) TotpVerifiedAtCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldTotpVerifiedAt] + return ok +} + +// ResetTotpVerifiedAt resets all changes to the "totp_verified_at" field. +func (m *PendingAuthSessionMutation) ResetTotpVerifiedAt() { + m.totp_verified_at = nil + delete(m.clearedFields, pendingauthsession.FieldTotpVerifiedAt) +} + +// SetExpiresAt sets the "expires_at" field. +func (m *PendingAuthSessionMutation) SetExpiresAt(t time.Time) { + m.expires_at = &t +} + +// ExpiresAt returns the value of the "expires_at" field in the mutation. +func (m *PendingAuthSessionMutation) ExpiresAt() (r time.Time, exists bool) { + v := m.expires_at + if v == nil { + return + } + return *v, true +} + +// OldExpiresAt returns the old "expires_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldExpiresAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) + } + return oldValue.ExpiresAt, nil +} + +// ResetExpiresAt resets all changes to the "expires_at" field. +func (m *PendingAuthSessionMutation) ResetExpiresAt() { + m.expires_at = nil +} + +// SetConsumedAt sets the "consumed_at" field. +func (m *PendingAuthSessionMutation) SetConsumedAt(t time.Time) { + m.consumed_at = &t +} + +// ConsumedAt returns the value of the "consumed_at" field in the mutation. +func (m *PendingAuthSessionMutation) ConsumedAt() (r time.Time, exists bool) { + v := m.consumed_at + if v == nil { + return + } + return *v, true +} + +// OldConsumedAt returns the old "consumed_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldConsumedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldConsumedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldConsumedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldConsumedAt: %w", err) + } + return oldValue.ConsumedAt, nil +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (m *PendingAuthSessionMutation) ClearConsumedAt() { + m.consumed_at = nil + m.clearedFields[pendingauthsession.FieldConsumedAt] = struct{}{} +} + +// ConsumedAtCleared returns if the "consumed_at" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) ConsumedAtCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldConsumedAt] + return ok +} + +// ResetConsumedAt resets all changes to the "consumed_at" field. +func (m *PendingAuthSessionMutation) ResetConsumedAt() { + m.consumed_at = nil + delete(m.clearedFields, pendingauthsession.FieldConsumedAt) +} + +// ClearTargetUser clears the "target_user" edge to the User entity. +func (m *PendingAuthSessionMutation) ClearTargetUser() { + m.clearedtarget_user = true + m.clearedFields[pendingauthsession.FieldTargetUserID] = struct{}{} +} + +// TargetUserCleared reports if the "target_user" edge to the User entity was cleared. +func (m *PendingAuthSessionMutation) TargetUserCleared() bool { + return m.TargetUserIDCleared() || m.clearedtarget_user +} + +// TargetUserIDs returns the "target_user" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// TargetUserID instead. It exists only for internal usage by the builders. +func (m *PendingAuthSessionMutation) TargetUserIDs() (ids []int64) { + if id := m.target_user; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetTargetUser resets all changes to the "target_user" edge. +func (m *PendingAuthSessionMutation) ResetTargetUser() { + m.target_user = nil + m.clearedtarget_user = false +} + +// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by id. +func (m *PendingAuthSessionMutation) SetAdoptionDecisionID(id int64) { + m.adoption_decision = &id +} + +// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (m *PendingAuthSessionMutation) ClearAdoptionDecision() { + m.clearedadoption_decision = true +} + +// AdoptionDecisionCleared reports if the "adoption_decision" edge to the IdentityAdoptionDecision entity was cleared. +func (m *PendingAuthSessionMutation) AdoptionDecisionCleared() bool { + return m.clearedadoption_decision +} + +// AdoptionDecisionID returns the "adoption_decision" edge ID in the mutation. +func (m *PendingAuthSessionMutation) AdoptionDecisionID() (id int64, exists bool) { + if m.adoption_decision != nil { + return *m.adoption_decision, true + } + return +} + +// AdoptionDecisionIDs returns the "adoption_decision" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// AdoptionDecisionID instead. It exists only for internal usage by the builders. +func (m *PendingAuthSessionMutation) AdoptionDecisionIDs() (ids []int64) { + if id := m.adoption_decision; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetAdoptionDecision resets all changes to the "adoption_decision" edge. +func (m *PendingAuthSessionMutation) ResetAdoptionDecision() { + m.adoption_decision = nil + m.clearedadoption_decision = false +} + +// Where appends a list predicates to the PendingAuthSessionMutation builder. +func (m *PendingAuthSessionMutation) Where(ps ...predicate.PendingAuthSession) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the PendingAuthSessionMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *PendingAuthSessionMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.PendingAuthSession, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *PendingAuthSessionMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *PendingAuthSessionMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (PendingAuthSession). +func (m *PendingAuthSessionMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *PendingAuthSessionMutation) Fields() []string { + fields := make([]string, 0, 21) + if m.created_at != nil { + fields = append(fields, pendingauthsession.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, pendingauthsession.FieldUpdatedAt) + } + if m.session_token != nil { + fields = append(fields, pendingauthsession.FieldSessionToken) + } + if m.intent != nil { + fields = append(fields, pendingauthsession.FieldIntent) + } + if m.provider_type != nil { + fields = append(fields, pendingauthsession.FieldProviderType) + } + if m.provider_key != nil { + fields = append(fields, pendingauthsession.FieldProviderKey) + } + if m.provider_subject != nil { + fields = append(fields, pendingauthsession.FieldProviderSubject) + } + if m.target_user != nil { + fields = append(fields, pendingauthsession.FieldTargetUserID) + } + if m.redirect_to != nil { + fields = append(fields, pendingauthsession.FieldRedirectTo) + } + if m.resolved_email != nil { + fields = append(fields, pendingauthsession.FieldResolvedEmail) + } + if m.registration_password_hash != nil { + fields = append(fields, pendingauthsession.FieldRegistrationPasswordHash) + } + if m.upstream_identity_claims != nil { + fields = append(fields, pendingauthsession.FieldUpstreamIdentityClaims) + } + if m.local_flow_state != nil { + fields = append(fields, pendingauthsession.FieldLocalFlowState) + } + if m.browser_session_key != nil { + fields = append(fields, pendingauthsession.FieldBrowserSessionKey) + } + if m.completion_code_hash != nil { + fields = append(fields, pendingauthsession.FieldCompletionCodeHash) + } + if m.completion_code_expires_at != nil { + fields = append(fields, pendingauthsession.FieldCompletionCodeExpiresAt) + } + if m.email_verified_at != nil { + fields = append(fields, pendingauthsession.FieldEmailVerifiedAt) + } + if m.password_verified_at != nil { + fields = append(fields, pendingauthsession.FieldPasswordVerifiedAt) + } + if m.totp_verified_at != nil { + fields = append(fields, pendingauthsession.FieldTotpVerifiedAt) + } + if m.expires_at != nil { + fields = append(fields, pendingauthsession.FieldExpiresAt) + } + if m.consumed_at != nil { + fields = append(fields, pendingauthsession.FieldConsumedAt) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *PendingAuthSessionMutation) Field(name string) (ent.Value, bool) { + switch name { + case pendingauthsession.FieldCreatedAt: + return m.CreatedAt() + case pendingauthsession.FieldUpdatedAt: + return m.UpdatedAt() + case pendingauthsession.FieldSessionToken: + return m.SessionToken() + case pendingauthsession.FieldIntent: + return m.Intent() + case pendingauthsession.FieldProviderType: + return m.ProviderType() + case pendingauthsession.FieldProviderKey: + return m.ProviderKey() + case pendingauthsession.FieldProviderSubject: + return m.ProviderSubject() + case pendingauthsession.FieldTargetUserID: + return m.TargetUserID() + case pendingauthsession.FieldRedirectTo: + return m.RedirectTo() + case pendingauthsession.FieldResolvedEmail: + return m.ResolvedEmail() + case pendingauthsession.FieldRegistrationPasswordHash: + return m.RegistrationPasswordHash() + case pendingauthsession.FieldUpstreamIdentityClaims: + return m.UpstreamIdentityClaims() + case pendingauthsession.FieldLocalFlowState: + return m.LocalFlowState() + case pendingauthsession.FieldBrowserSessionKey: + return m.BrowserSessionKey() + case pendingauthsession.FieldCompletionCodeHash: + return m.CompletionCodeHash() + case pendingauthsession.FieldCompletionCodeExpiresAt: + return m.CompletionCodeExpiresAt() + case pendingauthsession.FieldEmailVerifiedAt: + return m.EmailVerifiedAt() + case pendingauthsession.FieldPasswordVerifiedAt: + return m.PasswordVerifiedAt() + case pendingauthsession.FieldTotpVerifiedAt: + return m.TotpVerifiedAt() + case pendingauthsession.FieldExpiresAt: + return m.ExpiresAt() + case pendingauthsession.FieldConsumedAt: + return m.ConsumedAt() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *PendingAuthSessionMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case pendingauthsession.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case pendingauthsession.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case pendingauthsession.FieldSessionToken: + return m.OldSessionToken(ctx) + case pendingauthsession.FieldIntent: + return m.OldIntent(ctx) + case pendingauthsession.FieldProviderType: + return m.OldProviderType(ctx) + case pendingauthsession.FieldProviderKey: + return m.OldProviderKey(ctx) + case pendingauthsession.FieldProviderSubject: + return m.OldProviderSubject(ctx) + case pendingauthsession.FieldTargetUserID: + return m.OldTargetUserID(ctx) + case pendingauthsession.FieldRedirectTo: + return m.OldRedirectTo(ctx) + case pendingauthsession.FieldResolvedEmail: + return m.OldResolvedEmail(ctx) + case pendingauthsession.FieldRegistrationPasswordHash: + return m.OldRegistrationPasswordHash(ctx) + case pendingauthsession.FieldUpstreamIdentityClaims: + return m.OldUpstreamIdentityClaims(ctx) + case pendingauthsession.FieldLocalFlowState: + return m.OldLocalFlowState(ctx) + case pendingauthsession.FieldBrowserSessionKey: + return m.OldBrowserSessionKey(ctx) + case pendingauthsession.FieldCompletionCodeHash: + return m.OldCompletionCodeHash(ctx) + case pendingauthsession.FieldCompletionCodeExpiresAt: + return m.OldCompletionCodeExpiresAt(ctx) + case pendingauthsession.FieldEmailVerifiedAt: + return m.OldEmailVerifiedAt(ctx) + case pendingauthsession.FieldPasswordVerifiedAt: + return m.OldPasswordVerifiedAt(ctx) + case pendingauthsession.FieldTotpVerifiedAt: + return m.OldTotpVerifiedAt(ctx) + case pendingauthsession.FieldExpiresAt: + return m.OldExpiresAt(ctx) + case pendingauthsession.FieldConsumedAt: + return m.OldConsumedAt(ctx) + } + return nil, fmt.Errorf("unknown PendingAuthSession field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *PendingAuthSessionMutation) SetField(name string, value ent.Value) error { + switch name { + case pendingauthsession.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case pendingauthsession.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case pendingauthsession.FieldSessionToken: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSessionToken(v) + return nil + case pendingauthsession.FieldIntent: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIntent(v) + return nil + case pendingauthsession.FieldProviderType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderType(v) + return nil + case pendingauthsession.FieldProviderKey: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderKey(v) + return nil + case pendingauthsession.FieldProviderSubject: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderSubject(v) + return nil + case pendingauthsession.FieldTargetUserID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTargetUserID(v) + return nil + case pendingauthsession.FieldRedirectTo: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRedirectTo(v) + return nil + case pendingauthsession.FieldResolvedEmail: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetResolvedEmail(v) + return nil + case pendingauthsession.FieldRegistrationPasswordHash: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRegistrationPasswordHash(v) + return nil + case pendingauthsession.FieldUpstreamIdentityClaims: + v, ok := value.(map[string]interface{}) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpstreamIdentityClaims(v) + return nil + case pendingauthsession.FieldLocalFlowState: + v, ok := value.(map[string]interface{}) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLocalFlowState(v) + return nil + case pendingauthsession.FieldBrowserSessionKey: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetBrowserSessionKey(v) + return nil + case pendingauthsession.FieldCompletionCodeHash: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCompletionCodeHash(v) + return nil + case pendingauthsession.FieldCompletionCodeExpiresAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCompletionCodeExpiresAt(v) + return nil + case pendingauthsession.FieldEmailVerifiedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetEmailVerifiedAt(v) + return nil + case pendingauthsession.FieldPasswordVerifiedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPasswordVerifiedAt(v) + return nil + case pendingauthsession.FieldTotpVerifiedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTotpVerifiedAt(v) + return nil + case pendingauthsession.FieldExpiresAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetExpiresAt(v) + return nil + case pendingauthsession.FieldConsumedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetConsumedAt(v) + return nil + } + return fmt.Errorf("unknown PendingAuthSession field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *PendingAuthSessionMutation) AddedFields() []string { + var fields []string + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *PendingAuthSessionMutation) AddedField(name string) (ent.Value, bool) { + switch name { + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *PendingAuthSessionMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown PendingAuthSession numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *PendingAuthSessionMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(pendingauthsession.FieldTargetUserID) { + fields = append(fields, pendingauthsession.FieldTargetUserID) + } + if m.FieldCleared(pendingauthsession.FieldCompletionCodeExpiresAt) { + fields = append(fields, pendingauthsession.FieldCompletionCodeExpiresAt) + } + if m.FieldCleared(pendingauthsession.FieldEmailVerifiedAt) { + fields = append(fields, pendingauthsession.FieldEmailVerifiedAt) + } + if m.FieldCleared(pendingauthsession.FieldPasswordVerifiedAt) { + fields = append(fields, pendingauthsession.FieldPasswordVerifiedAt) + } + if m.FieldCleared(pendingauthsession.FieldTotpVerifiedAt) { + fields = append(fields, pendingauthsession.FieldTotpVerifiedAt) + } + if m.FieldCleared(pendingauthsession.FieldConsumedAt) { + fields = append(fields, pendingauthsession.FieldConsumedAt) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *PendingAuthSessionMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *PendingAuthSessionMutation) ClearField(name string) error { + switch name { + case pendingauthsession.FieldTargetUserID: + m.ClearTargetUserID() + return nil + case pendingauthsession.FieldCompletionCodeExpiresAt: + m.ClearCompletionCodeExpiresAt() + return nil + case pendingauthsession.FieldEmailVerifiedAt: + m.ClearEmailVerifiedAt() + return nil + case pendingauthsession.FieldPasswordVerifiedAt: + m.ClearPasswordVerifiedAt() + return nil + case pendingauthsession.FieldTotpVerifiedAt: + m.ClearTotpVerifiedAt() + return nil + case pendingauthsession.FieldConsumedAt: + m.ClearConsumedAt() + return nil + } + return fmt.Errorf("unknown PendingAuthSession nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *PendingAuthSessionMutation) ResetField(name string) error { + switch name { + case pendingauthsession.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case pendingauthsession.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case pendingauthsession.FieldSessionToken: + m.ResetSessionToken() + return nil + case pendingauthsession.FieldIntent: + m.ResetIntent() + return nil + case pendingauthsession.FieldProviderType: + m.ResetProviderType() + return nil + case pendingauthsession.FieldProviderKey: + m.ResetProviderKey() + return nil + case pendingauthsession.FieldProviderSubject: + m.ResetProviderSubject() + return nil + case pendingauthsession.FieldTargetUserID: + m.ResetTargetUserID() + return nil + case pendingauthsession.FieldRedirectTo: + m.ResetRedirectTo() + return nil + case pendingauthsession.FieldResolvedEmail: + m.ResetResolvedEmail() + return nil + case pendingauthsession.FieldRegistrationPasswordHash: + m.ResetRegistrationPasswordHash() + return nil + case pendingauthsession.FieldUpstreamIdentityClaims: + m.ResetUpstreamIdentityClaims() + return nil + case pendingauthsession.FieldLocalFlowState: + m.ResetLocalFlowState() + return nil + case pendingauthsession.FieldBrowserSessionKey: + m.ResetBrowserSessionKey() + return nil + case pendingauthsession.FieldCompletionCodeHash: + m.ResetCompletionCodeHash() + return nil + case pendingauthsession.FieldCompletionCodeExpiresAt: + m.ResetCompletionCodeExpiresAt() + return nil + case pendingauthsession.FieldEmailVerifiedAt: + m.ResetEmailVerifiedAt() + return nil + case pendingauthsession.FieldPasswordVerifiedAt: + m.ResetPasswordVerifiedAt() + return nil + case pendingauthsession.FieldTotpVerifiedAt: + m.ResetTotpVerifiedAt() + return nil + case pendingauthsession.FieldExpiresAt: + m.ResetExpiresAt() + return nil + case pendingauthsession.FieldConsumedAt: + m.ResetConsumedAt() + return nil + } + return fmt.Errorf("unknown PendingAuthSession field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *PendingAuthSessionMutation) AddedEdges() []string { + edges := make([]string, 0, 2) + if m.target_user != nil { + edges = append(edges, pendingauthsession.EdgeTargetUser) + } + if m.adoption_decision != nil { + edges = append(edges, pendingauthsession.EdgeAdoptionDecision) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *PendingAuthSessionMutation) AddedIDs(name string) []ent.Value { + switch name { + case pendingauthsession.EdgeTargetUser: + if id := m.target_user; id != nil { + return []ent.Value{*id} + } + case pendingauthsession.EdgeAdoptionDecision: + if id := m.adoption_decision; id != nil { + return []ent.Value{*id} + } + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *PendingAuthSessionMutation) RemovedEdges() []string { + edges := make([]string, 0, 2) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *PendingAuthSessionMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *PendingAuthSessionMutation) ClearedEdges() []string { + edges := make([]string, 0, 2) + if m.clearedtarget_user { + edges = append(edges, pendingauthsession.EdgeTargetUser) + } + if m.clearedadoption_decision { + edges = append(edges, pendingauthsession.EdgeAdoptionDecision) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *PendingAuthSessionMutation) EdgeCleared(name string) bool { + switch name { + case pendingauthsession.EdgeTargetUser: + return m.clearedtarget_user + case pendingauthsession.EdgeAdoptionDecision: + return m.clearedadoption_decision + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *PendingAuthSessionMutation) ClearEdge(name string) error { + switch name { + case pendingauthsession.EdgeTargetUser: + m.ClearTargetUser() + return nil + case pendingauthsession.EdgeAdoptionDecision: + m.ClearAdoptionDecision() + return nil + } + return fmt.Errorf("unknown PendingAuthSession unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *PendingAuthSessionMutation) ResetEdge(name string) error { + switch name { + case pendingauthsession.EdgeTargetUser: + m.ResetTargetUser() + return nil + case pendingauthsession.EdgeAdoptionDecision: + m.ResetAdoptionDecision() + return nil + } + return fmt.Errorf("unknown PendingAuthSession edge %s", name) +} + // PromoCodeMutation represents an operation that mutates the PromoCode nodes in the graph. type PromoCodeMutation struct { config @@ -28264,6 +32525,9 @@ type UserMutation struct { totp_secret_encrypted *string totp_enabled *bool totp_enabled_at *time.Time + signup_source *string + last_login_at *time.Time + last_active_at *time.Time balance_notify_enabled *bool balance_notify_threshold_type *string balance_notify_threshold *float64 @@ -28302,6 +32566,12 @@ type UserMutation struct { payment_orders map[int64]struct{} removedpayment_orders map[int64]struct{} clearedpayment_orders bool + auth_identities map[int64]struct{} + removedauth_identities map[int64]struct{} + clearedauth_identities bool + pending_auth_sessions map[int64]struct{} + removedpending_auth_sessions map[int64]struct{} + clearedpending_auth_sessions bool done bool oldValue func(context.Context) (*User, error) predicates []predicate.User @@ -28988,6 +33258,140 @@ func (m *UserMutation) ResetTotpEnabledAt() { delete(m.clearedFields, user.FieldTotpEnabledAt) } +// SetSignupSource sets the "signup_source" field. +func (m *UserMutation) SetSignupSource(s string) { + m.signup_source = &s +} + +// SignupSource returns the value of the "signup_source" field in the mutation. +func (m *UserMutation) SignupSource() (r string, exists bool) { + v := m.signup_source + if v == nil { + return + } + return *v, true +} + +// OldSignupSource returns the old "signup_source" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldSignupSource(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSignupSource is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSignupSource requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSignupSource: %w", err) + } + return oldValue.SignupSource, nil +} + +// ResetSignupSource resets all changes to the "signup_source" field. +func (m *UserMutation) ResetSignupSource() { + m.signup_source = nil +} + +// SetLastLoginAt sets the "last_login_at" field. +func (m *UserMutation) SetLastLoginAt(t time.Time) { + m.last_login_at = &t +} + +// LastLoginAt returns the value of the "last_login_at" field in the mutation. +func (m *UserMutation) LastLoginAt() (r time.Time, exists bool) { + v := m.last_login_at + if v == nil { + return + } + return *v, true +} + +// OldLastLoginAt returns the old "last_login_at" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldLastLoginAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLastLoginAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLastLoginAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLastLoginAt: %w", err) + } + return oldValue.LastLoginAt, nil +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (m *UserMutation) ClearLastLoginAt() { + m.last_login_at = nil + m.clearedFields[user.FieldLastLoginAt] = struct{}{} +} + +// LastLoginAtCleared returns if the "last_login_at" field was cleared in this mutation. +func (m *UserMutation) LastLoginAtCleared() bool { + _, ok := m.clearedFields[user.FieldLastLoginAt] + return ok +} + +// ResetLastLoginAt resets all changes to the "last_login_at" field. +func (m *UserMutation) ResetLastLoginAt() { + m.last_login_at = nil + delete(m.clearedFields, user.FieldLastLoginAt) +} + +// SetLastActiveAt sets the "last_active_at" field. +func (m *UserMutation) SetLastActiveAt(t time.Time) { + m.last_active_at = &t +} + +// LastActiveAt returns the value of the "last_active_at" field in the mutation. +func (m *UserMutation) LastActiveAt() (r time.Time, exists bool) { + v := m.last_active_at + if v == nil { + return + } + return *v, true +} + +// OldLastActiveAt returns the old "last_active_at" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldLastActiveAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLastActiveAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLastActiveAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLastActiveAt: %w", err) + } + return oldValue.LastActiveAt, nil +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (m *UserMutation) ClearLastActiveAt() { + m.last_active_at = nil + m.clearedFields[user.FieldLastActiveAt] = struct{}{} +} + +// LastActiveAtCleared returns if the "last_active_at" field was cleared in this mutation. +func (m *UserMutation) LastActiveAtCleared() bool { + _, ok := m.clearedFields[user.FieldLastActiveAt] + return ok +} + +// ResetLastActiveAt resets all changes to the "last_active_at" field. +func (m *UserMutation) ResetLastActiveAt() { + m.last_active_at = nil + delete(m.clearedFields, user.FieldLastActiveAt) +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (m *UserMutation) SetBalanceNotifyEnabled(b bool) { m.balance_notify_enabled = &b @@ -29762,6 +34166,114 @@ func (m *UserMutation) ResetPaymentOrders() { m.removedpayment_orders = nil } +// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by ids. +func (m *UserMutation) AddAuthIdentityIDs(ids ...int64) { + if m.auth_identities == nil { + m.auth_identities = make(map[int64]struct{}) + } + for i := range ids { + m.auth_identities[ids[i]] = struct{}{} + } +} + +// ClearAuthIdentities clears the "auth_identities" edge to the AuthIdentity entity. +func (m *UserMutation) ClearAuthIdentities() { + m.clearedauth_identities = true +} + +// AuthIdentitiesCleared reports if the "auth_identities" edge to the AuthIdentity entity was cleared. +func (m *UserMutation) AuthIdentitiesCleared() bool { + return m.clearedauth_identities +} + +// RemoveAuthIdentityIDs removes the "auth_identities" edge to the AuthIdentity entity by IDs. +func (m *UserMutation) RemoveAuthIdentityIDs(ids ...int64) { + if m.removedauth_identities == nil { + m.removedauth_identities = make(map[int64]struct{}) + } + for i := range ids { + delete(m.auth_identities, ids[i]) + m.removedauth_identities[ids[i]] = struct{}{} + } +} + +// RemovedAuthIdentities returns the removed IDs of the "auth_identities" edge to the AuthIdentity entity. +func (m *UserMutation) RemovedAuthIdentitiesIDs() (ids []int64) { + for id := range m.removedauth_identities { + ids = append(ids, id) + } + return +} + +// AuthIdentitiesIDs returns the "auth_identities" edge IDs in the mutation. +func (m *UserMutation) AuthIdentitiesIDs() (ids []int64) { + for id := range m.auth_identities { + ids = append(ids, id) + } + return +} + +// ResetAuthIdentities resets all changes to the "auth_identities" edge. +func (m *UserMutation) ResetAuthIdentities() { + m.auth_identities = nil + m.clearedauth_identities = false + m.removedauth_identities = nil +} + +// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by ids. +func (m *UserMutation) AddPendingAuthSessionIDs(ids ...int64) { + if m.pending_auth_sessions == nil { + m.pending_auth_sessions = make(map[int64]struct{}) + } + for i := range ids { + m.pending_auth_sessions[ids[i]] = struct{}{} + } +} + +// ClearPendingAuthSessions clears the "pending_auth_sessions" edge to the PendingAuthSession entity. +func (m *UserMutation) ClearPendingAuthSessions() { + m.clearedpending_auth_sessions = true +} + +// PendingAuthSessionsCleared reports if the "pending_auth_sessions" edge to the PendingAuthSession entity was cleared. +func (m *UserMutation) PendingAuthSessionsCleared() bool { + return m.clearedpending_auth_sessions +} + +// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs. +func (m *UserMutation) RemovePendingAuthSessionIDs(ids ...int64) { + if m.removedpending_auth_sessions == nil { + m.removedpending_auth_sessions = make(map[int64]struct{}) + } + for i := range ids { + delete(m.pending_auth_sessions, ids[i]) + m.removedpending_auth_sessions[ids[i]] = struct{}{} + } +} + +// RemovedPendingAuthSessions returns the removed IDs of the "pending_auth_sessions" edge to the PendingAuthSession entity. +func (m *UserMutation) RemovedPendingAuthSessionsIDs() (ids []int64) { + for id := range m.removedpending_auth_sessions { + ids = append(ids, id) + } + return +} + +// PendingAuthSessionsIDs returns the "pending_auth_sessions" edge IDs in the mutation. +func (m *UserMutation) PendingAuthSessionsIDs() (ids []int64) { + for id := range m.pending_auth_sessions { + ids = append(ids, id) + } + return +} + +// ResetPendingAuthSessions resets all changes to the "pending_auth_sessions" edge. +func (m *UserMutation) ResetPendingAuthSessions() { + m.pending_auth_sessions = nil + m.clearedpending_auth_sessions = false + m.removedpending_auth_sessions = nil +} + // Where appends a list predicates to the UserMutation builder. func (m *UserMutation) Where(ps ...predicate.User) { m.predicates = append(m.predicates, ps...) @@ -29796,7 +34308,7 @@ func (m *UserMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserMutation) Fields() []string { - fields := make([]string, 0, 19) + fields := make([]string, 0, 22) if m.created_at != nil { fields = append(fields, user.FieldCreatedAt) } @@ -29839,6 +34351,15 @@ func (m *UserMutation) Fields() []string { if m.totp_enabled_at != nil { fields = append(fields, user.FieldTotpEnabledAt) } + if m.signup_source != nil { + fields = append(fields, user.FieldSignupSource) + } + if m.last_login_at != nil { + fields = append(fields, user.FieldLastLoginAt) + } + if m.last_active_at != nil { + fields = append(fields, user.FieldLastActiveAt) + } if m.balance_notify_enabled != nil { fields = append(fields, user.FieldBalanceNotifyEnabled) } @@ -29890,6 +34411,12 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { return m.TotpEnabled() case user.FieldTotpEnabledAt: return m.TotpEnabledAt() + case user.FieldSignupSource: + return m.SignupSource() + case user.FieldLastLoginAt: + return m.LastLoginAt() + case user.FieldLastActiveAt: + return m.LastActiveAt() case user.FieldBalanceNotifyEnabled: return m.BalanceNotifyEnabled() case user.FieldBalanceNotifyThresholdType: @@ -29937,6 +34464,12 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldTotpEnabled(ctx) case user.FieldTotpEnabledAt: return m.OldTotpEnabledAt(ctx) + case user.FieldSignupSource: + return m.OldSignupSource(ctx) + case user.FieldLastLoginAt: + return m.OldLastLoginAt(ctx) + case user.FieldLastActiveAt: + return m.OldLastActiveAt(ctx) case user.FieldBalanceNotifyEnabled: return m.OldBalanceNotifyEnabled(ctx) case user.FieldBalanceNotifyThresholdType: @@ -30054,6 +34587,27 @@ func (m *UserMutation) SetField(name string, value ent.Value) error { } m.SetTotpEnabledAt(v) return nil + case user.FieldSignupSource: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSignupSource(v) + return nil + case user.FieldLastLoginAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLastLoginAt(v) + return nil + case user.FieldLastActiveAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLastActiveAt(v) + return nil case user.FieldBalanceNotifyEnabled: v, ok := value.(bool) if !ok { @@ -30179,6 +34733,12 @@ func (m *UserMutation) ClearedFields() []string { if m.FieldCleared(user.FieldTotpEnabledAt) { fields = append(fields, user.FieldTotpEnabledAt) } + if m.FieldCleared(user.FieldLastLoginAt) { + fields = append(fields, user.FieldLastLoginAt) + } + if m.FieldCleared(user.FieldLastActiveAt) { + fields = append(fields, user.FieldLastActiveAt) + } if m.FieldCleared(user.FieldBalanceNotifyThreshold) { fields = append(fields, user.FieldBalanceNotifyThreshold) } @@ -30205,6 +34765,12 @@ func (m *UserMutation) ClearField(name string) error { case user.FieldTotpEnabledAt: m.ClearTotpEnabledAt() return nil + case user.FieldLastLoginAt: + m.ClearLastLoginAt() + return nil + case user.FieldLastActiveAt: + m.ClearLastActiveAt() + return nil case user.FieldBalanceNotifyThreshold: m.ClearBalanceNotifyThreshold() return nil @@ -30258,6 +34824,15 @@ func (m *UserMutation) ResetField(name string) error { case user.FieldTotpEnabledAt: m.ResetTotpEnabledAt() return nil + case user.FieldSignupSource: + m.ResetSignupSource() + return nil + case user.FieldLastLoginAt: + m.ResetLastLoginAt() + return nil + case user.FieldLastActiveAt: + m.ResetLastActiveAt() + return nil case user.FieldBalanceNotifyEnabled: m.ResetBalanceNotifyEnabled() return nil @@ -30279,7 +34854,7 @@ func (m *UserMutation) ResetField(name string) error { // AddedEdges returns all edge names that were set/added in this mutation. func (m *UserMutation) AddedEdges() []string { - edges := make([]string, 0, 10) + edges := make([]string, 0, 12) if m.api_keys != nil { edges = append(edges, user.EdgeAPIKeys) } @@ -30310,6 +34885,12 @@ func (m *UserMutation) AddedEdges() []string { if m.payment_orders != nil { edges = append(edges, user.EdgePaymentOrders) } + if m.auth_identities != nil { + edges = append(edges, user.EdgeAuthIdentities) + } + if m.pending_auth_sessions != nil { + edges = append(edges, user.EdgePendingAuthSessions) + } return edges } @@ -30377,13 +34958,25 @@ func (m *UserMutation) AddedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case user.EdgeAuthIdentities: + ids := make([]ent.Value, 0, len(m.auth_identities)) + for id := range m.auth_identities { + ids = append(ids, id) + } + return ids + case user.EdgePendingAuthSessions: + ids := make([]ent.Value, 0, len(m.pending_auth_sessions)) + for id := range m.pending_auth_sessions { + ids = append(ids, id) + } + return ids } return nil } // RemovedEdges returns all edge names that were removed in this mutation. func (m *UserMutation) RemovedEdges() []string { - edges := make([]string, 0, 10) + edges := make([]string, 0, 12) if m.removedapi_keys != nil { edges = append(edges, user.EdgeAPIKeys) } @@ -30414,6 +35007,12 @@ func (m *UserMutation) RemovedEdges() []string { if m.removedpayment_orders != nil { edges = append(edges, user.EdgePaymentOrders) } + if m.removedauth_identities != nil { + edges = append(edges, user.EdgeAuthIdentities) + } + if m.removedpending_auth_sessions != nil { + edges = append(edges, user.EdgePendingAuthSessions) + } return edges } @@ -30481,13 +35080,25 @@ func (m *UserMutation) RemovedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case user.EdgeAuthIdentities: + ids := make([]ent.Value, 0, len(m.removedauth_identities)) + for id := range m.removedauth_identities { + ids = append(ids, id) + } + return ids + case user.EdgePendingAuthSessions: + ids := make([]ent.Value, 0, len(m.removedpending_auth_sessions)) + for id := range m.removedpending_auth_sessions { + ids = append(ids, id) + } + return ids } return nil } // ClearedEdges returns all edge names that were cleared in this mutation. func (m *UserMutation) ClearedEdges() []string { - edges := make([]string, 0, 10) + edges := make([]string, 0, 12) if m.clearedapi_keys { edges = append(edges, user.EdgeAPIKeys) } @@ -30518,6 +35129,12 @@ func (m *UserMutation) ClearedEdges() []string { if m.clearedpayment_orders { edges = append(edges, user.EdgePaymentOrders) } + if m.clearedauth_identities { + edges = append(edges, user.EdgeAuthIdentities) + } + if m.clearedpending_auth_sessions { + edges = append(edges, user.EdgePendingAuthSessions) + } return edges } @@ -30545,6 +35162,10 @@ func (m *UserMutation) EdgeCleared(name string) bool { return m.clearedpromo_code_usages case user.EdgePaymentOrders: return m.clearedpayment_orders + case user.EdgeAuthIdentities: + return m.clearedauth_identities + case user.EdgePendingAuthSessions: + return m.clearedpending_auth_sessions } return false } @@ -30591,6 +35212,12 @@ func (m *UserMutation) ResetEdge(name string) error { case user.EdgePaymentOrders: m.ResetPaymentOrders() return nil + case user.EdgeAuthIdentities: + m.ResetAuthIdentities() + return nil + case user.EdgePendingAuthSessions: + m.ResetPendingAuthSessions() + return nil } return fmt.Errorf("unknown User edge %s", name) } diff --git a/backend/ent/pendingauthsession.go b/backend/ent/pendingauthsession.go new file mode 100644 index 00000000..e77c065f --- /dev/null +++ b/backend/ent/pendingauthsession.go @@ -0,0 +1,399 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// PendingAuthSession is the model entity for the PendingAuthSession schema. +type PendingAuthSession struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // SessionToken holds the value of the "session_token" field. + SessionToken string `json:"session_token,omitempty"` + // Intent holds the value of the "intent" field. + Intent string `json:"intent,omitempty"` + // ProviderType holds the value of the "provider_type" field. + ProviderType string `json:"provider_type,omitempty"` + // ProviderKey holds the value of the "provider_key" field. + ProviderKey string `json:"provider_key,omitempty"` + // ProviderSubject holds the value of the "provider_subject" field. + ProviderSubject string `json:"provider_subject,omitempty"` + // TargetUserID holds the value of the "target_user_id" field. + TargetUserID *int64 `json:"target_user_id,omitempty"` + // RedirectTo holds the value of the "redirect_to" field. + RedirectTo string `json:"redirect_to,omitempty"` + // ResolvedEmail holds the value of the "resolved_email" field. + ResolvedEmail string `json:"resolved_email,omitempty"` + // RegistrationPasswordHash holds the value of the "registration_password_hash" field. + RegistrationPasswordHash string `json:"registration_password_hash,omitempty"` + // UpstreamIdentityClaims holds the value of the "upstream_identity_claims" field. + UpstreamIdentityClaims map[string]interface{} `json:"upstream_identity_claims,omitempty"` + // LocalFlowState holds the value of the "local_flow_state" field. + LocalFlowState map[string]interface{} `json:"local_flow_state,omitempty"` + // BrowserSessionKey holds the value of the "browser_session_key" field. + BrowserSessionKey string `json:"browser_session_key,omitempty"` + // CompletionCodeHash holds the value of the "completion_code_hash" field. + CompletionCodeHash string `json:"completion_code_hash,omitempty"` + // CompletionCodeExpiresAt holds the value of the "completion_code_expires_at" field. + CompletionCodeExpiresAt *time.Time `json:"completion_code_expires_at,omitempty"` + // EmailVerifiedAt holds the value of the "email_verified_at" field. + EmailVerifiedAt *time.Time `json:"email_verified_at,omitempty"` + // PasswordVerifiedAt holds the value of the "password_verified_at" field. + PasswordVerifiedAt *time.Time `json:"password_verified_at,omitempty"` + // TotpVerifiedAt holds the value of the "totp_verified_at" field. + TotpVerifiedAt *time.Time `json:"totp_verified_at,omitempty"` + // ExpiresAt holds the value of the "expires_at" field. + ExpiresAt time.Time `json:"expires_at,omitempty"` + // ConsumedAt holds the value of the "consumed_at" field. + ConsumedAt *time.Time `json:"consumed_at,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the PendingAuthSessionQuery when eager-loading is set. + Edges PendingAuthSessionEdges `json:"edges"` + selectValues sql.SelectValues +} + +// PendingAuthSessionEdges holds the relations/edges for other nodes in the graph. +type PendingAuthSessionEdges struct { + // TargetUser holds the value of the target_user edge. + TargetUser *User `json:"target_user,omitempty"` + // AdoptionDecision holds the value of the adoption_decision edge. + AdoptionDecision *IdentityAdoptionDecision `json:"adoption_decision,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [2]bool +} + +// TargetUserOrErr returns the TargetUser value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e PendingAuthSessionEdges) TargetUserOrErr() (*User, error) { + if e.TargetUser != nil { + return e.TargetUser, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: user.Label} + } + return nil, &NotLoadedError{edge: "target_user"} +} + +// AdoptionDecisionOrErr returns the AdoptionDecision value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e PendingAuthSessionEdges) AdoptionDecisionOrErr() (*IdentityAdoptionDecision, error) { + if e.AdoptionDecision != nil { + return e.AdoptionDecision, nil + } else if e.loadedTypes[1] { + return nil, &NotFoundError{label: identityadoptiondecision.Label} + } + return nil, &NotLoadedError{edge: "adoption_decision"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*PendingAuthSession) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case pendingauthsession.FieldUpstreamIdentityClaims, pendingauthsession.FieldLocalFlowState: + values[i] = new([]byte) + case pendingauthsession.FieldID, pendingauthsession.FieldTargetUserID: + values[i] = new(sql.NullInt64) + case pendingauthsession.FieldSessionToken, pendingauthsession.FieldIntent, pendingauthsession.FieldProviderType, pendingauthsession.FieldProviderKey, pendingauthsession.FieldProviderSubject, pendingauthsession.FieldRedirectTo, pendingauthsession.FieldResolvedEmail, pendingauthsession.FieldRegistrationPasswordHash, pendingauthsession.FieldBrowserSessionKey, pendingauthsession.FieldCompletionCodeHash: + values[i] = new(sql.NullString) + case pendingauthsession.FieldCreatedAt, pendingauthsession.FieldUpdatedAt, pendingauthsession.FieldCompletionCodeExpiresAt, pendingauthsession.FieldEmailVerifiedAt, pendingauthsession.FieldPasswordVerifiedAt, pendingauthsession.FieldTotpVerifiedAt, pendingauthsession.FieldExpiresAt, pendingauthsession.FieldConsumedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the PendingAuthSession fields. +func (_m *PendingAuthSession) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case pendingauthsession.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case pendingauthsession.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case pendingauthsession.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case pendingauthsession.FieldSessionToken: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field session_token", values[i]) + } else if value.Valid { + _m.SessionToken = value.String + } + case pendingauthsession.FieldIntent: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field intent", values[i]) + } else if value.Valid { + _m.Intent = value.String + } + case pendingauthsession.FieldProviderType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_type", values[i]) + } else if value.Valid { + _m.ProviderType = value.String + } + case pendingauthsession.FieldProviderKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_key", values[i]) + } else if value.Valid { + _m.ProviderKey = value.String + } + case pendingauthsession.FieldProviderSubject: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_subject", values[i]) + } else if value.Valid { + _m.ProviderSubject = value.String + } + case pendingauthsession.FieldTargetUserID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field target_user_id", values[i]) + } else if value.Valid { + _m.TargetUserID = new(int64) + *_m.TargetUserID = value.Int64 + } + case pendingauthsession.FieldRedirectTo: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field redirect_to", values[i]) + } else if value.Valid { + _m.RedirectTo = value.String + } + case pendingauthsession.FieldResolvedEmail: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field resolved_email", values[i]) + } else if value.Valid { + _m.ResolvedEmail = value.String + } + case pendingauthsession.FieldRegistrationPasswordHash: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field registration_password_hash", values[i]) + } else if value.Valid { + _m.RegistrationPasswordHash = value.String + } + case pendingauthsession.FieldUpstreamIdentityClaims: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field upstream_identity_claims", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.UpstreamIdentityClaims); err != nil { + return fmt.Errorf("unmarshal field upstream_identity_claims: %w", err) + } + } + case pendingauthsession.FieldLocalFlowState: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field local_flow_state", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.LocalFlowState); err != nil { + return fmt.Errorf("unmarshal field local_flow_state: %w", err) + } + } + case pendingauthsession.FieldBrowserSessionKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field browser_session_key", values[i]) + } else if value.Valid { + _m.BrowserSessionKey = value.String + } + case pendingauthsession.FieldCompletionCodeHash: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field completion_code_hash", values[i]) + } else if value.Valid { + _m.CompletionCodeHash = value.String + } + case pendingauthsession.FieldCompletionCodeExpiresAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field completion_code_expires_at", values[i]) + } else if value.Valid { + _m.CompletionCodeExpiresAt = new(time.Time) + *_m.CompletionCodeExpiresAt = value.Time + } + case pendingauthsession.FieldEmailVerifiedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field email_verified_at", values[i]) + } else if value.Valid { + _m.EmailVerifiedAt = new(time.Time) + *_m.EmailVerifiedAt = value.Time + } + case pendingauthsession.FieldPasswordVerifiedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field password_verified_at", values[i]) + } else if value.Valid { + _m.PasswordVerifiedAt = new(time.Time) + *_m.PasswordVerifiedAt = value.Time + } + case pendingauthsession.FieldTotpVerifiedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field totp_verified_at", values[i]) + } else if value.Valid { + _m.TotpVerifiedAt = new(time.Time) + *_m.TotpVerifiedAt = value.Time + } + case pendingauthsession.FieldExpiresAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field expires_at", values[i]) + } else if value.Valid { + _m.ExpiresAt = value.Time + } + case pendingauthsession.FieldConsumedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field consumed_at", values[i]) + } else if value.Valid { + _m.ConsumedAt = new(time.Time) + *_m.ConsumedAt = value.Time + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the PendingAuthSession. +// This includes values selected through modifiers, order, etc. +func (_m *PendingAuthSession) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryTargetUser queries the "target_user" edge of the PendingAuthSession entity. +func (_m *PendingAuthSession) QueryTargetUser() *UserQuery { + return NewPendingAuthSessionClient(_m.config).QueryTargetUser(_m) +} + +// QueryAdoptionDecision queries the "adoption_decision" edge of the PendingAuthSession entity. +func (_m *PendingAuthSession) QueryAdoptionDecision() *IdentityAdoptionDecisionQuery { + return NewPendingAuthSessionClient(_m.config).QueryAdoptionDecision(_m) +} + +// Update returns a builder for updating this PendingAuthSession. +// Note that you need to call PendingAuthSession.Unwrap() before calling this method if this PendingAuthSession +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *PendingAuthSession) Update() *PendingAuthSessionUpdateOne { + return NewPendingAuthSessionClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the PendingAuthSession entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *PendingAuthSession) Unwrap() *PendingAuthSession { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: PendingAuthSession is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *PendingAuthSession) String() string { + var builder strings.Builder + builder.WriteString("PendingAuthSession(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("session_token=") + builder.WriteString(_m.SessionToken) + builder.WriteString(", ") + builder.WriteString("intent=") + builder.WriteString(_m.Intent) + builder.WriteString(", ") + builder.WriteString("provider_type=") + builder.WriteString(_m.ProviderType) + builder.WriteString(", ") + builder.WriteString("provider_key=") + builder.WriteString(_m.ProviderKey) + builder.WriteString(", ") + builder.WriteString("provider_subject=") + builder.WriteString(_m.ProviderSubject) + builder.WriteString(", ") + if v := _m.TargetUserID; v != nil { + builder.WriteString("target_user_id=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + builder.WriteString("redirect_to=") + builder.WriteString(_m.RedirectTo) + builder.WriteString(", ") + builder.WriteString("resolved_email=") + builder.WriteString(_m.ResolvedEmail) + builder.WriteString(", ") + builder.WriteString("registration_password_hash=") + builder.WriteString(_m.RegistrationPasswordHash) + builder.WriteString(", ") + builder.WriteString("upstream_identity_claims=") + builder.WriteString(fmt.Sprintf("%v", _m.UpstreamIdentityClaims)) + builder.WriteString(", ") + builder.WriteString("local_flow_state=") + builder.WriteString(fmt.Sprintf("%v", _m.LocalFlowState)) + builder.WriteString(", ") + builder.WriteString("browser_session_key=") + builder.WriteString(_m.BrowserSessionKey) + builder.WriteString(", ") + builder.WriteString("completion_code_hash=") + builder.WriteString(_m.CompletionCodeHash) + builder.WriteString(", ") + if v := _m.CompletionCodeExpiresAt; v != nil { + builder.WriteString("completion_code_expires_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.EmailVerifiedAt; v != nil { + builder.WriteString("email_verified_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.PasswordVerifiedAt; v != nil { + builder.WriteString("password_verified_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.TotpVerifiedAt; v != nil { + builder.WriteString("totp_verified_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("expires_at=") + builder.WriteString(_m.ExpiresAt.Format(time.ANSIC)) + builder.WriteString(", ") + if v := _m.ConsumedAt; v != nil { + builder.WriteString("consumed_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteByte(')') + return builder.String() +} + +// PendingAuthSessions is a parsable slice of PendingAuthSession. +type PendingAuthSessions []*PendingAuthSession diff --git a/backend/ent/pendingauthsession/pendingauthsession.go b/backend/ent/pendingauthsession/pendingauthsession.go new file mode 100644 index 00000000..8a3ac9bf --- /dev/null +++ b/backend/ent/pendingauthsession/pendingauthsession.go @@ -0,0 +1,279 @@ +// Code generated by ent, DO NOT EDIT. + +package pendingauthsession + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the pendingauthsession type in the database. + Label = "pending_auth_session" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldSessionToken holds the string denoting the session_token field in the database. + FieldSessionToken = "session_token" + // FieldIntent holds the string denoting the intent field in the database. + FieldIntent = "intent" + // FieldProviderType holds the string denoting the provider_type field in the database. + FieldProviderType = "provider_type" + // FieldProviderKey holds the string denoting the provider_key field in the database. + FieldProviderKey = "provider_key" + // FieldProviderSubject holds the string denoting the provider_subject field in the database. + FieldProviderSubject = "provider_subject" + // FieldTargetUserID holds the string denoting the target_user_id field in the database. + FieldTargetUserID = "target_user_id" + // FieldRedirectTo holds the string denoting the redirect_to field in the database. + FieldRedirectTo = "redirect_to" + // FieldResolvedEmail holds the string denoting the resolved_email field in the database. + FieldResolvedEmail = "resolved_email" + // FieldRegistrationPasswordHash holds the string denoting the registration_password_hash field in the database. + FieldRegistrationPasswordHash = "registration_password_hash" + // FieldUpstreamIdentityClaims holds the string denoting the upstream_identity_claims field in the database. + FieldUpstreamIdentityClaims = "upstream_identity_claims" + // FieldLocalFlowState holds the string denoting the local_flow_state field in the database. + FieldLocalFlowState = "local_flow_state" + // FieldBrowserSessionKey holds the string denoting the browser_session_key field in the database. + FieldBrowserSessionKey = "browser_session_key" + // FieldCompletionCodeHash holds the string denoting the completion_code_hash field in the database. + FieldCompletionCodeHash = "completion_code_hash" + // FieldCompletionCodeExpiresAt holds the string denoting the completion_code_expires_at field in the database. + FieldCompletionCodeExpiresAt = "completion_code_expires_at" + // FieldEmailVerifiedAt holds the string denoting the email_verified_at field in the database. + FieldEmailVerifiedAt = "email_verified_at" + // FieldPasswordVerifiedAt holds the string denoting the password_verified_at field in the database. + FieldPasswordVerifiedAt = "password_verified_at" + // FieldTotpVerifiedAt holds the string denoting the totp_verified_at field in the database. + FieldTotpVerifiedAt = "totp_verified_at" + // FieldExpiresAt holds the string denoting the expires_at field in the database. + FieldExpiresAt = "expires_at" + // FieldConsumedAt holds the string denoting the consumed_at field in the database. + FieldConsumedAt = "consumed_at" + // EdgeTargetUser holds the string denoting the target_user edge name in mutations. + EdgeTargetUser = "target_user" + // EdgeAdoptionDecision holds the string denoting the adoption_decision edge name in mutations. + EdgeAdoptionDecision = "adoption_decision" + // Table holds the table name of the pendingauthsession in the database. + Table = "pending_auth_sessions" + // TargetUserTable is the table that holds the target_user relation/edge. + TargetUserTable = "pending_auth_sessions" + // TargetUserInverseTable is the table name for the User entity. + // It exists in this package in order to avoid circular dependency with the "user" package. + TargetUserInverseTable = "users" + // TargetUserColumn is the table column denoting the target_user relation/edge. + TargetUserColumn = "target_user_id" + // AdoptionDecisionTable is the table that holds the adoption_decision relation/edge. + AdoptionDecisionTable = "identity_adoption_decisions" + // AdoptionDecisionInverseTable is the table name for the IdentityAdoptionDecision entity. + // It exists in this package in order to avoid circular dependency with the "identityadoptiondecision" package. + AdoptionDecisionInverseTable = "identity_adoption_decisions" + // AdoptionDecisionColumn is the table column denoting the adoption_decision relation/edge. + AdoptionDecisionColumn = "pending_auth_session_id" +) + +// Columns holds all SQL columns for pendingauthsession fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldSessionToken, + FieldIntent, + FieldProviderType, + FieldProviderKey, + FieldProviderSubject, + FieldTargetUserID, + FieldRedirectTo, + FieldResolvedEmail, + FieldRegistrationPasswordHash, + FieldUpstreamIdentityClaims, + FieldLocalFlowState, + FieldBrowserSessionKey, + FieldCompletionCodeHash, + FieldCompletionCodeExpiresAt, + FieldEmailVerifiedAt, + FieldPasswordVerifiedAt, + FieldTotpVerifiedAt, + FieldExpiresAt, + FieldConsumedAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // SessionTokenValidator is a validator for the "session_token" field. It is called by the builders before save. + SessionTokenValidator func(string) error + // IntentValidator is a validator for the "intent" field. It is called by the builders before save. + IntentValidator func(string) error + // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + ProviderTypeValidator func(string) error + // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + ProviderKeyValidator func(string) error + // ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save. + ProviderSubjectValidator func(string) error + // DefaultRedirectTo holds the default value on creation for the "redirect_to" field. + DefaultRedirectTo string + // DefaultResolvedEmail holds the default value on creation for the "resolved_email" field. + DefaultResolvedEmail string + // DefaultRegistrationPasswordHash holds the default value on creation for the "registration_password_hash" field. + DefaultRegistrationPasswordHash string + // DefaultUpstreamIdentityClaims holds the default value on creation for the "upstream_identity_claims" field. + DefaultUpstreamIdentityClaims func() map[string]interface{} + // DefaultLocalFlowState holds the default value on creation for the "local_flow_state" field. + DefaultLocalFlowState func() map[string]interface{} + // DefaultBrowserSessionKey holds the default value on creation for the "browser_session_key" field. + DefaultBrowserSessionKey string + // DefaultCompletionCodeHash holds the default value on creation for the "completion_code_hash" field. + DefaultCompletionCodeHash string +) + +// OrderOption defines the ordering options for the PendingAuthSession queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// BySessionToken orders the results by the session_token field. +func BySessionToken(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSessionToken, opts...).ToFunc() +} + +// ByIntent orders the results by the intent field. +func ByIntent(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIntent, opts...).ToFunc() +} + +// ByProviderType orders the results by the provider_type field. +func ByProviderType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderType, opts...).ToFunc() +} + +// ByProviderKey orders the results by the provider_key field. +func ByProviderKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderKey, opts...).ToFunc() +} + +// ByProviderSubject orders the results by the provider_subject field. +func ByProviderSubject(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderSubject, opts...).ToFunc() +} + +// ByTargetUserID orders the results by the target_user_id field. +func ByTargetUserID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTargetUserID, opts...).ToFunc() +} + +// ByRedirectTo orders the results by the redirect_to field. +func ByRedirectTo(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRedirectTo, opts...).ToFunc() +} + +// ByResolvedEmail orders the results by the resolved_email field. +func ByResolvedEmail(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldResolvedEmail, opts...).ToFunc() +} + +// ByRegistrationPasswordHash orders the results by the registration_password_hash field. +func ByRegistrationPasswordHash(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRegistrationPasswordHash, opts...).ToFunc() +} + +// ByBrowserSessionKey orders the results by the browser_session_key field. +func ByBrowserSessionKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldBrowserSessionKey, opts...).ToFunc() +} + +// ByCompletionCodeHash orders the results by the completion_code_hash field. +func ByCompletionCodeHash(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCompletionCodeHash, opts...).ToFunc() +} + +// ByCompletionCodeExpiresAt orders the results by the completion_code_expires_at field. +func ByCompletionCodeExpiresAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCompletionCodeExpiresAt, opts...).ToFunc() +} + +// ByEmailVerifiedAt orders the results by the email_verified_at field. +func ByEmailVerifiedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldEmailVerifiedAt, opts...).ToFunc() +} + +// ByPasswordVerifiedAt orders the results by the password_verified_at field. +func ByPasswordVerifiedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPasswordVerifiedAt, opts...).ToFunc() +} + +// ByTotpVerifiedAt orders the results by the totp_verified_at field. +func ByTotpVerifiedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTotpVerifiedAt, opts...).ToFunc() +} + +// ByExpiresAt orders the results by the expires_at field. +func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldExpiresAt, opts...).ToFunc() +} + +// ByConsumedAt orders the results by the consumed_at field. +func ByConsumedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldConsumedAt, opts...).ToFunc() +} + +// ByTargetUserField orders the results by target_user field. +func ByTargetUserField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newTargetUserStep(), sql.OrderByField(field, opts...)) + } +} + +// ByAdoptionDecisionField orders the results by adoption_decision field. +func ByAdoptionDecisionField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAdoptionDecisionStep(), sql.OrderByField(field, opts...)) + } +} +func newTargetUserStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(TargetUserInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, TargetUserTable, TargetUserColumn), + ) +} +func newAdoptionDecisionStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AdoptionDecisionInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2O, false, AdoptionDecisionTable, AdoptionDecisionColumn), + ) +} diff --git a/backend/ent/pendingauthsession/where.go b/backend/ent/pendingauthsession/where.go new file mode 100644 index 00000000..cb316f47 --- /dev/null +++ b/backend/ent/pendingauthsession/where.go @@ -0,0 +1,1262 @@ +// Code generated by ent, DO NOT EDIT. + +package pendingauthsession + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// SessionToken applies equality check predicate on the "session_token" field. It's identical to SessionTokenEQ. +func SessionToken(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldSessionToken, v)) +} + +// Intent applies equality check predicate on the "intent" field. It's identical to IntentEQ. +func Intent(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldIntent, v)) +} + +// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ. +func ProviderType(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ. +func ProviderKey(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderSubject applies equality check predicate on the "provider_subject" field. It's identical to ProviderSubjectEQ. +func ProviderSubject(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderSubject, v)) +} + +// TargetUserID applies equality check predicate on the "target_user_id" field. It's identical to TargetUserIDEQ. +func TargetUserID(v int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldTargetUserID, v)) +} + +// RedirectTo applies equality check predicate on the "redirect_to" field. It's identical to RedirectToEQ. +func RedirectTo(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldRedirectTo, v)) +} + +// ResolvedEmail applies equality check predicate on the "resolved_email" field. It's identical to ResolvedEmailEQ. +func ResolvedEmail(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldResolvedEmail, v)) +} + +// RegistrationPasswordHash applies equality check predicate on the "registration_password_hash" field. It's identical to RegistrationPasswordHashEQ. +func RegistrationPasswordHash(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldRegistrationPasswordHash, v)) +} + +// BrowserSessionKey applies equality check predicate on the "browser_session_key" field. It's identical to BrowserSessionKeyEQ. +func BrowserSessionKey(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldBrowserSessionKey, v)) +} + +// CompletionCodeHash applies equality check predicate on the "completion_code_hash" field. It's identical to CompletionCodeHashEQ. +func CompletionCodeHash(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeHash, v)) +} + +// CompletionCodeExpiresAt applies equality check predicate on the "completion_code_expires_at" field. It's identical to CompletionCodeExpiresAtEQ. +func CompletionCodeExpiresAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeExpiresAt, v)) +} + +// EmailVerifiedAt applies equality check predicate on the "email_verified_at" field. It's identical to EmailVerifiedAtEQ. +func EmailVerifiedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldEmailVerifiedAt, v)) +} + +// PasswordVerifiedAt applies equality check predicate on the "password_verified_at" field. It's identical to PasswordVerifiedAtEQ. +func PasswordVerifiedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldPasswordVerifiedAt, v)) +} + +// TotpVerifiedAt applies equality check predicate on the "totp_verified_at" field. It's identical to TotpVerifiedAtEQ. +func TotpVerifiedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldTotpVerifiedAt, v)) +} + +// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ. +func ExpiresAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldExpiresAt, v)) +} + +// ConsumedAt applies equality check predicate on the "consumed_at" field. It's identical to ConsumedAtEQ. +func ConsumedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldConsumedAt, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// SessionTokenEQ applies the EQ predicate on the "session_token" field. +func SessionTokenEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldSessionToken, v)) +} + +// SessionTokenNEQ applies the NEQ predicate on the "session_token" field. +func SessionTokenNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldSessionToken, v)) +} + +// SessionTokenIn applies the In predicate on the "session_token" field. +func SessionTokenIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldSessionToken, vs...)) +} + +// SessionTokenNotIn applies the NotIn predicate on the "session_token" field. +func SessionTokenNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldSessionToken, vs...)) +} + +// SessionTokenGT applies the GT predicate on the "session_token" field. +func SessionTokenGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldSessionToken, v)) +} + +// SessionTokenGTE applies the GTE predicate on the "session_token" field. +func SessionTokenGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldSessionToken, v)) +} + +// SessionTokenLT applies the LT predicate on the "session_token" field. +func SessionTokenLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldSessionToken, v)) +} + +// SessionTokenLTE applies the LTE predicate on the "session_token" field. +func SessionTokenLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldSessionToken, v)) +} + +// SessionTokenContains applies the Contains predicate on the "session_token" field. +func SessionTokenContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldSessionToken, v)) +} + +// SessionTokenHasPrefix applies the HasPrefix predicate on the "session_token" field. +func SessionTokenHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldSessionToken, v)) +} + +// SessionTokenHasSuffix applies the HasSuffix predicate on the "session_token" field. +func SessionTokenHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldSessionToken, v)) +} + +// SessionTokenEqualFold applies the EqualFold predicate on the "session_token" field. +func SessionTokenEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldSessionToken, v)) +} + +// SessionTokenContainsFold applies the ContainsFold predicate on the "session_token" field. +func SessionTokenContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldSessionToken, v)) +} + +// IntentEQ applies the EQ predicate on the "intent" field. +func IntentEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldIntent, v)) +} + +// IntentNEQ applies the NEQ predicate on the "intent" field. +func IntentNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldIntent, v)) +} + +// IntentIn applies the In predicate on the "intent" field. +func IntentIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldIntent, vs...)) +} + +// IntentNotIn applies the NotIn predicate on the "intent" field. +func IntentNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldIntent, vs...)) +} + +// IntentGT applies the GT predicate on the "intent" field. +func IntentGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldIntent, v)) +} + +// IntentGTE applies the GTE predicate on the "intent" field. +func IntentGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldIntent, v)) +} + +// IntentLT applies the LT predicate on the "intent" field. +func IntentLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldIntent, v)) +} + +// IntentLTE applies the LTE predicate on the "intent" field. +func IntentLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldIntent, v)) +} + +// IntentContains applies the Contains predicate on the "intent" field. +func IntentContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldIntent, v)) +} + +// IntentHasPrefix applies the HasPrefix predicate on the "intent" field. +func IntentHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldIntent, v)) +} + +// IntentHasSuffix applies the HasSuffix predicate on the "intent" field. +func IntentHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldIntent, v)) +} + +// IntentEqualFold applies the EqualFold predicate on the "intent" field. +func IntentEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldIntent, v)) +} + +// IntentContainsFold applies the ContainsFold predicate on the "intent" field. +func IntentContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldIntent, v)) +} + +// ProviderTypeEQ applies the EQ predicate on the "provider_type" field. +func ProviderTypeEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field. +func ProviderTypeNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderType, v)) +} + +// ProviderTypeIn applies the In predicate on the "provider_type" field. +func ProviderTypeIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldProviderType, vs...)) +} + +// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field. +func ProviderTypeNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderType, vs...)) +} + +// ProviderTypeGT applies the GT predicate on the "provider_type" field. +func ProviderTypeGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldProviderType, v)) +} + +// ProviderTypeGTE applies the GTE predicate on the "provider_type" field. +func ProviderTypeGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderType, v)) +} + +// ProviderTypeLT applies the LT predicate on the "provider_type" field. +func ProviderTypeLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldProviderType, v)) +} + +// ProviderTypeLTE applies the LTE predicate on the "provider_type" field. +func ProviderTypeLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderType, v)) +} + +// ProviderTypeContains applies the Contains predicate on the "provider_type" field. +func ProviderTypeContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldProviderType, v)) +} + +// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field. +func ProviderTypeHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderType, v)) +} + +// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field. +func ProviderTypeHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderType, v)) +} + +// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field. +func ProviderTypeEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderType, v)) +} + +// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field. +func ProviderTypeContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderType, v)) +} + +// ProviderKeyEQ applies the EQ predicate on the "provider_key" field. +func ProviderKeyEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field. +func ProviderKeyNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderKey, v)) +} + +// ProviderKeyIn applies the In predicate on the "provider_key" field. +func ProviderKeyIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldProviderKey, vs...)) +} + +// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field. +func ProviderKeyNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderKey, vs...)) +} + +// ProviderKeyGT applies the GT predicate on the "provider_key" field. +func ProviderKeyGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldProviderKey, v)) +} + +// ProviderKeyGTE applies the GTE predicate on the "provider_key" field. +func ProviderKeyGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderKey, v)) +} + +// ProviderKeyLT applies the LT predicate on the "provider_key" field. +func ProviderKeyLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldProviderKey, v)) +} + +// ProviderKeyLTE applies the LTE predicate on the "provider_key" field. +func ProviderKeyLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderKey, v)) +} + +// ProviderKeyContains applies the Contains predicate on the "provider_key" field. +func ProviderKeyContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldProviderKey, v)) +} + +// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field. +func ProviderKeyHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderKey, v)) +} + +// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field. +func ProviderKeyHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderKey, v)) +} + +// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field. +func ProviderKeyEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderKey, v)) +} + +// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field. +func ProviderKeyContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderKey, v)) +} + +// ProviderSubjectEQ applies the EQ predicate on the "provider_subject" field. +func ProviderSubjectEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderSubject, v)) +} + +// ProviderSubjectNEQ applies the NEQ predicate on the "provider_subject" field. +func ProviderSubjectNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderSubject, v)) +} + +// ProviderSubjectIn applies the In predicate on the "provider_subject" field. +func ProviderSubjectIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldProviderSubject, vs...)) +} + +// ProviderSubjectNotIn applies the NotIn predicate on the "provider_subject" field. +func ProviderSubjectNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderSubject, vs...)) +} + +// ProviderSubjectGT applies the GT predicate on the "provider_subject" field. +func ProviderSubjectGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldProviderSubject, v)) +} + +// ProviderSubjectGTE applies the GTE predicate on the "provider_subject" field. +func ProviderSubjectGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderSubject, v)) +} + +// ProviderSubjectLT applies the LT predicate on the "provider_subject" field. +func ProviderSubjectLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldProviderSubject, v)) +} + +// ProviderSubjectLTE applies the LTE predicate on the "provider_subject" field. +func ProviderSubjectLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderSubject, v)) +} + +// ProviderSubjectContains applies the Contains predicate on the "provider_subject" field. +func ProviderSubjectContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldProviderSubject, v)) +} + +// ProviderSubjectHasPrefix applies the HasPrefix predicate on the "provider_subject" field. +func ProviderSubjectHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderSubject, v)) +} + +// ProviderSubjectHasSuffix applies the HasSuffix predicate on the "provider_subject" field. +func ProviderSubjectHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderSubject, v)) +} + +// ProviderSubjectEqualFold applies the EqualFold predicate on the "provider_subject" field. +func ProviderSubjectEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderSubject, v)) +} + +// ProviderSubjectContainsFold applies the ContainsFold predicate on the "provider_subject" field. +func ProviderSubjectContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderSubject, v)) +} + +// TargetUserIDEQ applies the EQ predicate on the "target_user_id" field. +func TargetUserIDEQ(v int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldTargetUserID, v)) +} + +// TargetUserIDNEQ applies the NEQ predicate on the "target_user_id" field. +func TargetUserIDNEQ(v int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldTargetUserID, v)) +} + +// TargetUserIDIn applies the In predicate on the "target_user_id" field. +func TargetUserIDIn(vs ...int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldTargetUserID, vs...)) +} + +// TargetUserIDNotIn applies the NotIn predicate on the "target_user_id" field. +func TargetUserIDNotIn(vs ...int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldTargetUserID, vs...)) +} + +// TargetUserIDIsNil applies the IsNil predicate on the "target_user_id" field. +func TargetUserIDIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldTargetUserID)) +} + +// TargetUserIDNotNil applies the NotNil predicate on the "target_user_id" field. +func TargetUserIDNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldTargetUserID)) +} + +// RedirectToEQ applies the EQ predicate on the "redirect_to" field. +func RedirectToEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldRedirectTo, v)) +} + +// RedirectToNEQ applies the NEQ predicate on the "redirect_to" field. +func RedirectToNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldRedirectTo, v)) +} + +// RedirectToIn applies the In predicate on the "redirect_to" field. +func RedirectToIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldRedirectTo, vs...)) +} + +// RedirectToNotIn applies the NotIn predicate on the "redirect_to" field. +func RedirectToNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldRedirectTo, vs...)) +} + +// RedirectToGT applies the GT predicate on the "redirect_to" field. +func RedirectToGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldRedirectTo, v)) +} + +// RedirectToGTE applies the GTE predicate on the "redirect_to" field. +func RedirectToGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldRedirectTo, v)) +} + +// RedirectToLT applies the LT predicate on the "redirect_to" field. +func RedirectToLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldRedirectTo, v)) +} + +// RedirectToLTE applies the LTE predicate on the "redirect_to" field. +func RedirectToLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldRedirectTo, v)) +} + +// RedirectToContains applies the Contains predicate on the "redirect_to" field. +func RedirectToContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldRedirectTo, v)) +} + +// RedirectToHasPrefix applies the HasPrefix predicate on the "redirect_to" field. +func RedirectToHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldRedirectTo, v)) +} + +// RedirectToHasSuffix applies the HasSuffix predicate on the "redirect_to" field. +func RedirectToHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldRedirectTo, v)) +} + +// RedirectToEqualFold applies the EqualFold predicate on the "redirect_to" field. +func RedirectToEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldRedirectTo, v)) +} + +// RedirectToContainsFold applies the ContainsFold predicate on the "redirect_to" field. +func RedirectToContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldRedirectTo, v)) +} + +// ResolvedEmailEQ applies the EQ predicate on the "resolved_email" field. +func ResolvedEmailEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldResolvedEmail, v)) +} + +// ResolvedEmailNEQ applies the NEQ predicate on the "resolved_email" field. +func ResolvedEmailNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldResolvedEmail, v)) +} + +// ResolvedEmailIn applies the In predicate on the "resolved_email" field. +func ResolvedEmailIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldResolvedEmail, vs...)) +} + +// ResolvedEmailNotIn applies the NotIn predicate on the "resolved_email" field. +func ResolvedEmailNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldResolvedEmail, vs...)) +} + +// ResolvedEmailGT applies the GT predicate on the "resolved_email" field. +func ResolvedEmailGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldResolvedEmail, v)) +} + +// ResolvedEmailGTE applies the GTE predicate on the "resolved_email" field. +func ResolvedEmailGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldResolvedEmail, v)) +} + +// ResolvedEmailLT applies the LT predicate on the "resolved_email" field. +func ResolvedEmailLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldResolvedEmail, v)) +} + +// ResolvedEmailLTE applies the LTE predicate on the "resolved_email" field. +func ResolvedEmailLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldResolvedEmail, v)) +} + +// ResolvedEmailContains applies the Contains predicate on the "resolved_email" field. +func ResolvedEmailContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldResolvedEmail, v)) +} + +// ResolvedEmailHasPrefix applies the HasPrefix predicate on the "resolved_email" field. +func ResolvedEmailHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldResolvedEmail, v)) +} + +// ResolvedEmailHasSuffix applies the HasSuffix predicate on the "resolved_email" field. +func ResolvedEmailHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldResolvedEmail, v)) +} + +// ResolvedEmailEqualFold applies the EqualFold predicate on the "resolved_email" field. +func ResolvedEmailEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldResolvedEmail, v)) +} + +// ResolvedEmailContainsFold applies the ContainsFold predicate on the "resolved_email" field. +func ResolvedEmailContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldResolvedEmail, v)) +} + +// RegistrationPasswordHashEQ applies the EQ predicate on the "registration_password_hash" field. +func RegistrationPasswordHashEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashNEQ applies the NEQ predicate on the "registration_password_hash" field. +func RegistrationPasswordHashNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashIn applies the In predicate on the "registration_password_hash" field. +func RegistrationPasswordHashIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldRegistrationPasswordHash, vs...)) +} + +// RegistrationPasswordHashNotIn applies the NotIn predicate on the "registration_password_hash" field. +func RegistrationPasswordHashNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldRegistrationPasswordHash, vs...)) +} + +// RegistrationPasswordHashGT applies the GT predicate on the "registration_password_hash" field. +func RegistrationPasswordHashGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashGTE applies the GTE predicate on the "registration_password_hash" field. +func RegistrationPasswordHashGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashLT applies the LT predicate on the "registration_password_hash" field. +func RegistrationPasswordHashLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashLTE applies the LTE predicate on the "registration_password_hash" field. +func RegistrationPasswordHashLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashContains applies the Contains predicate on the "registration_password_hash" field. +func RegistrationPasswordHashContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashHasPrefix applies the HasPrefix predicate on the "registration_password_hash" field. +func RegistrationPasswordHashHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashHasSuffix applies the HasSuffix predicate on the "registration_password_hash" field. +func RegistrationPasswordHashHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashEqualFold applies the EqualFold predicate on the "registration_password_hash" field. +func RegistrationPasswordHashEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashContainsFold applies the ContainsFold predicate on the "registration_password_hash" field. +func RegistrationPasswordHashContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldRegistrationPasswordHash, v)) +} + +// BrowserSessionKeyEQ applies the EQ predicate on the "browser_session_key" field. +func BrowserSessionKeyEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyNEQ applies the NEQ predicate on the "browser_session_key" field. +func BrowserSessionKeyNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyIn applies the In predicate on the "browser_session_key" field. +func BrowserSessionKeyIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldBrowserSessionKey, vs...)) +} + +// BrowserSessionKeyNotIn applies the NotIn predicate on the "browser_session_key" field. +func BrowserSessionKeyNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldBrowserSessionKey, vs...)) +} + +// BrowserSessionKeyGT applies the GT predicate on the "browser_session_key" field. +func BrowserSessionKeyGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyGTE applies the GTE predicate on the "browser_session_key" field. +func BrowserSessionKeyGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyLT applies the LT predicate on the "browser_session_key" field. +func BrowserSessionKeyLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyLTE applies the LTE predicate on the "browser_session_key" field. +func BrowserSessionKeyLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyContains applies the Contains predicate on the "browser_session_key" field. +func BrowserSessionKeyContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyHasPrefix applies the HasPrefix predicate on the "browser_session_key" field. +func BrowserSessionKeyHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyHasSuffix applies the HasSuffix predicate on the "browser_session_key" field. +func BrowserSessionKeyHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyEqualFold applies the EqualFold predicate on the "browser_session_key" field. +func BrowserSessionKeyEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyContainsFold applies the ContainsFold predicate on the "browser_session_key" field. +func BrowserSessionKeyContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldBrowserSessionKey, v)) +} + +// CompletionCodeHashEQ applies the EQ predicate on the "completion_code_hash" field. +func CompletionCodeHashEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashNEQ applies the NEQ predicate on the "completion_code_hash" field. +func CompletionCodeHashNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashIn applies the In predicate on the "completion_code_hash" field. +func CompletionCodeHashIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldCompletionCodeHash, vs...)) +} + +// CompletionCodeHashNotIn applies the NotIn predicate on the "completion_code_hash" field. +func CompletionCodeHashNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldCompletionCodeHash, vs...)) +} + +// CompletionCodeHashGT applies the GT predicate on the "completion_code_hash" field. +func CompletionCodeHashGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashGTE applies the GTE predicate on the "completion_code_hash" field. +func CompletionCodeHashGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashLT applies the LT predicate on the "completion_code_hash" field. +func CompletionCodeHashLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashLTE applies the LTE predicate on the "completion_code_hash" field. +func CompletionCodeHashLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashContains applies the Contains predicate on the "completion_code_hash" field. +func CompletionCodeHashContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashHasPrefix applies the HasPrefix predicate on the "completion_code_hash" field. +func CompletionCodeHashHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashHasSuffix applies the HasSuffix predicate on the "completion_code_hash" field. +func CompletionCodeHashHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashEqualFold applies the EqualFold predicate on the "completion_code_hash" field. +func CompletionCodeHashEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashContainsFold applies the ContainsFold predicate on the "completion_code_hash" field. +func CompletionCodeHashContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldCompletionCodeHash, v)) +} + +// CompletionCodeExpiresAtEQ applies the EQ predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtNEQ applies the NEQ predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtIn applies the In predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldCompletionCodeExpiresAt, vs...)) +} + +// CompletionCodeExpiresAtNotIn applies the NotIn predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldCompletionCodeExpiresAt, vs...)) +} + +// CompletionCodeExpiresAtGT applies the GT predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtGTE applies the GTE predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtLT applies the LT predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtLTE applies the LTE predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtIsNil applies the IsNil predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldCompletionCodeExpiresAt)) +} + +// CompletionCodeExpiresAtNotNil applies the NotNil predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldCompletionCodeExpiresAt)) +} + +// EmailVerifiedAtEQ applies the EQ predicate on the "email_verified_at" field. +func EmailVerifiedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtNEQ applies the NEQ predicate on the "email_verified_at" field. +func EmailVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtIn applies the In predicate on the "email_verified_at" field. +func EmailVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldEmailVerifiedAt, vs...)) +} + +// EmailVerifiedAtNotIn applies the NotIn predicate on the "email_verified_at" field. +func EmailVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldEmailVerifiedAt, vs...)) +} + +// EmailVerifiedAtGT applies the GT predicate on the "email_verified_at" field. +func EmailVerifiedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtGTE applies the GTE predicate on the "email_verified_at" field. +func EmailVerifiedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtLT applies the LT predicate on the "email_verified_at" field. +func EmailVerifiedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtLTE applies the LTE predicate on the "email_verified_at" field. +func EmailVerifiedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtIsNil applies the IsNil predicate on the "email_verified_at" field. +func EmailVerifiedAtIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldEmailVerifiedAt)) +} + +// EmailVerifiedAtNotNil applies the NotNil predicate on the "email_verified_at" field. +func EmailVerifiedAtNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldEmailVerifiedAt)) +} + +// PasswordVerifiedAtEQ applies the EQ predicate on the "password_verified_at" field. +func PasswordVerifiedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtNEQ applies the NEQ predicate on the "password_verified_at" field. +func PasswordVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtIn applies the In predicate on the "password_verified_at" field. +func PasswordVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldPasswordVerifiedAt, vs...)) +} + +// PasswordVerifiedAtNotIn applies the NotIn predicate on the "password_verified_at" field. +func PasswordVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldPasswordVerifiedAt, vs...)) +} + +// PasswordVerifiedAtGT applies the GT predicate on the "password_verified_at" field. +func PasswordVerifiedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtGTE applies the GTE predicate on the "password_verified_at" field. +func PasswordVerifiedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtLT applies the LT predicate on the "password_verified_at" field. +func PasswordVerifiedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtLTE applies the LTE predicate on the "password_verified_at" field. +func PasswordVerifiedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtIsNil applies the IsNil predicate on the "password_verified_at" field. +func PasswordVerifiedAtIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldPasswordVerifiedAt)) +} + +// PasswordVerifiedAtNotNil applies the NotNil predicate on the "password_verified_at" field. +func PasswordVerifiedAtNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldPasswordVerifiedAt)) +} + +// TotpVerifiedAtEQ applies the EQ predicate on the "totp_verified_at" field. +func TotpVerifiedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtNEQ applies the NEQ predicate on the "totp_verified_at" field. +func TotpVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtIn applies the In predicate on the "totp_verified_at" field. +func TotpVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldTotpVerifiedAt, vs...)) +} + +// TotpVerifiedAtNotIn applies the NotIn predicate on the "totp_verified_at" field. +func TotpVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldTotpVerifiedAt, vs...)) +} + +// TotpVerifiedAtGT applies the GT predicate on the "totp_verified_at" field. +func TotpVerifiedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtGTE applies the GTE predicate on the "totp_verified_at" field. +func TotpVerifiedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtLT applies the LT predicate on the "totp_verified_at" field. +func TotpVerifiedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtLTE applies the LTE predicate on the "totp_verified_at" field. +func TotpVerifiedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtIsNil applies the IsNil predicate on the "totp_verified_at" field. +func TotpVerifiedAtIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldTotpVerifiedAt)) +} + +// TotpVerifiedAtNotNil applies the NotNil predicate on the "totp_verified_at" field. +func TotpVerifiedAtNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldTotpVerifiedAt)) +} + +// ExpiresAtEQ applies the EQ predicate on the "expires_at" field. +func ExpiresAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldExpiresAt, v)) +} + +// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field. +func ExpiresAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldExpiresAt, v)) +} + +// ExpiresAtIn applies the In predicate on the "expires_at" field. +func ExpiresAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field. +func ExpiresAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtGT applies the GT predicate on the "expires_at" field. +func ExpiresAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldExpiresAt, v)) +} + +// ExpiresAtGTE applies the GTE predicate on the "expires_at" field. +func ExpiresAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldExpiresAt, v)) +} + +// ExpiresAtLT applies the LT predicate on the "expires_at" field. +func ExpiresAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldExpiresAt, v)) +} + +// ExpiresAtLTE applies the LTE predicate on the "expires_at" field. +func ExpiresAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldExpiresAt, v)) +} + +// ConsumedAtEQ applies the EQ predicate on the "consumed_at" field. +func ConsumedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldConsumedAt, v)) +} + +// ConsumedAtNEQ applies the NEQ predicate on the "consumed_at" field. +func ConsumedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldConsumedAt, v)) +} + +// ConsumedAtIn applies the In predicate on the "consumed_at" field. +func ConsumedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldConsumedAt, vs...)) +} + +// ConsumedAtNotIn applies the NotIn predicate on the "consumed_at" field. +func ConsumedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldConsumedAt, vs...)) +} + +// ConsumedAtGT applies the GT predicate on the "consumed_at" field. +func ConsumedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldConsumedAt, v)) +} + +// ConsumedAtGTE applies the GTE predicate on the "consumed_at" field. +func ConsumedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldConsumedAt, v)) +} + +// ConsumedAtLT applies the LT predicate on the "consumed_at" field. +func ConsumedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldConsumedAt, v)) +} + +// ConsumedAtLTE applies the LTE predicate on the "consumed_at" field. +func ConsumedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldConsumedAt, v)) +} + +// ConsumedAtIsNil applies the IsNil predicate on the "consumed_at" field. +func ConsumedAtIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldConsumedAt)) +} + +// ConsumedAtNotNil applies the NotNil predicate on the "consumed_at" field. +func ConsumedAtNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldConsumedAt)) +} + +// HasTargetUser applies the HasEdge predicate on the "target_user" edge. +func HasTargetUser() predicate.PendingAuthSession { + return predicate.PendingAuthSession(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, TargetUserTable, TargetUserColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasTargetUserWith applies the HasEdge predicate on the "target_user" edge with a given conditions (other predicates). +func HasTargetUserWith(preds ...predicate.User) predicate.PendingAuthSession { + return predicate.PendingAuthSession(func(s *sql.Selector) { + step := newTargetUserStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasAdoptionDecision applies the HasEdge predicate on the "adoption_decision" edge. +func HasAdoptionDecision() predicate.PendingAuthSession { + return predicate.PendingAuthSession(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2O, false, AdoptionDecisionTable, AdoptionDecisionColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasAdoptionDecisionWith applies the HasEdge predicate on the "adoption_decision" edge with a given conditions (other predicates). +func HasAdoptionDecisionWith(preds ...predicate.IdentityAdoptionDecision) predicate.PendingAuthSession { + return predicate.PendingAuthSession(func(s *sql.Selector) { + step := newAdoptionDecisionStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.PendingAuthSession) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.PendingAuthSession) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.PendingAuthSession) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.NotPredicates(p)) +} diff --git a/backend/ent/pendingauthsession_create.go b/backend/ent/pendingauthsession_create.go new file mode 100644 index 00000000..60276daa --- /dev/null +++ b/backend/ent/pendingauthsession_create.go @@ -0,0 +1,1815 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// PendingAuthSessionCreate is the builder for creating a PendingAuthSession entity. +type PendingAuthSessionCreate struct { + config + mutation *PendingAuthSessionMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *PendingAuthSessionCreate) SetCreatedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableCreatedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *PendingAuthSessionCreate) SetUpdatedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableUpdatedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetSessionToken sets the "session_token" field. +func (_c *PendingAuthSessionCreate) SetSessionToken(v string) *PendingAuthSessionCreate { + _c.mutation.SetSessionToken(v) + return _c +} + +// SetIntent sets the "intent" field. +func (_c *PendingAuthSessionCreate) SetIntent(v string) *PendingAuthSessionCreate { + _c.mutation.SetIntent(v) + return _c +} + +// SetProviderType sets the "provider_type" field. +func (_c *PendingAuthSessionCreate) SetProviderType(v string) *PendingAuthSessionCreate { + _c.mutation.SetProviderType(v) + return _c +} + +// SetProviderKey sets the "provider_key" field. +func (_c *PendingAuthSessionCreate) SetProviderKey(v string) *PendingAuthSessionCreate { + _c.mutation.SetProviderKey(v) + return _c +} + +// SetProviderSubject sets the "provider_subject" field. +func (_c *PendingAuthSessionCreate) SetProviderSubject(v string) *PendingAuthSessionCreate { + _c.mutation.SetProviderSubject(v) + return _c +} + +// SetTargetUserID sets the "target_user_id" field. +func (_c *PendingAuthSessionCreate) SetTargetUserID(v int64) *PendingAuthSessionCreate { + _c.mutation.SetTargetUserID(v) + return _c +} + +// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableTargetUserID(v *int64) *PendingAuthSessionCreate { + if v != nil { + _c.SetTargetUserID(*v) + } + return _c +} + +// SetRedirectTo sets the "redirect_to" field. +func (_c *PendingAuthSessionCreate) SetRedirectTo(v string) *PendingAuthSessionCreate { + _c.mutation.SetRedirectTo(v) + return _c +} + +// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableRedirectTo(v *string) *PendingAuthSessionCreate { + if v != nil { + _c.SetRedirectTo(*v) + } + return _c +} + +// SetResolvedEmail sets the "resolved_email" field. +func (_c *PendingAuthSessionCreate) SetResolvedEmail(v string) *PendingAuthSessionCreate { + _c.mutation.SetResolvedEmail(v) + return _c +} + +// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableResolvedEmail(v *string) *PendingAuthSessionCreate { + if v != nil { + _c.SetResolvedEmail(*v) + } + return _c +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (_c *PendingAuthSessionCreate) SetRegistrationPasswordHash(v string) *PendingAuthSessionCreate { + _c.mutation.SetRegistrationPasswordHash(v) + return _c +} + +// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionCreate { + if v != nil { + _c.SetRegistrationPasswordHash(*v) + } + return _c +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (_c *PendingAuthSessionCreate) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionCreate { + _c.mutation.SetUpstreamIdentityClaims(v) + return _c +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (_c *PendingAuthSessionCreate) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionCreate { + _c.mutation.SetLocalFlowState(v) + return _c +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (_c *PendingAuthSessionCreate) SetBrowserSessionKey(v string) *PendingAuthSessionCreate { + _c.mutation.SetBrowserSessionKey(v) + return _c +} + +// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionCreate { + if v != nil { + _c.SetBrowserSessionKey(*v) + } + return _c +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (_c *PendingAuthSessionCreate) SetCompletionCodeHash(v string) *PendingAuthSessionCreate { + _c.mutation.SetCompletionCodeHash(v) + return _c +} + +// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionCreate { + if v != nil { + _c.SetCompletionCodeHash(*v) + } + return _c +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (_c *PendingAuthSessionCreate) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetCompletionCodeExpiresAt(v) + return _c +} + +// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetCompletionCodeExpiresAt(*v) + } + return _c +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (_c *PendingAuthSessionCreate) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetEmailVerifiedAt(v) + return _c +} + +// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetEmailVerifiedAt(*v) + } + return _c +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (_c *PendingAuthSessionCreate) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetPasswordVerifiedAt(v) + return _c +} + +// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetPasswordVerifiedAt(*v) + } + return _c +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (_c *PendingAuthSessionCreate) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetTotpVerifiedAt(v) + return _c +} + +// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetTotpVerifiedAt(*v) + } + return _c +} + +// SetExpiresAt sets the "expires_at" field. +func (_c *PendingAuthSessionCreate) SetExpiresAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetExpiresAt(v) + return _c +} + +// SetConsumedAt sets the "consumed_at" field. +func (_c *PendingAuthSessionCreate) SetConsumedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetConsumedAt(v) + return _c +} + +// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetConsumedAt(*v) + } + return _c +} + +// SetTargetUser sets the "target_user" edge to the User entity. +func (_c *PendingAuthSessionCreate) SetTargetUser(v *User) *PendingAuthSessionCreate { + return _c.SetTargetUserID(v.ID) +} + +// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID. +func (_c *PendingAuthSessionCreate) SetAdoptionDecisionID(id int64) *PendingAuthSessionCreate { + _c.mutation.SetAdoptionDecisionID(id) + return _c +} + +// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionCreate { + if id != nil { + _c = _c.SetAdoptionDecisionID(*id) + } + return _c +} + +// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (_c *PendingAuthSessionCreate) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionCreate { + return _c.SetAdoptionDecisionID(v.ID) +} + +// Mutation returns the PendingAuthSessionMutation object of the builder. +func (_c *PendingAuthSessionCreate) Mutation() *PendingAuthSessionMutation { + return _c.mutation +} + +// Save creates the PendingAuthSession in the database. +func (_c *PendingAuthSessionCreate) Save(ctx context.Context) (*PendingAuthSession, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *PendingAuthSessionCreate) SaveX(ctx context.Context) *PendingAuthSession { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *PendingAuthSessionCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *PendingAuthSessionCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *PendingAuthSessionCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := pendingauthsession.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := pendingauthsession.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.RedirectTo(); !ok { + v := pendingauthsession.DefaultRedirectTo + _c.mutation.SetRedirectTo(v) + } + if _, ok := _c.mutation.ResolvedEmail(); !ok { + v := pendingauthsession.DefaultResolvedEmail + _c.mutation.SetResolvedEmail(v) + } + if _, ok := _c.mutation.RegistrationPasswordHash(); !ok { + v := pendingauthsession.DefaultRegistrationPasswordHash + _c.mutation.SetRegistrationPasswordHash(v) + } + if _, ok := _c.mutation.UpstreamIdentityClaims(); !ok { + v := pendingauthsession.DefaultUpstreamIdentityClaims() + _c.mutation.SetUpstreamIdentityClaims(v) + } + if _, ok := _c.mutation.LocalFlowState(); !ok { + v := pendingauthsession.DefaultLocalFlowState() + _c.mutation.SetLocalFlowState(v) + } + if _, ok := _c.mutation.BrowserSessionKey(); !ok { + v := pendingauthsession.DefaultBrowserSessionKey + _c.mutation.SetBrowserSessionKey(v) + } + if _, ok := _c.mutation.CompletionCodeHash(); !ok { + v := pendingauthsession.DefaultCompletionCodeHash + _c.mutation.SetCompletionCodeHash(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *PendingAuthSessionCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "PendingAuthSession.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "PendingAuthSession.updated_at"`)} + } + if _, ok := _c.mutation.SessionToken(); !ok { + return &ValidationError{Name: "session_token", err: errors.New(`ent: missing required field "PendingAuthSession.session_token"`)} + } + if v, ok := _c.mutation.SessionToken(); ok { + if err := pendingauthsession.SessionTokenValidator(v); err != nil { + return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)} + } + } + if _, ok := _c.mutation.Intent(); !ok { + return &ValidationError{Name: "intent", err: errors.New(`ent: missing required field "PendingAuthSession.intent"`)} + } + if v, ok := _c.mutation.Intent(); ok { + if err := pendingauthsession.IntentValidator(v); err != nil { + return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderType(); !ok { + return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "PendingAuthSession.provider_type"`)} + } + if v, ok := _c.mutation.ProviderType(); ok { + if err := pendingauthsession.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderKey(); !ok { + return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "PendingAuthSession.provider_key"`)} + } + if v, ok := _c.mutation.ProviderKey(); ok { + if err := pendingauthsession.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderSubject(); !ok { + return &ValidationError{Name: "provider_subject", err: errors.New(`ent: missing required field "PendingAuthSession.provider_subject"`)} + } + if v, ok := _c.mutation.ProviderSubject(); ok { + if err := pendingauthsession.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)} + } + } + if _, ok := _c.mutation.RedirectTo(); !ok { + return &ValidationError{Name: "redirect_to", err: errors.New(`ent: missing required field "PendingAuthSession.redirect_to"`)} + } + if _, ok := _c.mutation.ResolvedEmail(); !ok { + return &ValidationError{Name: "resolved_email", err: errors.New(`ent: missing required field "PendingAuthSession.resolved_email"`)} + } + if _, ok := _c.mutation.RegistrationPasswordHash(); !ok { + return &ValidationError{Name: "registration_password_hash", err: errors.New(`ent: missing required field "PendingAuthSession.registration_password_hash"`)} + } + if _, ok := _c.mutation.UpstreamIdentityClaims(); !ok { + return &ValidationError{Name: "upstream_identity_claims", err: errors.New(`ent: missing required field "PendingAuthSession.upstream_identity_claims"`)} + } + if _, ok := _c.mutation.LocalFlowState(); !ok { + return &ValidationError{Name: "local_flow_state", err: errors.New(`ent: missing required field "PendingAuthSession.local_flow_state"`)} + } + if _, ok := _c.mutation.BrowserSessionKey(); !ok { + return &ValidationError{Name: "browser_session_key", err: errors.New(`ent: missing required field "PendingAuthSession.browser_session_key"`)} + } + if _, ok := _c.mutation.CompletionCodeHash(); !ok { + return &ValidationError{Name: "completion_code_hash", err: errors.New(`ent: missing required field "PendingAuthSession.completion_code_hash"`)} + } + if _, ok := _c.mutation.ExpiresAt(); !ok { + return &ValidationError{Name: "expires_at", err: errors.New(`ent: missing required field "PendingAuthSession.expires_at"`)} + } + return nil +} + +func (_c *PendingAuthSessionCreate) sqlSave(ctx context.Context) (*PendingAuthSession, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *PendingAuthSessionCreate) createSpec() (*PendingAuthSession, *sqlgraph.CreateSpec) { + var ( + _node = &PendingAuthSession{config: _c.config} + _spec = sqlgraph.NewCreateSpec(pendingauthsession.Table, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(pendingauthsession.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.SessionToken(); ok { + _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value) + _node.SessionToken = value + } + if value, ok := _c.mutation.Intent(); ok { + _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value) + _node.Intent = value + } + if value, ok := _c.mutation.ProviderType(); ok { + _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value) + _node.ProviderType = value + } + if value, ok := _c.mutation.ProviderKey(); ok { + _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value) + _node.ProviderKey = value + } + if value, ok := _c.mutation.ProviderSubject(); ok { + _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value) + _node.ProviderSubject = value + } + if value, ok := _c.mutation.RedirectTo(); ok { + _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value) + _node.RedirectTo = value + } + if value, ok := _c.mutation.ResolvedEmail(); ok { + _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value) + _node.ResolvedEmail = value + } + if value, ok := _c.mutation.RegistrationPasswordHash(); ok { + _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value) + _node.RegistrationPasswordHash = value + } + if value, ok := _c.mutation.UpstreamIdentityClaims(); ok { + _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value) + _node.UpstreamIdentityClaims = value + } + if value, ok := _c.mutation.LocalFlowState(); ok { + _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value) + _node.LocalFlowState = value + } + if value, ok := _c.mutation.BrowserSessionKey(); ok { + _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value) + _node.BrowserSessionKey = value + } + if value, ok := _c.mutation.CompletionCodeHash(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value) + _node.CompletionCodeHash = value + } + if value, ok := _c.mutation.CompletionCodeExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value) + _node.CompletionCodeExpiresAt = &value + } + if value, ok := _c.mutation.EmailVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value) + _node.EmailVerifiedAt = &value + } + if value, ok := _c.mutation.PasswordVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value) + _node.PasswordVerifiedAt = &value + } + if value, ok := _c.mutation.TotpVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value) + _node.TotpVerifiedAt = &value + } + if value, ok := _c.mutation.ExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value) + _node.ExpiresAt = value + } + if value, ok := _c.mutation.ConsumedAt(); ok { + _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value) + _node.ConsumedAt = &value + } + if nodes := _c.mutation.TargetUserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: pendingauthsession.TargetUserTable, + Columns: []string{pendingauthsession.TargetUserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.TargetUserID = &nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.AdoptionDecisionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: false, + Table: pendingauthsession.AdoptionDecisionTable, + Columns: []string{pendingauthsession.AdoptionDecisionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.PendingAuthSession.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.PendingAuthSessionUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *PendingAuthSessionCreate) OnConflict(opts ...sql.ConflictOption) *PendingAuthSessionUpsertOne { + _c.conflict = opts + return &PendingAuthSessionUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *PendingAuthSessionCreate) OnConflictColumns(columns ...string) *PendingAuthSessionUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &PendingAuthSessionUpsertOne{ + create: _c, + } +} + +type ( + // PendingAuthSessionUpsertOne is the builder for "upsert"-ing + // one PendingAuthSession node. + PendingAuthSessionUpsertOne struct { + create *PendingAuthSessionCreate + } + + // PendingAuthSessionUpsert is the "OnConflict" setter. + PendingAuthSessionUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *PendingAuthSessionUpsert) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateUpdatedAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldUpdatedAt) + return u +} + +// SetSessionToken sets the "session_token" field. +func (u *PendingAuthSessionUpsert) SetSessionToken(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldSessionToken, v) + return u +} + +// UpdateSessionToken sets the "session_token" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateSessionToken() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldSessionToken) + return u +} + +// SetIntent sets the "intent" field. +func (u *PendingAuthSessionUpsert) SetIntent(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldIntent, v) + return u +} + +// UpdateIntent sets the "intent" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateIntent() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldIntent) + return u +} + +// SetProviderType sets the "provider_type" field. +func (u *PendingAuthSessionUpsert) SetProviderType(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldProviderType, v) + return u +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateProviderType() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldProviderType) + return u +} + +// SetProviderKey sets the "provider_key" field. +func (u *PendingAuthSessionUpsert) SetProviderKey(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldProviderKey, v) + return u +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateProviderKey() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldProviderKey) + return u +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *PendingAuthSessionUpsert) SetProviderSubject(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldProviderSubject, v) + return u +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateProviderSubject() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldProviderSubject) + return u +} + +// SetTargetUserID sets the "target_user_id" field. +func (u *PendingAuthSessionUpsert) SetTargetUserID(v int64) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldTargetUserID, v) + return u +} + +// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateTargetUserID() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldTargetUserID) + return u +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (u *PendingAuthSessionUpsert) ClearTargetUserID() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldTargetUserID) + return u +} + +// SetRedirectTo sets the "redirect_to" field. +func (u *PendingAuthSessionUpsert) SetRedirectTo(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldRedirectTo, v) + return u +} + +// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateRedirectTo() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldRedirectTo) + return u +} + +// SetResolvedEmail sets the "resolved_email" field. +func (u *PendingAuthSessionUpsert) SetResolvedEmail(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldResolvedEmail, v) + return u +} + +// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateResolvedEmail() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldResolvedEmail) + return u +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (u *PendingAuthSessionUpsert) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldRegistrationPasswordHash, v) + return u +} + +// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldRegistrationPasswordHash) + return u +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (u *PendingAuthSessionUpsert) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldUpstreamIdentityClaims, v) + return u +} + +// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldUpstreamIdentityClaims) + return u +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (u *PendingAuthSessionUpsert) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldLocalFlowState, v) + return u +} + +// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateLocalFlowState() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldLocalFlowState) + return u +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (u *PendingAuthSessionUpsert) SetBrowserSessionKey(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldBrowserSessionKey, v) + return u +} + +// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateBrowserSessionKey() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldBrowserSessionKey) + return u +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (u *PendingAuthSessionUpsert) SetCompletionCodeHash(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldCompletionCodeHash, v) + return u +} + +// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateCompletionCodeHash() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldCompletionCodeHash) + return u +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsert) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldCompletionCodeExpiresAt, v) + return u +} + +// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldCompletionCodeExpiresAt) + return u +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsert) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldCompletionCodeExpiresAt) + return u +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (u *PendingAuthSessionUpsert) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldEmailVerifiedAt, v) + return u +} + +// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateEmailVerifiedAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldEmailVerifiedAt) + return u +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (u *PendingAuthSessionUpsert) ClearEmailVerifiedAt() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldEmailVerifiedAt) + return u +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (u *PendingAuthSessionUpsert) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldPasswordVerifiedAt, v) + return u +} + +// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldPasswordVerifiedAt) + return u +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (u *PendingAuthSessionUpsert) ClearPasswordVerifiedAt() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldPasswordVerifiedAt) + return u +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (u *PendingAuthSessionUpsert) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldTotpVerifiedAt, v) + return u +} + +// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateTotpVerifiedAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldTotpVerifiedAt) + return u +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (u *PendingAuthSessionUpsert) ClearTotpVerifiedAt() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldTotpVerifiedAt) + return u +} + +// SetExpiresAt sets the "expires_at" field. +func (u *PendingAuthSessionUpsert) SetExpiresAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldExpiresAt, v) + return u +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateExpiresAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldExpiresAt) + return u +} + +// SetConsumedAt sets the "consumed_at" field. +func (u *PendingAuthSessionUpsert) SetConsumedAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldConsumedAt, v) + return u +} + +// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateConsumedAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldConsumedAt) + return u +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (u *PendingAuthSessionUpsert) ClearConsumedAt() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldConsumedAt) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *PendingAuthSessionUpsertOne) UpdateNewValues() *PendingAuthSessionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(pendingauthsession.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *PendingAuthSessionUpsertOne) Ignore() *PendingAuthSessionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *PendingAuthSessionUpsertOne) DoNothing() *PendingAuthSessionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the PendingAuthSessionCreate.OnConflict +// documentation for more info. +func (u *PendingAuthSessionUpsertOne) Update(set func(*PendingAuthSessionUpsert)) *PendingAuthSessionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&PendingAuthSessionUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *PendingAuthSessionUpsertOne) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateUpdatedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetSessionToken sets the "session_token" field. +func (u *PendingAuthSessionUpsertOne) SetSessionToken(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetSessionToken(v) + }) +} + +// UpdateSessionToken sets the "session_token" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateSessionToken() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateSessionToken() + }) +} + +// SetIntent sets the "intent" field. +func (u *PendingAuthSessionUpsertOne) SetIntent(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetIntent(v) + }) +} + +// UpdateIntent sets the "intent" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateIntent() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateIntent() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *PendingAuthSessionUpsertOne) SetProviderType(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateProviderType() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *PendingAuthSessionUpsertOne) SetProviderKey(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateProviderKey() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderKey() + }) +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *PendingAuthSessionUpsertOne) SetProviderSubject(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderSubject(v) + }) +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateProviderSubject() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderSubject() + }) +} + +// SetTargetUserID sets the "target_user_id" field. +func (u *PendingAuthSessionUpsertOne) SetTargetUserID(v int64) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetTargetUserID(v) + }) +} + +// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateTargetUserID() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateTargetUserID() + }) +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (u *PendingAuthSessionUpsertOne) ClearTargetUserID() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearTargetUserID() + }) +} + +// SetRedirectTo sets the "redirect_to" field. +func (u *PendingAuthSessionUpsertOne) SetRedirectTo(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetRedirectTo(v) + }) +} + +// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateRedirectTo() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateRedirectTo() + }) +} + +// SetResolvedEmail sets the "resolved_email" field. +func (u *PendingAuthSessionUpsertOne) SetResolvedEmail(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetResolvedEmail(v) + }) +} + +// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateResolvedEmail() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateResolvedEmail() + }) +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (u *PendingAuthSessionUpsertOne) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetRegistrationPasswordHash(v) + }) +} + +// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateRegistrationPasswordHash() + }) +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (u *PendingAuthSessionUpsertOne) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetUpstreamIdentityClaims(v) + }) +} + +// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateUpstreamIdentityClaims() + }) +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (u *PendingAuthSessionUpsertOne) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetLocalFlowState(v) + }) +} + +// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateLocalFlowState() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateLocalFlowState() + }) +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (u *PendingAuthSessionUpsertOne) SetBrowserSessionKey(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetBrowserSessionKey(v) + }) +} + +// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateBrowserSessionKey() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateBrowserSessionKey() + }) +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (u *PendingAuthSessionUpsertOne) SetCompletionCodeHash(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetCompletionCodeHash(v) + }) +} + +// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateCompletionCodeHash() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateCompletionCodeHash() + }) +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsertOne) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetCompletionCodeExpiresAt(v) + }) +} + +// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateCompletionCodeExpiresAt() + }) +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsertOne) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearCompletionCodeExpiresAt() + }) +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (u *PendingAuthSessionUpsertOne) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetEmailVerifiedAt(v) + }) +} + +// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateEmailVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateEmailVerifiedAt() + }) +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (u *PendingAuthSessionUpsertOne) ClearEmailVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearEmailVerifiedAt() + }) +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (u *PendingAuthSessionUpsertOne) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetPasswordVerifiedAt(v) + }) +} + +// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdatePasswordVerifiedAt() + }) +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (u *PendingAuthSessionUpsertOne) ClearPasswordVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearPasswordVerifiedAt() + }) +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (u *PendingAuthSessionUpsertOne) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetTotpVerifiedAt(v) + }) +} + +// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateTotpVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateTotpVerifiedAt() + }) +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (u *PendingAuthSessionUpsertOne) ClearTotpVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearTotpVerifiedAt() + }) +} + +// SetExpiresAt sets the "expires_at" field. +func (u *PendingAuthSessionUpsertOne) SetExpiresAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateExpiresAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateExpiresAt() + }) +} + +// SetConsumedAt sets the "consumed_at" field. +func (u *PendingAuthSessionUpsertOne) SetConsumedAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetConsumedAt(v) + }) +} + +// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateConsumedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateConsumedAt() + }) +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (u *PendingAuthSessionUpsertOne) ClearConsumedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearConsumedAt() + }) +} + +// Exec executes the query. +func (u *PendingAuthSessionUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for PendingAuthSessionCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *PendingAuthSessionUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *PendingAuthSessionUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *PendingAuthSessionUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// PendingAuthSessionCreateBulk is the builder for creating many PendingAuthSession entities in bulk. +type PendingAuthSessionCreateBulk struct { + config + err error + builders []*PendingAuthSessionCreate + conflict []sql.ConflictOption +} + +// Save creates the PendingAuthSession entities in the database. +func (_c *PendingAuthSessionCreateBulk) Save(ctx context.Context) ([]*PendingAuthSession, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*PendingAuthSession, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*PendingAuthSessionMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *PendingAuthSessionCreateBulk) SaveX(ctx context.Context) []*PendingAuthSession { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *PendingAuthSessionCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *PendingAuthSessionCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.PendingAuthSession.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.PendingAuthSessionUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *PendingAuthSessionCreateBulk) OnConflict(opts ...sql.ConflictOption) *PendingAuthSessionUpsertBulk { + _c.conflict = opts + return &PendingAuthSessionUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *PendingAuthSessionCreateBulk) OnConflictColumns(columns ...string) *PendingAuthSessionUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &PendingAuthSessionUpsertBulk{ + create: _c, + } +} + +// PendingAuthSessionUpsertBulk is the builder for "upsert"-ing +// a bulk of PendingAuthSession nodes. +type PendingAuthSessionUpsertBulk struct { + create *PendingAuthSessionCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *PendingAuthSessionUpsertBulk) UpdateNewValues() *PendingAuthSessionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(pendingauthsession.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *PendingAuthSessionUpsertBulk) Ignore() *PendingAuthSessionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *PendingAuthSessionUpsertBulk) DoNothing() *PendingAuthSessionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the PendingAuthSessionCreateBulk.OnConflict +// documentation for more info. +func (u *PendingAuthSessionUpsertBulk) Update(set func(*PendingAuthSessionUpsert)) *PendingAuthSessionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&PendingAuthSessionUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *PendingAuthSessionUpsertBulk) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateUpdatedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetSessionToken sets the "session_token" field. +func (u *PendingAuthSessionUpsertBulk) SetSessionToken(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetSessionToken(v) + }) +} + +// UpdateSessionToken sets the "session_token" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateSessionToken() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateSessionToken() + }) +} + +// SetIntent sets the "intent" field. +func (u *PendingAuthSessionUpsertBulk) SetIntent(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetIntent(v) + }) +} + +// UpdateIntent sets the "intent" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateIntent() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateIntent() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *PendingAuthSessionUpsertBulk) SetProviderType(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateProviderType() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *PendingAuthSessionUpsertBulk) SetProviderKey(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateProviderKey() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderKey() + }) +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *PendingAuthSessionUpsertBulk) SetProviderSubject(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderSubject(v) + }) +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateProviderSubject() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderSubject() + }) +} + +// SetTargetUserID sets the "target_user_id" field. +func (u *PendingAuthSessionUpsertBulk) SetTargetUserID(v int64) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetTargetUserID(v) + }) +} + +// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateTargetUserID() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateTargetUserID() + }) +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (u *PendingAuthSessionUpsertBulk) ClearTargetUserID() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearTargetUserID() + }) +} + +// SetRedirectTo sets the "redirect_to" field. +func (u *PendingAuthSessionUpsertBulk) SetRedirectTo(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetRedirectTo(v) + }) +} + +// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateRedirectTo() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateRedirectTo() + }) +} + +// SetResolvedEmail sets the "resolved_email" field. +func (u *PendingAuthSessionUpsertBulk) SetResolvedEmail(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetResolvedEmail(v) + }) +} + +// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateResolvedEmail() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateResolvedEmail() + }) +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (u *PendingAuthSessionUpsertBulk) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetRegistrationPasswordHash(v) + }) +} + +// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateRegistrationPasswordHash() + }) +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (u *PendingAuthSessionUpsertBulk) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetUpstreamIdentityClaims(v) + }) +} + +// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateUpstreamIdentityClaims() + }) +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (u *PendingAuthSessionUpsertBulk) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetLocalFlowState(v) + }) +} + +// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateLocalFlowState() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateLocalFlowState() + }) +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (u *PendingAuthSessionUpsertBulk) SetBrowserSessionKey(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetBrowserSessionKey(v) + }) +} + +// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateBrowserSessionKey() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateBrowserSessionKey() + }) +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (u *PendingAuthSessionUpsertBulk) SetCompletionCodeHash(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetCompletionCodeHash(v) + }) +} + +// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateCompletionCodeHash() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateCompletionCodeHash() + }) +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsertBulk) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetCompletionCodeExpiresAt(v) + }) +} + +// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateCompletionCodeExpiresAt() + }) +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsertBulk) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearCompletionCodeExpiresAt() + }) +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetEmailVerifiedAt(v) + }) +} + +// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateEmailVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateEmailVerifiedAt() + }) +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) ClearEmailVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearEmailVerifiedAt() + }) +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetPasswordVerifiedAt(v) + }) +} + +// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdatePasswordVerifiedAt() + }) +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) ClearPasswordVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearPasswordVerifiedAt() + }) +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetTotpVerifiedAt(v) + }) +} + +// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateTotpVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateTotpVerifiedAt() + }) +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) ClearTotpVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearTotpVerifiedAt() + }) +} + +// SetExpiresAt sets the "expires_at" field. +func (u *PendingAuthSessionUpsertBulk) SetExpiresAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateExpiresAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateExpiresAt() + }) +} + +// SetConsumedAt sets the "consumed_at" field. +func (u *PendingAuthSessionUpsertBulk) SetConsumedAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetConsumedAt(v) + }) +} + +// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateConsumedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateConsumedAt() + }) +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (u *PendingAuthSessionUpsertBulk) ClearConsumedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearConsumedAt() + }) +} + +// Exec executes the query. +func (u *PendingAuthSessionUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the PendingAuthSessionCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for PendingAuthSessionCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *PendingAuthSessionUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/pendingauthsession_delete.go b/backend/ent/pendingauthsession_delete.go new file mode 100644 index 00000000..ee4fe605 --- /dev/null +++ b/backend/ent/pendingauthsession_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// PendingAuthSessionDelete is the builder for deleting a PendingAuthSession entity. +type PendingAuthSessionDelete struct { + config + hooks []Hook + mutation *PendingAuthSessionMutation +} + +// Where appends a list predicates to the PendingAuthSessionDelete builder. +func (_d *PendingAuthSessionDelete) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *PendingAuthSessionDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *PendingAuthSessionDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *PendingAuthSessionDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(pendingauthsession.Table, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// PendingAuthSessionDeleteOne is the builder for deleting a single PendingAuthSession entity. +type PendingAuthSessionDeleteOne struct { + _d *PendingAuthSessionDelete +} + +// Where appends a list predicates to the PendingAuthSessionDelete builder. +func (_d *PendingAuthSessionDeleteOne) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *PendingAuthSessionDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{pendingauthsession.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *PendingAuthSessionDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/pendingauthsession_query.go b/backend/ent/pendingauthsession_query.go new file mode 100644 index 00000000..78e29cd2 --- /dev/null +++ b/backend/ent/pendingauthsession_query.go @@ -0,0 +1,717 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "database/sql/driver" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// PendingAuthSessionQuery is the builder for querying PendingAuthSession entities. +type PendingAuthSessionQuery struct { + config + ctx *QueryContext + order []pendingauthsession.OrderOption + inters []Interceptor + predicates []predicate.PendingAuthSession + withTargetUser *UserQuery + withAdoptionDecision *IdentityAdoptionDecisionQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the PendingAuthSessionQuery builder. +func (_q *PendingAuthSessionQuery) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *PendingAuthSessionQuery) Limit(limit int) *PendingAuthSessionQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *PendingAuthSessionQuery) Offset(offset int) *PendingAuthSessionQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *PendingAuthSessionQuery) Unique(unique bool) *PendingAuthSessionQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *PendingAuthSessionQuery) Order(o ...pendingauthsession.OrderOption) *PendingAuthSessionQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryTargetUser chains the current query on the "target_user" edge. +func (_q *PendingAuthSessionQuery) QueryTargetUser() *UserQuery { + query := (&UserClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, selector), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, pendingauthsession.TargetUserTable, pendingauthsession.TargetUserColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryAdoptionDecision chains the current query on the "adoption_decision" edge. +func (_q *PendingAuthSessionQuery) QueryAdoptionDecision() *IdentityAdoptionDecisionQuery { + query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, selector), + sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID), + sqlgraph.Edge(sqlgraph.O2O, false, pendingauthsession.AdoptionDecisionTable, pendingauthsession.AdoptionDecisionColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first PendingAuthSession entity from the query. +// Returns a *NotFoundError when no PendingAuthSession was found. +func (_q *PendingAuthSessionQuery) First(ctx context.Context) (*PendingAuthSession, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{pendingauthsession.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) FirstX(ctx context.Context) *PendingAuthSession { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first PendingAuthSession ID from the query. +// Returns a *NotFoundError when no PendingAuthSession ID was found. +func (_q *PendingAuthSessionQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{pendingauthsession.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single PendingAuthSession entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one PendingAuthSession entity is found. +// Returns a *NotFoundError when no PendingAuthSession entities are found. +func (_q *PendingAuthSessionQuery) Only(ctx context.Context) (*PendingAuthSession, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{pendingauthsession.Label} + default: + return nil, &NotSingularError{pendingauthsession.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) OnlyX(ctx context.Context) *PendingAuthSession { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only PendingAuthSession ID in the query. +// Returns a *NotSingularError when more than one PendingAuthSession ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *PendingAuthSessionQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{pendingauthsession.Label} + default: + err = &NotSingularError{pendingauthsession.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of PendingAuthSessions. +func (_q *PendingAuthSessionQuery) All(ctx context.Context) ([]*PendingAuthSession, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*PendingAuthSession, *PendingAuthSessionQuery]() + return withInterceptors[[]*PendingAuthSession](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) AllX(ctx context.Context) []*PendingAuthSession { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of PendingAuthSession IDs. +func (_q *PendingAuthSessionQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(pendingauthsession.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *PendingAuthSessionQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*PendingAuthSessionQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *PendingAuthSessionQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the PendingAuthSessionQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *PendingAuthSessionQuery) Clone() *PendingAuthSessionQuery { + if _q == nil { + return nil + } + return &PendingAuthSessionQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]pendingauthsession.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.PendingAuthSession{}, _q.predicates...), + withTargetUser: _q.withTargetUser.Clone(), + withAdoptionDecision: _q.withAdoptionDecision.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithTargetUser tells the query-builder to eager-load the nodes that are connected to +// the "target_user" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *PendingAuthSessionQuery) WithTargetUser(opts ...func(*UserQuery)) *PendingAuthSessionQuery { + query := (&UserClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withTargetUser = query + return _q +} + +// WithAdoptionDecision tells the query-builder to eager-load the nodes that are connected to +// the "adoption_decision" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *PendingAuthSessionQuery) WithAdoptionDecision(opts ...func(*IdentityAdoptionDecisionQuery)) *PendingAuthSessionQuery { + query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withAdoptionDecision = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.PendingAuthSession.Query(). +// GroupBy(pendingauthsession.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *PendingAuthSessionQuery) GroupBy(field string, fields ...string) *PendingAuthSessionGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &PendingAuthSessionGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = pendingauthsession.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.PendingAuthSession.Query(). +// Select(pendingauthsession.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *PendingAuthSessionQuery) Select(fields ...string) *PendingAuthSessionSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &PendingAuthSessionSelect{PendingAuthSessionQuery: _q} + sbuild.label = pendingauthsession.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a PendingAuthSessionSelect configured with the given aggregations. +func (_q *PendingAuthSessionQuery) Aggregate(fns ...AggregateFunc) *PendingAuthSessionSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *PendingAuthSessionQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !pendingauthsession.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *PendingAuthSessionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*PendingAuthSession, error) { + var ( + nodes = []*PendingAuthSession{} + _spec = _q.querySpec() + loadedTypes = [2]bool{ + _q.withTargetUser != nil, + _q.withAdoptionDecision != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*PendingAuthSession).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &PendingAuthSession{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withTargetUser; query != nil { + if err := _q.loadTargetUser(ctx, query, nodes, nil, + func(n *PendingAuthSession, e *User) { n.Edges.TargetUser = e }); err != nil { + return nil, err + } + } + if query := _q.withAdoptionDecision; query != nil { + if err := _q.loadAdoptionDecision(ctx, query, nodes, nil, + func(n *PendingAuthSession, e *IdentityAdoptionDecision) { n.Edges.AdoptionDecision = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *PendingAuthSessionQuery) loadTargetUser(ctx context.Context, query *UserQuery, nodes []*PendingAuthSession, init func(*PendingAuthSession), assign func(*PendingAuthSession, *User)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*PendingAuthSession) + for i := range nodes { + if nodes[i].TargetUserID == nil { + continue + } + fk := *nodes[i].TargetUserID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "target_user_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (_q *PendingAuthSessionQuery) loadAdoptionDecision(ctx context.Context, query *IdentityAdoptionDecisionQuery, nodes []*PendingAuthSession, init func(*PendingAuthSession), assign func(*PendingAuthSession, *IdentityAdoptionDecision)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*PendingAuthSession) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(identityadoptiondecision.FieldPendingAuthSessionID) + } + query.Where(predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(pendingauthsession.AdoptionDecisionColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.PendingAuthSessionID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "pending_auth_session_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} + +func (_q *PendingAuthSessionQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *PendingAuthSessionQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, pendingauthsession.FieldID) + for i := range fields { + if fields[i] != pendingauthsession.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withTargetUser != nil { + _spec.Node.AddColumnOnce(pendingauthsession.FieldTargetUserID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *PendingAuthSessionQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(pendingauthsession.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = pendingauthsession.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *PendingAuthSessionQuery) ForUpdate(opts ...sql.LockOption) *PendingAuthSessionQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *PendingAuthSessionQuery) ForShare(opts ...sql.LockOption) *PendingAuthSessionQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// PendingAuthSessionGroupBy is the group-by builder for PendingAuthSession entities. +type PendingAuthSessionGroupBy struct { + selector + build *PendingAuthSessionQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *PendingAuthSessionGroupBy) Aggregate(fns ...AggregateFunc) *PendingAuthSessionGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *PendingAuthSessionGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*PendingAuthSessionQuery, *PendingAuthSessionGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *PendingAuthSessionGroupBy) sqlScan(ctx context.Context, root *PendingAuthSessionQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// PendingAuthSessionSelect is the builder for selecting fields of PendingAuthSession entities. +type PendingAuthSessionSelect struct { + *PendingAuthSessionQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *PendingAuthSessionSelect) Aggregate(fns ...AggregateFunc) *PendingAuthSessionSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *PendingAuthSessionSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*PendingAuthSessionQuery, *PendingAuthSessionSelect](ctx, _s.PendingAuthSessionQuery, _s, _s.inters, v) +} + +func (_s *PendingAuthSessionSelect) sqlScan(ctx context.Context, root *PendingAuthSessionQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/pendingauthsession_update.go b/backend/ent/pendingauthsession_update.go new file mode 100644 index 00000000..00066f69 --- /dev/null +++ b/backend/ent/pendingauthsession_update.go @@ -0,0 +1,1178 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// PendingAuthSessionUpdate is the builder for updating PendingAuthSession entities. +type PendingAuthSessionUpdate struct { + config + hooks []Hook + mutation *PendingAuthSessionMutation +} + +// Where appends a list predicates to the PendingAuthSessionUpdate builder. +func (_u *PendingAuthSessionUpdate) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *PendingAuthSessionUpdate) SetUpdatedAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetSessionToken sets the "session_token" field. +func (_u *PendingAuthSessionUpdate) SetSessionToken(v string) *PendingAuthSessionUpdate { + _u.mutation.SetSessionToken(v) + return _u +} + +// SetNillableSessionToken sets the "session_token" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableSessionToken(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetSessionToken(*v) + } + return _u +} + +// SetIntent sets the "intent" field. +func (_u *PendingAuthSessionUpdate) SetIntent(v string) *PendingAuthSessionUpdate { + _u.mutation.SetIntent(v) + return _u +} + +// SetNillableIntent sets the "intent" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableIntent(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetIntent(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *PendingAuthSessionUpdate) SetProviderType(v string) *PendingAuthSessionUpdate { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableProviderType(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *PendingAuthSessionUpdate) SetProviderKey(v string) *PendingAuthSessionUpdate { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableProviderKey(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetProviderSubject sets the "provider_subject" field. +func (_u *PendingAuthSessionUpdate) SetProviderSubject(v string) *PendingAuthSessionUpdate { + _u.mutation.SetProviderSubject(v) + return _u +} + +// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableProviderSubject(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetProviderSubject(*v) + } + return _u +} + +// SetTargetUserID sets the "target_user_id" field. +func (_u *PendingAuthSessionUpdate) SetTargetUserID(v int64) *PendingAuthSessionUpdate { + _u.mutation.SetTargetUserID(v) + return _u +} + +// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableTargetUserID(v *int64) *PendingAuthSessionUpdate { + if v != nil { + _u.SetTargetUserID(*v) + } + return _u +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (_u *PendingAuthSessionUpdate) ClearTargetUserID() *PendingAuthSessionUpdate { + _u.mutation.ClearTargetUserID() + return _u +} + +// SetRedirectTo sets the "redirect_to" field. +func (_u *PendingAuthSessionUpdate) SetRedirectTo(v string) *PendingAuthSessionUpdate { + _u.mutation.SetRedirectTo(v) + return _u +} + +// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableRedirectTo(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetRedirectTo(*v) + } + return _u +} + +// SetResolvedEmail sets the "resolved_email" field. +func (_u *PendingAuthSessionUpdate) SetResolvedEmail(v string) *PendingAuthSessionUpdate { + _u.mutation.SetResolvedEmail(v) + return _u +} + +// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableResolvedEmail(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetResolvedEmail(*v) + } + return _u +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (_u *PendingAuthSessionUpdate) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpdate { + _u.mutation.SetRegistrationPasswordHash(v) + return _u +} + +// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetRegistrationPasswordHash(*v) + } + return _u +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (_u *PendingAuthSessionUpdate) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpdate { + _u.mutation.SetUpstreamIdentityClaims(v) + return _u +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (_u *PendingAuthSessionUpdate) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpdate { + _u.mutation.SetLocalFlowState(v) + return _u +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (_u *PendingAuthSessionUpdate) SetBrowserSessionKey(v string) *PendingAuthSessionUpdate { + _u.mutation.SetBrowserSessionKey(v) + return _u +} + +// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetBrowserSessionKey(*v) + } + return _u +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (_u *PendingAuthSessionUpdate) SetCompletionCodeHash(v string) *PendingAuthSessionUpdate { + _u.mutation.SetCompletionCodeHash(v) + return _u +} + +// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetCompletionCodeHash(*v) + } + return _u +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (_u *PendingAuthSessionUpdate) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetCompletionCodeExpiresAt(v) + return _u +} + +// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetCompletionCodeExpiresAt(*v) + } + return _u +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (_u *PendingAuthSessionUpdate) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpdate { + _u.mutation.ClearCompletionCodeExpiresAt() + return _u +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (_u *PendingAuthSessionUpdate) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetEmailVerifiedAt(v) + return _u +} + +// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetEmailVerifiedAt(*v) + } + return _u +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (_u *PendingAuthSessionUpdate) ClearEmailVerifiedAt() *PendingAuthSessionUpdate { + _u.mutation.ClearEmailVerifiedAt() + return _u +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (_u *PendingAuthSessionUpdate) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetPasswordVerifiedAt(v) + return _u +} + +// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetPasswordVerifiedAt(*v) + } + return _u +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (_u *PendingAuthSessionUpdate) ClearPasswordVerifiedAt() *PendingAuthSessionUpdate { + _u.mutation.ClearPasswordVerifiedAt() + return _u +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (_u *PendingAuthSessionUpdate) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetTotpVerifiedAt(v) + return _u +} + +// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetTotpVerifiedAt(*v) + } + return _u +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (_u *PendingAuthSessionUpdate) ClearTotpVerifiedAt() *PendingAuthSessionUpdate { + _u.mutation.ClearTotpVerifiedAt() + return _u +} + +// SetExpiresAt sets the "expires_at" field. +func (_u *PendingAuthSessionUpdate) SetExpiresAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableExpiresAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// SetConsumedAt sets the "consumed_at" field. +func (_u *PendingAuthSessionUpdate) SetConsumedAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetConsumedAt(v) + return _u +} + +// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetConsumedAt(*v) + } + return _u +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (_u *PendingAuthSessionUpdate) ClearConsumedAt() *PendingAuthSessionUpdate { + _u.mutation.ClearConsumedAt() + return _u +} + +// SetTargetUser sets the "target_user" edge to the User entity. +func (_u *PendingAuthSessionUpdate) SetTargetUser(v *User) *PendingAuthSessionUpdate { + return _u.SetTargetUserID(v.ID) +} + +// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID. +func (_u *PendingAuthSessionUpdate) SetAdoptionDecisionID(id int64) *PendingAuthSessionUpdate { + _u.mutation.SetAdoptionDecisionID(id) + return _u +} + +// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionUpdate { + if id != nil { + _u = _u.SetAdoptionDecisionID(*id) + } + return _u +} + +// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (_u *PendingAuthSessionUpdate) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionUpdate { + return _u.SetAdoptionDecisionID(v.ID) +} + +// Mutation returns the PendingAuthSessionMutation object of the builder. +func (_u *PendingAuthSessionUpdate) Mutation() *PendingAuthSessionMutation { + return _u.mutation +} + +// ClearTargetUser clears the "target_user" edge to the User entity. +func (_u *PendingAuthSessionUpdate) ClearTargetUser() *PendingAuthSessionUpdate { + _u.mutation.ClearTargetUser() + return _u +} + +// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (_u *PendingAuthSessionUpdate) ClearAdoptionDecision() *PendingAuthSessionUpdate { + _u.mutation.ClearAdoptionDecision() + return _u +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *PendingAuthSessionUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *PendingAuthSessionUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *PendingAuthSessionUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *PendingAuthSessionUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *PendingAuthSessionUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := pendingauthsession.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *PendingAuthSessionUpdate) check() error { + if v, ok := _u.mutation.SessionToken(); ok { + if err := pendingauthsession.SessionTokenValidator(v); err != nil { + return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)} + } + } + if v, ok := _u.mutation.Intent(); ok { + if err := pendingauthsession.IntentValidator(v); err != nil { + return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderType(); ok { + if err := pendingauthsession.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := pendingauthsession.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderSubject(); ok { + if err := pendingauthsession.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)} + } + } + return nil +} + +func (_u *PendingAuthSessionUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.SessionToken(); ok { + _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value) + } + if value, ok := _u.mutation.Intent(); ok { + _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderSubject(); ok { + _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value) + } + if value, ok := _u.mutation.RedirectTo(); ok { + _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value) + } + if value, ok := _u.mutation.ResolvedEmail(); ok { + _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value) + } + if value, ok := _u.mutation.RegistrationPasswordHash(); ok { + _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value) + } + if value, ok := _u.mutation.UpstreamIdentityClaims(); ok { + _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value) + } + if value, ok := _u.mutation.LocalFlowState(); ok { + _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value) + } + if value, ok := _u.mutation.BrowserSessionKey(); ok { + _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value) + } + if value, ok := _u.mutation.CompletionCodeHash(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value) + } + if value, ok := _u.mutation.CompletionCodeExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value) + } + if _u.mutation.CompletionCodeExpiresAtCleared() { + _spec.ClearField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime) + } + if value, ok := _u.mutation.EmailVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value) + } + if _u.mutation.EmailVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.PasswordVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value) + } + if _u.mutation.PasswordVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.TotpVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value) + } + if _u.mutation.TotpVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ConsumedAt(); ok { + _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value) + } + if _u.mutation.ConsumedAtCleared() { + _spec.ClearField(pendingauthsession.FieldConsumedAt, field.TypeTime) + } + if _u.mutation.TargetUserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: pendingauthsession.TargetUserTable, + Columns: []string{pendingauthsession.TargetUserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.TargetUserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: pendingauthsession.TargetUserTable, + Columns: []string{pendingauthsession.TargetUserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.AdoptionDecisionCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: false, + Table: pendingauthsession.AdoptionDecisionTable, + Columns: []string{pendingauthsession.AdoptionDecisionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AdoptionDecisionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: false, + Table: pendingauthsession.AdoptionDecisionTable, + Columns: []string{pendingauthsession.AdoptionDecisionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{pendingauthsession.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// PendingAuthSessionUpdateOne is the builder for updating a single PendingAuthSession entity. +type PendingAuthSessionUpdateOne struct { + config + fields []string + hooks []Hook + mutation *PendingAuthSessionMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *PendingAuthSessionUpdateOne) SetUpdatedAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetSessionToken sets the "session_token" field. +func (_u *PendingAuthSessionUpdateOne) SetSessionToken(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetSessionToken(v) + return _u +} + +// SetNillableSessionToken sets the "session_token" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableSessionToken(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetSessionToken(*v) + } + return _u +} + +// SetIntent sets the "intent" field. +func (_u *PendingAuthSessionUpdateOne) SetIntent(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetIntent(v) + return _u +} + +// SetNillableIntent sets the "intent" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableIntent(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetIntent(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *PendingAuthSessionUpdateOne) SetProviderType(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableProviderType(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *PendingAuthSessionUpdateOne) SetProviderKey(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableProviderKey(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetProviderSubject sets the "provider_subject" field. +func (_u *PendingAuthSessionUpdateOne) SetProviderSubject(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetProviderSubject(v) + return _u +} + +// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableProviderSubject(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetProviderSubject(*v) + } + return _u +} + +// SetTargetUserID sets the "target_user_id" field. +func (_u *PendingAuthSessionUpdateOne) SetTargetUserID(v int64) *PendingAuthSessionUpdateOne { + _u.mutation.SetTargetUserID(v) + return _u +} + +// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableTargetUserID(v *int64) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetTargetUserID(*v) + } + return _u +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (_u *PendingAuthSessionUpdateOne) ClearTargetUserID() *PendingAuthSessionUpdateOne { + _u.mutation.ClearTargetUserID() + return _u +} + +// SetRedirectTo sets the "redirect_to" field. +func (_u *PendingAuthSessionUpdateOne) SetRedirectTo(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetRedirectTo(v) + return _u +} + +// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableRedirectTo(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetRedirectTo(*v) + } + return _u +} + +// SetResolvedEmail sets the "resolved_email" field. +func (_u *PendingAuthSessionUpdateOne) SetResolvedEmail(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetResolvedEmail(v) + return _u +} + +// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableResolvedEmail(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetResolvedEmail(*v) + } + return _u +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (_u *PendingAuthSessionUpdateOne) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetRegistrationPasswordHash(v) + return _u +} + +// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetRegistrationPasswordHash(*v) + } + return _u +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (_u *PendingAuthSessionUpdateOne) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpdateOne { + _u.mutation.SetUpstreamIdentityClaims(v) + return _u +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (_u *PendingAuthSessionUpdateOne) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpdateOne { + _u.mutation.SetLocalFlowState(v) + return _u +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (_u *PendingAuthSessionUpdateOne) SetBrowserSessionKey(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetBrowserSessionKey(v) + return _u +} + +// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetBrowserSessionKey(*v) + } + return _u +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (_u *PendingAuthSessionUpdateOne) SetCompletionCodeHash(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetCompletionCodeHash(v) + return _u +} + +// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetCompletionCodeHash(*v) + } + return _u +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (_u *PendingAuthSessionUpdateOne) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetCompletionCodeExpiresAt(v) + return _u +} + +// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetCompletionCodeExpiresAt(*v) + } + return _u +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (_u *PendingAuthSessionUpdateOne) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpdateOne { + _u.mutation.ClearCompletionCodeExpiresAt() + return _u +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetEmailVerifiedAt(v) + return _u +} + +// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetEmailVerifiedAt(*v) + } + return _u +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) ClearEmailVerifiedAt() *PendingAuthSessionUpdateOne { + _u.mutation.ClearEmailVerifiedAt() + return _u +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetPasswordVerifiedAt(v) + return _u +} + +// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetPasswordVerifiedAt(*v) + } + return _u +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) ClearPasswordVerifiedAt() *PendingAuthSessionUpdateOne { + _u.mutation.ClearPasswordVerifiedAt() + return _u +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetTotpVerifiedAt(v) + return _u +} + +// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetTotpVerifiedAt(*v) + } + return _u +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) ClearTotpVerifiedAt() *PendingAuthSessionUpdateOne { + _u.mutation.ClearTotpVerifiedAt() + return _u +} + +// SetExpiresAt sets the "expires_at" field. +func (_u *PendingAuthSessionUpdateOne) SetExpiresAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableExpiresAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// SetConsumedAt sets the "consumed_at" field. +func (_u *PendingAuthSessionUpdateOne) SetConsumedAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetConsumedAt(v) + return _u +} + +// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetConsumedAt(*v) + } + return _u +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (_u *PendingAuthSessionUpdateOne) ClearConsumedAt() *PendingAuthSessionUpdateOne { + _u.mutation.ClearConsumedAt() + return _u +} + +// SetTargetUser sets the "target_user" edge to the User entity. +func (_u *PendingAuthSessionUpdateOne) SetTargetUser(v *User) *PendingAuthSessionUpdateOne { + return _u.SetTargetUserID(v.ID) +} + +// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID. +func (_u *PendingAuthSessionUpdateOne) SetAdoptionDecisionID(id int64) *PendingAuthSessionUpdateOne { + _u.mutation.SetAdoptionDecisionID(id) + return _u +} + +// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionUpdateOne { + if id != nil { + _u = _u.SetAdoptionDecisionID(*id) + } + return _u +} + +// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (_u *PendingAuthSessionUpdateOne) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionUpdateOne { + return _u.SetAdoptionDecisionID(v.ID) +} + +// Mutation returns the PendingAuthSessionMutation object of the builder. +func (_u *PendingAuthSessionUpdateOne) Mutation() *PendingAuthSessionMutation { + return _u.mutation +} + +// ClearTargetUser clears the "target_user" edge to the User entity. +func (_u *PendingAuthSessionUpdateOne) ClearTargetUser() *PendingAuthSessionUpdateOne { + _u.mutation.ClearTargetUser() + return _u +} + +// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (_u *PendingAuthSessionUpdateOne) ClearAdoptionDecision() *PendingAuthSessionUpdateOne { + _u.mutation.ClearAdoptionDecision() + return _u +} + +// Where appends a list predicates to the PendingAuthSessionUpdate builder. +func (_u *PendingAuthSessionUpdateOne) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *PendingAuthSessionUpdateOne) Select(field string, fields ...string) *PendingAuthSessionUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated PendingAuthSession entity. +func (_u *PendingAuthSessionUpdateOne) Save(ctx context.Context) (*PendingAuthSession, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *PendingAuthSessionUpdateOne) SaveX(ctx context.Context) *PendingAuthSession { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *PendingAuthSessionUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *PendingAuthSessionUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *PendingAuthSessionUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := pendingauthsession.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *PendingAuthSessionUpdateOne) check() error { + if v, ok := _u.mutation.SessionToken(); ok { + if err := pendingauthsession.SessionTokenValidator(v); err != nil { + return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)} + } + } + if v, ok := _u.mutation.Intent(); ok { + if err := pendingauthsession.IntentValidator(v); err != nil { + return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderType(); ok { + if err := pendingauthsession.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := pendingauthsession.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderSubject(); ok { + if err := pendingauthsession.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)} + } + } + return nil +} + +func (_u *PendingAuthSessionUpdateOne) sqlSave(ctx context.Context) (_node *PendingAuthSession, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "PendingAuthSession.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, pendingauthsession.FieldID) + for _, f := range fields { + if !pendingauthsession.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != pendingauthsession.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.SessionToken(); ok { + _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value) + } + if value, ok := _u.mutation.Intent(); ok { + _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderSubject(); ok { + _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value) + } + if value, ok := _u.mutation.RedirectTo(); ok { + _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value) + } + if value, ok := _u.mutation.ResolvedEmail(); ok { + _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value) + } + if value, ok := _u.mutation.RegistrationPasswordHash(); ok { + _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value) + } + if value, ok := _u.mutation.UpstreamIdentityClaims(); ok { + _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value) + } + if value, ok := _u.mutation.LocalFlowState(); ok { + _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value) + } + if value, ok := _u.mutation.BrowserSessionKey(); ok { + _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value) + } + if value, ok := _u.mutation.CompletionCodeHash(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value) + } + if value, ok := _u.mutation.CompletionCodeExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value) + } + if _u.mutation.CompletionCodeExpiresAtCleared() { + _spec.ClearField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime) + } + if value, ok := _u.mutation.EmailVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value) + } + if _u.mutation.EmailVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.PasswordVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value) + } + if _u.mutation.PasswordVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.TotpVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value) + } + if _u.mutation.TotpVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ConsumedAt(); ok { + _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value) + } + if _u.mutation.ConsumedAtCleared() { + _spec.ClearField(pendingauthsession.FieldConsumedAt, field.TypeTime) + } + if _u.mutation.TargetUserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: pendingauthsession.TargetUserTable, + Columns: []string{pendingauthsession.TargetUserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.TargetUserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: pendingauthsession.TargetUserTable, + Columns: []string{pendingauthsession.TargetUserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.AdoptionDecisionCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: false, + Table: pendingauthsession.AdoptionDecisionTable, + Columns: []string{pendingauthsession.AdoptionDecisionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AdoptionDecisionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: false, + Table: pendingauthsession.AdoptionDecisionTable, + Columns: []string{pendingauthsession.AdoptionDecisionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &PendingAuthSession{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{pendingauthsession.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go index ef551940..0aa90b90 100644 --- a/backend/ent/predicate/predicate.go +++ b/backend/ent/predicate/predicate.go @@ -21,6 +21,12 @@ type Announcement func(*sql.Selector) // AnnouncementRead is the predicate function for announcementread builders. type AnnouncementRead func(*sql.Selector) +// AuthIdentity is the predicate function for authidentity builders. +type AuthIdentity func(*sql.Selector) + +// AuthIdentityChannel is the predicate function for authidentitychannel builders. +type AuthIdentityChannel func(*sql.Selector) + // ErrorPassthroughRule is the predicate function for errorpassthroughrule builders. type ErrorPassthroughRule func(*sql.Selector) @@ -30,6 +36,9 @@ type Group func(*sql.Selector) // IdempotencyRecord is the predicate function for idempotencyrecord builders. type IdempotencyRecord func(*sql.Selector) +// IdentityAdoptionDecision is the predicate function for identityadoptiondecision builders. +type IdentityAdoptionDecision func(*sql.Selector) + // PaymentAuditLog is the predicate function for paymentauditlog builders. type PaymentAuditLog func(*sql.Selector) @@ -39,6 +48,9 @@ type PaymentOrder func(*sql.Selector) // PaymentProviderInstance is the predicate function for paymentproviderinstance builders. type PaymentProviderInstance func(*sql.Selector) +// PendingAuthSession is the predicate function for pendingauthsession builders. +type PendingAuthSession func(*sql.Selector) + // PromoCode is the predicate function for promocode builders. type PromoCode func(*sql.Selector) diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index fbdd08c7..268e9ddb 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -10,12 +10,16 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" @@ -309,6 +313,120 @@ func init() { announcementreadDescCreatedAt := announcementreadFields[3].Descriptor() // announcementread.DefaultCreatedAt holds the default value on creation for the created_at field. announcementread.DefaultCreatedAt = announcementreadDescCreatedAt.Default.(func() time.Time) + authidentityMixin := schema.AuthIdentity{}.Mixin() + authidentityMixinFields0 := authidentityMixin[0].Fields() + _ = authidentityMixinFields0 + authidentityFields := schema.AuthIdentity{}.Fields() + _ = authidentityFields + // authidentityDescCreatedAt is the schema descriptor for created_at field. + authidentityDescCreatedAt := authidentityMixinFields0[0].Descriptor() + // authidentity.DefaultCreatedAt holds the default value on creation for the created_at field. + authidentity.DefaultCreatedAt = authidentityDescCreatedAt.Default.(func() time.Time) + // authidentityDescUpdatedAt is the schema descriptor for updated_at field. + authidentityDescUpdatedAt := authidentityMixinFields0[1].Descriptor() + // authidentity.DefaultUpdatedAt holds the default value on creation for the updated_at field. + authidentity.DefaultUpdatedAt = authidentityDescUpdatedAt.Default.(func() time.Time) + // authidentity.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + authidentity.UpdateDefaultUpdatedAt = authidentityDescUpdatedAt.UpdateDefault.(func() time.Time) + // authidentityDescProviderType is the schema descriptor for provider_type field. + authidentityDescProviderType := authidentityFields[1].Descriptor() + // authidentity.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + authidentity.ProviderTypeValidator = func() func(string) error { + validators := authidentityDescProviderType.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + validators[2].(func(string) error), + } + return func(provider_type string) error { + for _, fn := range fns { + if err := fn(provider_type); err != nil { + return err + } + } + return nil + } + }() + // authidentityDescProviderKey is the schema descriptor for provider_key field. + authidentityDescProviderKey := authidentityFields[2].Descriptor() + // authidentity.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + authidentity.ProviderKeyValidator = authidentityDescProviderKey.Validators[0].(func(string) error) + // authidentityDescProviderSubject is the schema descriptor for provider_subject field. + authidentityDescProviderSubject := authidentityFields[3].Descriptor() + // authidentity.ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save. + authidentity.ProviderSubjectValidator = authidentityDescProviderSubject.Validators[0].(func(string) error) + // authidentityDescMetadata is the schema descriptor for metadata field. + authidentityDescMetadata := authidentityFields[6].Descriptor() + // authidentity.DefaultMetadata holds the default value on creation for the metadata field. + authidentity.DefaultMetadata = authidentityDescMetadata.Default.(func() map[string]interface{}) + authidentitychannelMixin := schema.AuthIdentityChannel{}.Mixin() + authidentitychannelMixinFields0 := authidentitychannelMixin[0].Fields() + _ = authidentitychannelMixinFields0 + authidentitychannelFields := schema.AuthIdentityChannel{}.Fields() + _ = authidentitychannelFields + // authidentitychannelDescCreatedAt is the schema descriptor for created_at field. + authidentitychannelDescCreatedAt := authidentitychannelMixinFields0[0].Descriptor() + // authidentitychannel.DefaultCreatedAt holds the default value on creation for the created_at field. + authidentitychannel.DefaultCreatedAt = authidentitychannelDescCreatedAt.Default.(func() time.Time) + // authidentitychannelDescUpdatedAt is the schema descriptor for updated_at field. + authidentitychannelDescUpdatedAt := authidentitychannelMixinFields0[1].Descriptor() + // authidentitychannel.DefaultUpdatedAt holds the default value on creation for the updated_at field. + authidentitychannel.DefaultUpdatedAt = authidentitychannelDescUpdatedAt.Default.(func() time.Time) + // authidentitychannel.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + authidentitychannel.UpdateDefaultUpdatedAt = authidentitychannelDescUpdatedAt.UpdateDefault.(func() time.Time) + // authidentitychannelDescProviderType is the schema descriptor for provider_type field. + authidentitychannelDescProviderType := authidentitychannelFields[1].Descriptor() + // authidentitychannel.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + authidentitychannel.ProviderTypeValidator = func() func(string) error { + validators := authidentitychannelDescProviderType.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + validators[2].(func(string) error), + } + return func(provider_type string) error { + for _, fn := range fns { + if err := fn(provider_type); err != nil { + return err + } + } + return nil + } + }() + // authidentitychannelDescProviderKey is the schema descriptor for provider_key field. + authidentitychannelDescProviderKey := authidentitychannelFields[2].Descriptor() + // authidentitychannel.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + authidentitychannel.ProviderKeyValidator = authidentitychannelDescProviderKey.Validators[0].(func(string) error) + // authidentitychannelDescChannel is the schema descriptor for channel field. + authidentitychannelDescChannel := authidentitychannelFields[3].Descriptor() + // authidentitychannel.ChannelValidator is a validator for the "channel" field. It is called by the builders before save. + authidentitychannel.ChannelValidator = func() func(string) error { + validators := authidentitychannelDescChannel.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(channel string) error { + for _, fn := range fns { + if err := fn(channel); err != nil { + return err + } + } + return nil + } + }() + // authidentitychannelDescChannelAppID is the schema descriptor for channel_app_id field. + authidentitychannelDescChannelAppID := authidentitychannelFields[4].Descriptor() + // authidentitychannel.ChannelAppIDValidator is a validator for the "channel_app_id" field. It is called by the builders before save. + authidentitychannel.ChannelAppIDValidator = authidentitychannelDescChannelAppID.Validators[0].(func(string) error) + // authidentitychannelDescChannelSubject is the schema descriptor for channel_subject field. + authidentitychannelDescChannelSubject := authidentitychannelFields[5].Descriptor() + // authidentitychannel.ChannelSubjectValidator is a validator for the "channel_subject" field. It is called by the builders before save. + authidentitychannel.ChannelSubjectValidator = authidentitychannelDescChannelSubject.Validators[0].(func(string) error) + // authidentitychannelDescMetadata is the schema descriptor for metadata field. + authidentitychannelDescMetadata := authidentitychannelFields[6].Descriptor() + // authidentitychannel.DefaultMetadata holds the default value on creation for the metadata field. + authidentitychannel.DefaultMetadata = authidentitychannelDescMetadata.Default.(func() map[string]interface{}) errorpassthroughruleMixin := schema.ErrorPassthroughRule{}.Mixin() errorpassthroughruleMixinFields0 := errorpassthroughruleMixin[0].Fields() _ = errorpassthroughruleMixinFields0 @@ -512,6 +630,33 @@ func init() { idempotencyrecordDescErrorReason := idempotencyrecordFields[6].Descriptor() // idempotencyrecord.ErrorReasonValidator is a validator for the "error_reason" field. It is called by the builders before save. idempotencyrecord.ErrorReasonValidator = idempotencyrecordDescErrorReason.Validators[0].(func(string) error) + identityadoptiondecisionMixin := schema.IdentityAdoptionDecision{}.Mixin() + identityadoptiondecisionMixinFields0 := identityadoptiondecisionMixin[0].Fields() + _ = identityadoptiondecisionMixinFields0 + identityadoptiondecisionFields := schema.IdentityAdoptionDecision{}.Fields() + _ = identityadoptiondecisionFields + // identityadoptiondecisionDescCreatedAt is the schema descriptor for created_at field. + identityadoptiondecisionDescCreatedAt := identityadoptiondecisionMixinFields0[0].Descriptor() + // identityadoptiondecision.DefaultCreatedAt holds the default value on creation for the created_at field. + identityadoptiondecision.DefaultCreatedAt = identityadoptiondecisionDescCreatedAt.Default.(func() time.Time) + // identityadoptiondecisionDescUpdatedAt is the schema descriptor for updated_at field. + identityadoptiondecisionDescUpdatedAt := identityadoptiondecisionMixinFields0[1].Descriptor() + // identityadoptiondecision.DefaultUpdatedAt holds the default value on creation for the updated_at field. + identityadoptiondecision.DefaultUpdatedAt = identityadoptiondecisionDescUpdatedAt.Default.(func() time.Time) + // identityadoptiondecision.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + identityadoptiondecision.UpdateDefaultUpdatedAt = identityadoptiondecisionDescUpdatedAt.UpdateDefault.(func() time.Time) + // identityadoptiondecisionDescAdoptDisplayName is the schema descriptor for adopt_display_name field. + identityadoptiondecisionDescAdoptDisplayName := identityadoptiondecisionFields[2].Descriptor() + // identityadoptiondecision.DefaultAdoptDisplayName holds the default value on creation for the adopt_display_name field. + identityadoptiondecision.DefaultAdoptDisplayName = identityadoptiondecisionDescAdoptDisplayName.Default.(bool) + // identityadoptiondecisionDescAdoptAvatar is the schema descriptor for adopt_avatar field. + identityadoptiondecisionDescAdoptAvatar := identityadoptiondecisionFields[3].Descriptor() + // identityadoptiondecision.DefaultAdoptAvatar holds the default value on creation for the adopt_avatar field. + identityadoptiondecision.DefaultAdoptAvatar = identityadoptiondecisionDescAdoptAvatar.Default.(bool) + // identityadoptiondecisionDescDecidedAt is the schema descriptor for decided_at field. + identityadoptiondecisionDescDecidedAt := identityadoptiondecisionFields[4].Descriptor() + // identityadoptiondecision.DefaultDecidedAt holds the default value on creation for the decided_at field. + identityadoptiondecision.DefaultDecidedAt = identityadoptiondecisionDescDecidedAt.Default.(func() time.Time) paymentauditlogFields := schema.PaymentAuditLog{}.Fields() _ = paymentauditlogFields // paymentauditlogDescOrderID is the schema descriptor for order_id field. @@ -682,6 +827,113 @@ func init() { paymentproviderinstance.DefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.Default.(func() time.Time) // paymentproviderinstance.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. paymentproviderinstance.UpdateDefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.UpdateDefault.(func() time.Time) + pendingauthsessionMixin := schema.PendingAuthSession{}.Mixin() + pendingauthsessionMixinFields0 := pendingauthsessionMixin[0].Fields() + _ = pendingauthsessionMixinFields0 + pendingauthsessionFields := schema.PendingAuthSession{}.Fields() + _ = pendingauthsessionFields + // pendingauthsessionDescCreatedAt is the schema descriptor for created_at field. + pendingauthsessionDescCreatedAt := pendingauthsessionMixinFields0[0].Descriptor() + // pendingauthsession.DefaultCreatedAt holds the default value on creation for the created_at field. + pendingauthsession.DefaultCreatedAt = pendingauthsessionDescCreatedAt.Default.(func() time.Time) + // pendingauthsessionDescUpdatedAt is the schema descriptor for updated_at field. + pendingauthsessionDescUpdatedAt := pendingauthsessionMixinFields0[1].Descriptor() + // pendingauthsession.DefaultUpdatedAt holds the default value on creation for the updated_at field. + pendingauthsession.DefaultUpdatedAt = pendingauthsessionDescUpdatedAt.Default.(func() time.Time) + // pendingauthsession.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + pendingauthsession.UpdateDefaultUpdatedAt = pendingauthsessionDescUpdatedAt.UpdateDefault.(func() time.Time) + // pendingauthsessionDescSessionToken is the schema descriptor for session_token field. + pendingauthsessionDescSessionToken := pendingauthsessionFields[0].Descriptor() + // pendingauthsession.SessionTokenValidator is a validator for the "session_token" field. It is called by the builders before save. + pendingauthsession.SessionTokenValidator = func() func(string) error { + validators := pendingauthsessionDescSessionToken.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(session_token string) error { + for _, fn := range fns { + if err := fn(session_token); err != nil { + return err + } + } + return nil + } + }() + // pendingauthsessionDescIntent is the schema descriptor for intent field. + pendingauthsessionDescIntent := pendingauthsessionFields[1].Descriptor() + // pendingauthsession.IntentValidator is a validator for the "intent" field. It is called by the builders before save. + pendingauthsession.IntentValidator = func() func(string) error { + validators := pendingauthsessionDescIntent.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + validators[2].(func(string) error), + } + return func(intent string) error { + for _, fn := range fns { + if err := fn(intent); err != nil { + return err + } + } + return nil + } + }() + // pendingauthsessionDescProviderType is the schema descriptor for provider_type field. + pendingauthsessionDescProviderType := pendingauthsessionFields[2].Descriptor() + // pendingauthsession.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + pendingauthsession.ProviderTypeValidator = func() func(string) error { + validators := pendingauthsessionDescProviderType.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + validators[2].(func(string) error), + } + return func(provider_type string) error { + for _, fn := range fns { + if err := fn(provider_type); err != nil { + return err + } + } + return nil + } + }() + // pendingauthsessionDescProviderKey is the schema descriptor for provider_key field. + pendingauthsessionDescProviderKey := pendingauthsessionFields[3].Descriptor() + // pendingauthsession.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + pendingauthsession.ProviderKeyValidator = pendingauthsessionDescProviderKey.Validators[0].(func(string) error) + // pendingauthsessionDescProviderSubject is the schema descriptor for provider_subject field. + pendingauthsessionDescProviderSubject := pendingauthsessionFields[4].Descriptor() + // pendingauthsession.ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save. + pendingauthsession.ProviderSubjectValidator = pendingauthsessionDescProviderSubject.Validators[0].(func(string) error) + // pendingauthsessionDescRedirectTo is the schema descriptor for redirect_to field. + pendingauthsessionDescRedirectTo := pendingauthsessionFields[6].Descriptor() + // pendingauthsession.DefaultRedirectTo holds the default value on creation for the redirect_to field. + pendingauthsession.DefaultRedirectTo = pendingauthsessionDescRedirectTo.Default.(string) + // pendingauthsessionDescResolvedEmail is the schema descriptor for resolved_email field. + pendingauthsessionDescResolvedEmail := pendingauthsessionFields[7].Descriptor() + // pendingauthsession.DefaultResolvedEmail holds the default value on creation for the resolved_email field. + pendingauthsession.DefaultResolvedEmail = pendingauthsessionDescResolvedEmail.Default.(string) + // pendingauthsessionDescRegistrationPasswordHash is the schema descriptor for registration_password_hash field. + pendingauthsessionDescRegistrationPasswordHash := pendingauthsessionFields[8].Descriptor() + // pendingauthsession.DefaultRegistrationPasswordHash holds the default value on creation for the registration_password_hash field. + pendingauthsession.DefaultRegistrationPasswordHash = pendingauthsessionDescRegistrationPasswordHash.Default.(string) + // pendingauthsessionDescUpstreamIdentityClaims is the schema descriptor for upstream_identity_claims field. + pendingauthsessionDescUpstreamIdentityClaims := pendingauthsessionFields[9].Descriptor() + // pendingauthsession.DefaultUpstreamIdentityClaims holds the default value on creation for the upstream_identity_claims field. + pendingauthsession.DefaultUpstreamIdentityClaims = pendingauthsessionDescUpstreamIdentityClaims.Default.(func() map[string]interface{}) + // pendingauthsessionDescLocalFlowState is the schema descriptor for local_flow_state field. + pendingauthsessionDescLocalFlowState := pendingauthsessionFields[10].Descriptor() + // pendingauthsession.DefaultLocalFlowState holds the default value on creation for the local_flow_state field. + pendingauthsession.DefaultLocalFlowState = pendingauthsessionDescLocalFlowState.Default.(func() map[string]interface{}) + // pendingauthsessionDescBrowserSessionKey is the schema descriptor for browser_session_key field. + pendingauthsessionDescBrowserSessionKey := pendingauthsessionFields[11].Descriptor() + // pendingauthsession.DefaultBrowserSessionKey holds the default value on creation for the browser_session_key field. + pendingauthsession.DefaultBrowserSessionKey = pendingauthsessionDescBrowserSessionKey.Default.(string) + // pendingauthsessionDescCompletionCodeHash is the schema descriptor for completion_code_hash field. + pendingauthsessionDescCompletionCodeHash := pendingauthsessionFields[12].Descriptor() + // pendingauthsession.DefaultCompletionCodeHash holds the default value on creation for the completion_code_hash field. + pendingauthsession.DefaultCompletionCodeHash = pendingauthsessionDescCompletionCodeHash.Default.(string) promocodeFields := schema.PromoCode{}.Fields() _ = promocodeFields // promocodeDescCode is the schema descriptor for code field. @@ -1297,20 +1549,26 @@ func init() { userDescTotpEnabled := userFields[9].Descriptor() // user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field. user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool) + // userDescSignupSource is the schema descriptor for signup_source field. + userDescSignupSource := userFields[11].Descriptor() + // user.DefaultSignupSource holds the default value on creation for the signup_source field. + user.DefaultSignupSource = userDescSignupSource.Default.(string) + // user.SignupSourceValidator is a validator for the "signup_source" field. It is called by the builders before save. + user.SignupSourceValidator = userDescSignupSource.Validators[0].(func(string) error) // userDescBalanceNotifyEnabled is the schema descriptor for balance_notify_enabled field. - userDescBalanceNotifyEnabled := userFields[11].Descriptor() + userDescBalanceNotifyEnabled := userFields[14].Descriptor() // user.DefaultBalanceNotifyEnabled holds the default value on creation for the balance_notify_enabled field. user.DefaultBalanceNotifyEnabled = userDescBalanceNotifyEnabled.Default.(bool) // userDescBalanceNotifyThresholdType is the schema descriptor for balance_notify_threshold_type field. - userDescBalanceNotifyThresholdType := userFields[12].Descriptor() + userDescBalanceNotifyThresholdType := userFields[15].Descriptor() // user.DefaultBalanceNotifyThresholdType holds the default value on creation for the balance_notify_threshold_type field. user.DefaultBalanceNotifyThresholdType = userDescBalanceNotifyThresholdType.Default.(string) // userDescBalanceNotifyExtraEmails is the schema descriptor for balance_notify_extra_emails field. - userDescBalanceNotifyExtraEmails := userFields[14].Descriptor() + userDescBalanceNotifyExtraEmails := userFields[17].Descriptor() // user.DefaultBalanceNotifyExtraEmails holds the default value on creation for the balance_notify_extra_emails field. user.DefaultBalanceNotifyExtraEmails = userDescBalanceNotifyExtraEmails.Default.(string) // userDescTotalRecharged is the schema descriptor for total_recharged field. - userDescTotalRecharged := userFields[15].Descriptor() + userDescTotalRecharged := userFields[18].Descriptor() // user.DefaultTotalRecharged holds the default value on creation for the total_recharged field. user.DefaultTotalRecharged = userDescTotalRecharged.Default.(float64) userallowedgroupFields := schema.UserAllowedGroup{}.Fields() diff --git a/backend/ent/schema/auth_identity.go b/backend/ent/schema/auth_identity.go new file mode 100644 index 00000000..e4b9ac90 --- /dev/null +++ b/backend/ent/schema/auth_identity.go @@ -0,0 +1,93 @@ +package schema + +import ( + "fmt" + + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +var authProviderTypes = map[string]struct{}{ + "email": {}, + "linuxdo": {}, + "oidc": {}, + "wechat": {}, +} + +func validateAuthProviderType(value string) error { + if _, ok := authProviderTypes[value]; ok { + return nil + } + return fmt.Errorf("invalid auth provider type %q", value) +} + +// AuthIdentity stores the canonical login identity for an account. +type AuthIdentity struct { + ent.Schema +} + +func (AuthIdentity) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "auth_identities"}, + } +} + +func (AuthIdentity) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (AuthIdentity) Fields() []ent.Field { + return []ent.Field{ + field.Int64("user_id"), + field.String("provider_type"). + MaxLen(20). + NotEmpty(). + Validate(validateAuthProviderType), + field.String("provider_key"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("provider_subject"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.Time("verified_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.String("issuer"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.JSON("metadata", map[string]any{}). + Default(func() map[string]any { return map[string]any{} }). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + } +} + +func (AuthIdentity) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("user", User.Type). + Ref("auth_identities"). + Field("user_id"). + Required(). + Unique(), + edge.To("channels", AuthIdentityChannel.Type), + edge.To("adoption_decisions", IdentityAdoptionDecision.Type), + } +} + +func (AuthIdentity) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("provider_type", "provider_key", "provider_subject").Unique(), + index.Fields("user_id"), + index.Fields("user_id", "provider_type"), + } +} diff --git a/backend/ent/schema/auth_identity_channel.go b/backend/ent/schema/auth_identity_channel.go new file mode 100644 index 00000000..69f2ad02 --- /dev/null +++ b/backend/ent/schema/auth_identity_channel.go @@ -0,0 +1,72 @@ +package schema + +import ( + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// AuthIdentityChannel stores channel-scoped identifiers for a canonical identity. +type AuthIdentityChannel struct { + ent.Schema +} + +func (AuthIdentityChannel) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "auth_identity_channels"}, + } +} + +func (AuthIdentityChannel) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (AuthIdentityChannel) Fields() []ent.Field { + return []ent.Field{ + field.Int64("identity_id"), + field.String("provider_type"). + MaxLen(20). + NotEmpty(). + Validate(validateAuthProviderType), + field.String("provider_key"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("channel"). + MaxLen(20). + NotEmpty(), + field.String("channel_app_id"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("channel_subject"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.JSON("metadata", map[string]any{}). + Default(func() map[string]any { return map[string]any{} }). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + } +} + +func (AuthIdentityChannel) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("identity", AuthIdentity.Type). + Ref("channels"). + Field("identity_id"). + Required(). + Unique(), + } +} + +func (AuthIdentityChannel) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("provider_type", "provider_key", "channel", "channel_app_id", "channel_subject").Unique(), + index.Fields("identity_id"), + } +} diff --git a/backend/ent/schema/auth_identity_schema_test.go b/backend/ent/schema/auth_identity_schema_test.go new file mode 100644 index 00000000..de55dd69 --- /dev/null +++ b/backend/ent/schema/auth_identity_schema_test.go @@ -0,0 +1,124 @@ +package schema + +import ( + "testing" + + "entgo.io/ent/entc/load" + "github.com/stretchr/testify/require" +) + +func TestAuthIdentityFoundationSchemas(t *testing.T) { + spec, err := (&load.Config{Path: "."}).Load() + require.NoError(t, err) + + schemas := map[string]*load.Schema{} + for _, schema := range spec.Schemas { + schemas[schema.Name] = schema + } + + authIdentity := requireSchema(t, schemas, "AuthIdentity") + requireSchemaFields(t, authIdentity, + "user_id", + "provider_type", + "provider_key", + "provider_subject", + "verified_at", + "issuer", + "metadata", + ) + requireHasUniqueIndex(t, authIdentity, "provider_type", "provider_key", "provider_subject") + + authIdentityChannel := requireSchema(t, schemas, "AuthIdentityChannel") + requireSchemaFields(t, authIdentityChannel, + "identity_id", + "provider_type", + "provider_key", + "channel", + "channel_app_id", + "channel_subject", + "metadata", + ) + requireHasUniqueIndex(t, authIdentityChannel, "provider_type", "provider_key", "channel", "channel_app_id", "channel_subject") + + pendingAuthSession := requireSchema(t, schemas, "PendingAuthSession") + requireSchemaFields(t, pendingAuthSession, + "intent", + "provider_type", + "provider_key", + "provider_subject", + "target_user_id", + "redirect_to", + "resolved_email", + "registration_password_hash", + "upstream_identity_claims", + "local_flow_state", + "browser_session_key", + "completion_code_hash", + "completion_code_expires_at", + "email_verified_at", + "password_verified_at", + "totp_verified_at", + "expires_at", + "consumed_at", + ) + + adoptionDecision := requireSchema(t, schemas, "IdentityAdoptionDecision") + requireSchemaFields(t, adoptionDecision, + "pending_auth_session_id", + "identity_id", + "adopt_display_name", + "adopt_avatar", + "decided_at", + ) + requireHasUniqueIndex(t, adoptionDecision, "pending_auth_session_id") + + userSchema := requireSchema(t, schemas, "User") + requireSchemaFields(t, userSchema, "signup_source", "last_login_at", "last_active_at") +} + +func requireSchema(t *testing.T, schemas map[string]*load.Schema, name string) *load.Schema { + t.Helper() + + schema, ok := schemas[name] + require.True(t, ok, "schema %s should exist", name) + return schema +} + +func requireSchemaFields(t *testing.T, schema *load.Schema, names ...string) { + t.Helper() + + fields := map[string]struct{}{} + for _, field := range schema.Fields { + fields[field.Name] = struct{}{} + } + + for _, name := range names { + _, ok := fields[name] + require.True(t, ok, "schema %s should include field %s", schema.Name, name) + } +} + +func requireHasUniqueIndex(t *testing.T, schema *load.Schema, fields ...string) { + t.Helper() + + for _, index := range schema.Indexes { + if !index.Unique { + continue + } + if len(index.Fields) != len(fields) { + continue + } + match := true + for i := range fields { + if index.Fields[i] != fields[i] { + match = false + break + } + } + if match { + return + } + } + + require.Failf(t, "missing unique index", "schema %s should include unique index on %v", schema.Name, fields) +} diff --git a/backend/ent/schema/identity_adoption_decision.go b/backend/ent/schema/identity_adoption_decision.go new file mode 100644 index 00000000..9fdd26fb --- /dev/null +++ b/backend/ent/schema/identity_adoption_decision.go @@ -0,0 +1,70 @@ +package schema + +import ( + "time" + + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// IdentityAdoptionDecision stores the one-time profile adoption choice captured during a pending auth flow. +type IdentityAdoptionDecision struct { + ent.Schema +} + +func (IdentityAdoptionDecision) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "identity_adoption_decisions"}, + } +} + +func (IdentityAdoptionDecision) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (IdentityAdoptionDecision) Fields() []ent.Field { + return []ent.Field{ + field.Int64("pending_auth_session_id"), + field.Int64("identity_id"). + Optional(). + Nillable(), + field.Bool("adopt_display_name"). + Default(false), + field.Bool("adopt_avatar"). + Default(false), + field.Time("decided_at"). + Immutable(). + Default(time.Now). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + } +} + +func (IdentityAdoptionDecision) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("pending_auth_session", PendingAuthSession.Type). + Ref("adoption_decision"). + Field("pending_auth_session_id"). + Required(). + Unique(), + edge.From("identity", AuthIdentity.Type). + Ref("adoption_decisions"). + Field("identity_id"). + Unique(), + } +} + +func (IdentityAdoptionDecision) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("pending_auth_session_id").Unique(), + index.Fields("identity_id"), + } +} diff --git a/backend/ent/schema/pending_auth_session.go b/backend/ent/schema/pending_auth_session.go new file mode 100644 index 00000000..91341d49 --- /dev/null +++ b/backend/ent/schema/pending_auth_session.go @@ -0,0 +1,134 @@ +package schema + +import ( + "fmt" + + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +var pendingAuthIntents = map[string]struct{}{ + "login": {}, + "bind_current_user": {}, + "adopt_existing_user_by_email": {}, +} + +func validatePendingAuthIntent(value string) error { + if _, ok := pendingAuthIntents[value]; ok { + return nil + } + return fmt.Errorf("invalid pending auth intent %q", value) +} + +// PendingAuthSession stores a short-lived post-auth decision session. +type PendingAuthSession struct { + ent.Schema +} + +func (PendingAuthSession) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "pending_auth_sessions"}, + } +} + +func (PendingAuthSession) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (PendingAuthSession) Fields() []ent.Field { + return []ent.Field{ + field.String("session_token"). + MaxLen(255). + NotEmpty(), + field.String("intent"). + MaxLen(40). + NotEmpty(). + Validate(validatePendingAuthIntent), + field.String("provider_type"). + MaxLen(20). + NotEmpty(). + Validate(validateAuthProviderType), + field.String("provider_key"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("provider_subject"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.Int64("target_user_id"). + Optional(). + Nillable(), + field.String("redirect_to"). + Default(""). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("resolved_email"). + Default(""). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("registration_password_hash"). + Default(""). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.JSON("upstream_identity_claims", map[string]any{}). + Default(func() map[string]any { return map[string]any{} }). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + field.JSON("local_flow_state", map[string]any{}). + Default(func() map[string]any { return map[string]any{} }). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + field.String("browser_session_key"). + Default(""). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("completion_code_hash"). + Default(""). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.Time("completion_code_expires_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("email_verified_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("password_verified_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("totp_verified_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("expires_at"). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("consumed_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + } +} + +func (PendingAuthSession) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("target_user", User.Type). + Ref("pending_auth_sessions"). + Field("target_user_id"). + Unique(), + edge.To("adoption_decision", IdentityAdoptionDecision.Type). + Unique(), + } +} + +func (PendingAuthSession) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("session_token").Unique(), + index.Fields("target_user_id"), + index.Fields("expires_at"), + index.Fields("provider_type", "provider_key", "provider_subject"), + index.Fields("completion_code_hash"), + } +} diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index ef52e985..bb58d9e3 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -72,6 +72,17 @@ func (User) Fields() []ent.Field { field.Time("totp_enabled_at"). Optional(). Nillable(), + field.String("signup_source"). + MaxLen(20). + Default("email"), + field.Time("last_login_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("last_active_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), // 余额不足通知 field.Bool("balance_notify_enabled"). @@ -104,6 +115,8 @@ func (User) Edges() []ent.Edge { edge.To("attribute_values", UserAttributeValue.Type), edge.To("promo_code_usages", PromoCodeUsage.Type), edge.To("payment_orders", PaymentOrder.Type), + edge.To("auth_identities", AuthIdentity.Type), + edge.To("pending_auth_sessions", PendingAuthSession.Type), } } diff --git a/backend/ent/tx.go b/backend/ent/tx.go index bb3139d5..bde3e35b 100644 --- a/backend/ent/tx.go +++ b/backend/ent/tx.go @@ -24,18 +24,26 @@ type Tx struct { Announcement *AnnouncementClient // AnnouncementRead is the client for interacting with the AnnouncementRead builders. AnnouncementRead *AnnouncementReadClient + // AuthIdentity is the client for interacting with the AuthIdentity builders. + AuthIdentity *AuthIdentityClient + // AuthIdentityChannel is the client for interacting with the AuthIdentityChannel builders. + AuthIdentityChannel *AuthIdentityChannelClient // ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders. ErrorPassthroughRule *ErrorPassthroughRuleClient // Group is the client for interacting with the Group builders. Group *GroupClient // IdempotencyRecord is the client for interacting with the IdempotencyRecord builders. IdempotencyRecord *IdempotencyRecordClient + // IdentityAdoptionDecision is the client for interacting with the IdentityAdoptionDecision builders. + IdentityAdoptionDecision *IdentityAdoptionDecisionClient // PaymentAuditLog is the client for interacting with the PaymentAuditLog builders. PaymentAuditLog *PaymentAuditLogClient // PaymentOrder is the client for interacting with the PaymentOrder builders. PaymentOrder *PaymentOrderClient // PaymentProviderInstance is the client for interacting with the PaymentProviderInstance builders. PaymentProviderInstance *PaymentProviderInstanceClient + // PendingAuthSession is the client for interacting with the PendingAuthSession builders. + PendingAuthSession *PendingAuthSessionClient // PromoCode is the client for interacting with the PromoCode builders. PromoCode *PromoCodeClient // PromoCodeUsage is the client for interacting with the PromoCodeUsage builders. @@ -202,12 +210,16 @@ func (tx *Tx) init() { tx.AccountGroup = NewAccountGroupClient(tx.config) tx.Announcement = NewAnnouncementClient(tx.config) tx.AnnouncementRead = NewAnnouncementReadClient(tx.config) + tx.AuthIdentity = NewAuthIdentityClient(tx.config) + tx.AuthIdentityChannel = NewAuthIdentityChannelClient(tx.config) tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config) tx.Group = NewGroupClient(tx.config) tx.IdempotencyRecord = NewIdempotencyRecordClient(tx.config) + tx.IdentityAdoptionDecision = NewIdentityAdoptionDecisionClient(tx.config) tx.PaymentAuditLog = NewPaymentAuditLogClient(tx.config) tx.PaymentOrder = NewPaymentOrderClient(tx.config) tx.PaymentProviderInstance = NewPaymentProviderInstanceClient(tx.config) + tx.PendingAuthSession = NewPendingAuthSessionClient(tx.config) tx.PromoCode = NewPromoCodeClient(tx.config) tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config) tx.Proxy = NewProxyClient(tx.config) diff --git a/backend/ent/user.go b/backend/ent/user.go index 9fa91f74..66f33623 100644 --- a/backend/ent/user.go +++ b/backend/ent/user.go @@ -45,6 +45,12 @@ type User struct { TotpEnabled bool `json:"totp_enabled,omitempty"` // TotpEnabledAt holds the value of the "totp_enabled_at" field. TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"` + // SignupSource holds the value of the "signup_source" field. + SignupSource string `json:"signup_source,omitempty"` + // LastLoginAt holds the value of the "last_login_at" field. + LastLoginAt *time.Time `json:"last_login_at,omitempty"` + // LastActiveAt holds the value of the "last_active_at" field. + LastActiveAt *time.Time `json:"last_active_at,omitempty"` // BalanceNotifyEnabled holds the value of the "balance_notify_enabled" field. BalanceNotifyEnabled bool `json:"balance_notify_enabled,omitempty"` // BalanceNotifyThresholdType holds the value of the "balance_notify_threshold_type" field. @@ -83,11 +89,15 @@ type UserEdges struct { PromoCodeUsages []*PromoCodeUsage `json:"promo_code_usages,omitempty"` // PaymentOrders holds the value of the payment_orders edge. PaymentOrders []*PaymentOrder `json:"payment_orders,omitempty"` + // AuthIdentities holds the value of the auth_identities edge. + AuthIdentities []*AuthIdentity `json:"auth_identities,omitempty"` + // PendingAuthSessions holds the value of the pending_auth_sessions edge. + PendingAuthSessions []*PendingAuthSession `json:"pending_auth_sessions,omitempty"` // UserAllowedGroups holds the value of the user_allowed_groups edge. UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. - loadedTypes [11]bool + loadedTypes [13]bool } // APIKeysOrErr returns the APIKeys value or an error if the edge @@ -180,10 +190,28 @@ func (e UserEdges) PaymentOrdersOrErr() ([]*PaymentOrder, error) { return nil, &NotLoadedError{edge: "payment_orders"} } +// AuthIdentitiesOrErr returns the AuthIdentities value or an error if the edge +// was not loaded in eager-loading. +func (e UserEdges) AuthIdentitiesOrErr() ([]*AuthIdentity, error) { + if e.loadedTypes[10] { + return e.AuthIdentities, nil + } + return nil, &NotLoadedError{edge: "auth_identities"} +} + +// PendingAuthSessionsOrErr returns the PendingAuthSessions value or an error if the edge +// was not loaded in eager-loading. +func (e UserEdges) PendingAuthSessionsOrErr() ([]*PendingAuthSession, error) { + if e.loadedTypes[11] { + return e.PendingAuthSessions, nil + } + return nil, &NotLoadedError{edge: "pending_auth_sessions"} +} + // UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge // was not loaded in eager-loading. func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) { - if e.loadedTypes[10] { + if e.loadedTypes[12] { return e.UserAllowedGroups, nil } return nil, &NotLoadedError{edge: "user_allowed_groups"} @@ -200,9 +228,9 @@ func (*User) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullFloat64) case user.FieldID, user.FieldConcurrency: values[i] = new(sql.NullInt64) - case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails: + case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldSignupSource, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails: values[i] = new(sql.NullString) - case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt: + case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt, user.FieldLastLoginAt, user.FieldLastActiveAt: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -312,6 +340,26 @@ func (_m *User) assignValues(columns []string, values []any) error { _m.TotpEnabledAt = new(time.Time) *_m.TotpEnabledAt = value.Time } + case user.FieldSignupSource: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field signup_source", values[i]) + } else if value.Valid { + _m.SignupSource = value.String + } + case user.FieldLastLoginAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field last_login_at", values[i]) + } else if value.Valid { + _m.LastLoginAt = new(time.Time) + *_m.LastLoginAt = value.Time + } + case user.FieldLastActiveAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field last_active_at", values[i]) + } else if value.Valid { + _m.LastActiveAt = new(time.Time) + *_m.LastActiveAt = value.Time + } case user.FieldBalanceNotifyEnabled: if value, ok := values[i].(*sql.NullBool); !ok { return fmt.Errorf("unexpected type %T for field balance_notify_enabled", values[i]) @@ -406,6 +454,16 @@ func (_m *User) QueryPaymentOrders() *PaymentOrderQuery { return NewUserClient(_m.config).QueryPaymentOrders(_m) } +// QueryAuthIdentities queries the "auth_identities" edge of the User entity. +func (_m *User) QueryAuthIdentities() *AuthIdentityQuery { + return NewUserClient(_m.config).QueryAuthIdentities(_m) +} + +// QueryPendingAuthSessions queries the "pending_auth_sessions" edge of the User entity. +func (_m *User) QueryPendingAuthSessions() *PendingAuthSessionQuery { + return NewUserClient(_m.config).QueryPendingAuthSessions(_m) +} + // QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity. func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery { return NewUserClient(_m.config).QueryUserAllowedGroups(_m) @@ -482,6 +540,19 @@ func (_m *User) String() string { builder.WriteString(v.Format(time.ANSIC)) } builder.WriteString(", ") + builder.WriteString("signup_source=") + builder.WriteString(_m.SignupSource) + builder.WriteString(", ") + if v := _m.LastLoginAt; v != nil { + builder.WriteString("last_login_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.LastActiveAt; v != nil { + builder.WriteString("last_active_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") builder.WriteString("balance_notify_enabled=") builder.WriteString(fmt.Sprintf("%v", _m.BalanceNotifyEnabled)) builder.WriteString(", ") diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go index d88a3a38..567e3b14 100644 --- a/backend/ent/user/user.go +++ b/backend/ent/user/user.go @@ -43,6 +43,12 @@ const ( FieldTotpEnabled = "totp_enabled" // FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database. FieldTotpEnabledAt = "totp_enabled_at" + // FieldSignupSource holds the string denoting the signup_source field in the database. + FieldSignupSource = "signup_source" + // FieldLastLoginAt holds the string denoting the last_login_at field in the database. + FieldLastLoginAt = "last_login_at" + // FieldLastActiveAt holds the string denoting the last_active_at field in the database. + FieldLastActiveAt = "last_active_at" // FieldBalanceNotifyEnabled holds the string denoting the balance_notify_enabled field in the database. FieldBalanceNotifyEnabled = "balance_notify_enabled" // FieldBalanceNotifyThresholdType holds the string denoting the balance_notify_threshold_type field in the database. @@ -73,6 +79,10 @@ const ( EdgePromoCodeUsages = "promo_code_usages" // EdgePaymentOrders holds the string denoting the payment_orders edge name in mutations. EdgePaymentOrders = "payment_orders" + // EdgeAuthIdentities holds the string denoting the auth_identities edge name in mutations. + EdgeAuthIdentities = "auth_identities" + // EdgePendingAuthSessions holds the string denoting the pending_auth_sessions edge name in mutations. + EdgePendingAuthSessions = "pending_auth_sessions" // EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations. EdgeUserAllowedGroups = "user_allowed_groups" // Table holds the table name of the user in the database. @@ -145,6 +155,20 @@ const ( PaymentOrdersInverseTable = "payment_orders" // PaymentOrdersColumn is the table column denoting the payment_orders relation/edge. PaymentOrdersColumn = "user_id" + // AuthIdentitiesTable is the table that holds the auth_identities relation/edge. + AuthIdentitiesTable = "auth_identities" + // AuthIdentitiesInverseTable is the table name for the AuthIdentity entity. + // It exists in this package in order to avoid circular dependency with the "authidentity" package. + AuthIdentitiesInverseTable = "auth_identities" + // AuthIdentitiesColumn is the table column denoting the auth_identities relation/edge. + AuthIdentitiesColumn = "user_id" + // PendingAuthSessionsTable is the table that holds the pending_auth_sessions relation/edge. + PendingAuthSessionsTable = "pending_auth_sessions" + // PendingAuthSessionsInverseTable is the table name for the PendingAuthSession entity. + // It exists in this package in order to avoid circular dependency with the "pendingauthsession" package. + PendingAuthSessionsInverseTable = "pending_auth_sessions" + // PendingAuthSessionsColumn is the table column denoting the pending_auth_sessions relation/edge. + PendingAuthSessionsColumn = "target_user_id" // UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge. UserAllowedGroupsTable = "user_allowed_groups" // UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity. @@ -171,6 +195,9 @@ var Columns = []string{ FieldTotpSecretEncrypted, FieldTotpEnabled, FieldTotpEnabledAt, + FieldSignupSource, + FieldLastLoginAt, + FieldLastActiveAt, FieldBalanceNotifyEnabled, FieldBalanceNotifyThresholdType, FieldBalanceNotifyThreshold, @@ -232,6 +259,10 @@ var ( DefaultNotes string // DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field. DefaultTotpEnabled bool + // DefaultSignupSource holds the default value on creation for the "signup_source" field. + DefaultSignupSource string + // SignupSourceValidator is a validator for the "signup_source" field. It is called by the builders before save. + SignupSourceValidator func(string) error // DefaultBalanceNotifyEnabled holds the default value on creation for the "balance_notify_enabled" field. DefaultBalanceNotifyEnabled bool // DefaultBalanceNotifyThresholdType holds the default value on creation for the "balance_notify_threshold_type" field. @@ -320,6 +351,21 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc() } +// BySignupSource orders the results by the signup_source field. +func BySignupSource(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSignupSource, opts...).ToFunc() +} + +// ByLastLoginAt orders the results by the last_login_at field. +func ByLastLoginAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLastLoginAt, opts...).ToFunc() +} + +// ByLastActiveAt orders the results by the last_active_at field. +func ByLastActiveAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLastActiveAt, opts...).ToFunc() +} + // ByBalanceNotifyEnabled orders the results by the balance_notify_enabled field. func ByBalanceNotifyEnabled(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldBalanceNotifyEnabled, opts...).ToFunc() @@ -485,6 +531,34 @@ func ByPaymentOrders(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { } } +// ByAuthIdentitiesCount orders the results by auth_identities count. +func ByAuthIdentitiesCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newAuthIdentitiesStep(), opts...) + } +} + +// ByAuthIdentities orders the results by auth_identities terms. +func ByAuthIdentities(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAuthIdentitiesStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByPendingAuthSessionsCount orders the results by pending_auth_sessions count. +func ByPendingAuthSessionsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newPendingAuthSessionsStep(), opts...) + } +} + +// ByPendingAuthSessions orders the results by pending_auth_sessions terms. +func ByPendingAuthSessions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newPendingAuthSessionsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + // ByUserAllowedGroupsCount orders the results by user_allowed_groups count. func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { @@ -568,6 +642,20 @@ func newPaymentOrdersStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.O2M, false, PaymentOrdersTable, PaymentOrdersColumn), ) } +func newAuthIdentitiesStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AuthIdentitiesInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, AuthIdentitiesTable, AuthIdentitiesColumn), + ) +} +func newPendingAuthSessionsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(PendingAuthSessionsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn), + ) +} func newUserAllowedGroupsStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go index 2788aa7a..cbcfcc26 100644 --- a/backend/ent/user/where.go +++ b/backend/ent/user/where.go @@ -125,6 +125,21 @@ func TotpEnabledAt(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v)) } +// SignupSource applies equality check predicate on the "signup_source" field. It's identical to SignupSourceEQ. +func SignupSource(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldSignupSource, v)) +} + +// LastLoginAt applies equality check predicate on the "last_login_at" field. It's identical to LastLoginAtEQ. +func LastLoginAt(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldLastLoginAt, v)) +} + +// LastActiveAt applies equality check predicate on the "last_active_at" field. It's identical to LastActiveAtEQ. +func LastActiveAt(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldLastActiveAt, v)) +} + // BalanceNotifyEnabled applies equality check predicate on the "balance_notify_enabled" field. It's identical to BalanceNotifyEnabledEQ. func BalanceNotifyEnabled(v bool) predicate.User { return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v)) @@ -885,6 +900,171 @@ func TotpEnabledAtNotNil() predicate.User { return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt)) } +// SignupSourceEQ applies the EQ predicate on the "signup_source" field. +func SignupSourceEQ(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldSignupSource, v)) +} + +// SignupSourceNEQ applies the NEQ predicate on the "signup_source" field. +func SignupSourceNEQ(v string) predicate.User { + return predicate.User(sql.FieldNEQ(FieldSignupSource, v)) +} + +// SignupSourceIn applies the In predicate on the "signup_source" field. +func SignupSourceIn(vs ...string) predicate.User { + return predicate.User(sql.FieldIn(FieldSignupSource, vs...)) +} + +// SignupSourceNotIn applies the NotIn predicate on the "signup_source" field. +func SignupSourceNotIn(vs ...string) predicate.User { + return predicate.User(sql.FieldNotIn(FieldSignupSource, vs...)) +} + +// SignupSourceGT applies the GT predicate on the "signup_source" field. +func SignupSourceGT(v string) predicate.User { + return predicate.User(sql.FieldGT(FieldSignupSource, v)) +} + +// SignupSourceGTE applies the GTE predicate on the "signup_source" field. +func SignupSourceGTE(v string) predicate.User { + return predicate.User(sql.FieldGTE(FieldSignupSource, v)) +} + +// SignupSourceLT applies the LT predicate on the "signup_source" field. +func SignupSourceLT(v string) predicate.User { + return predicate.User(sql.FieldLT(FieldSignupSource, v)) +} + +// SignupSourceLTE applies the LTE predicate on the "signup_source" field. +func SignupSourceLTE(v string) predicate.User { + return predicate.User(sql.FieldLTE(FieldSignupSource, v)) +} + +// SignupSourceContains applies the Contains predicate on the "signup_source" field. +func SignupSourceContains(v string) predicate.User { + return predicate.User(sql.FieldContains(FieldSignupSource, v)) +} + +// SignupSourceHasPrefix applies the HasPrefix predicate on the "signup_source" field. +func SignupSourceHasPrefix(v string) predicate.User { + return predicate.User(sql.FieldHasPrefix(FieldSignupSource, v)) +} + +// SignupSourceHasSuffix applies the HasSuffix predicate on the "signup_source" field. +func SignupSourceHasSuffix(v string) predicate.User { + return predicate.User(sql.FieldHasSuffix(FieldSignupSource, v)) +} + +// SignupSourceEqualFold applies the EqualFold predicate on the "signup_source" field. +func SignupSourceEqualFold(v string) predicate.User { + return predicate.User(sql.FieldEqualFold(FieldSignupSource, v)) +} + +// SignupSourceContainsFold applies the ContainsFold predicate on the "signup_source" field. +func SignupSourceContainsFold(v string) predicate.User { + return predicate.User(sql.FieldContainsFold(FieldSignupSource, v)) +} + +// LastLoginAtEQ applies the EQ predicate on the "last_login_at" field. +func LastLoginAtEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldLastLoginAt, v)) +} + +// LastLoginAtNEQ applies the NEQ predicate on the "last_login_at" field. +func LastLoginAtNEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldNEQ(FieldLastLoginAt, v)) +} + +// LastLoginAtIn applies the In predicate on the "last_login_at" field. +func LastLoginAtIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldIn(FieldLastLoginAt, vs...)) +} + +// LastLoginAtNotIn applies the NotIn predicate on the "last_login_at" field. +func LastLoginAtNotIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldNotIn(FieldLastLoginAt, vs...)) +} + +// LastLoginAtGT applies the GT predicate on the "last_login_at" field. +func LastLoginAtGT(v time.Time) predicate.User { + return predicate.User(sql.FieldGT(FieldLastLoginAt, v)) +} + +// LastLoginAtGTE applies the GTE predicate on the "last_login_at" field. +func LastLoginAtGTE(v time.Time) predicate.User { + return predicate.User(sql.FieldGTE(FieldLastLoginAt, v)) +} + +// LastLoginAtLT applies the LT predicate on the "last_login_at" field. +func LastLoginAtLT(v time.Time) predicate.User { + return predicate.User(sql.FieldLT(FieldLastLoginAt, v)) +} + +// LastLoginAtLTE applies the LTE predicate on the "last_login_at" field. +func LastLoginAtLTE(v time.Time) predicate.User { + return predicate.User(sql.FieldLTE(FieldLastLoginAt, v)) +} + +// LastLoginAtIsNil applies the IsNil predicate on the "last_login_at" field. +func LastLoginAtIsNil() predicate.User { + return predicate.User(sql.FieldIsNull(FieldLastLoginAt)) +} + +// LastLoginAtNotNil applies the NotNil predicate on the "last_login_at" field. +func LastLoginAtNotNil() predicate.User { + return predicate.User(sql.FieldNotNull(FieldLastLoginAt)) +} + +// LastActiveAtEQ applies the EQ predicate on the "last_active_at" field. +func LastActiveAtEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldLastActiveAt, v)) +} + +// LastActiveAtNEQ applies the NEQ predicate on the "last_active_at" field. +func LastActiveAtNEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldNEQ(FieldLastActiveAt, v)) +} + +// LastActiveAtIn applies the In predicate on the "last_active_at" field. +func LastActiveAtIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldIn(FieldLastActiveAt, vs...)) +} + +// LastActiveAtNotIn applies the NotIn predicate on the "last_active_at" field. +func LastActiveAtNotIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldNotIn(FieldLastActiveAt, vs...)) +} + +// LastActiveAtGT applies the GT predicate on the "last_active_at" field. +func LastActiveAtGT(v time.Time) predicate.User { + return predicate.User(sql.FieldGT(FieldLastActiveAt, v)) +} + +// LastActiveAtGTE applies the GTE predicate on the "last_active_at" field. +func LastActiveAtGTE(v time.Time) predicate.User { + return predicate.User(sql.FieldGTE(FieldLastActiveAt, v)) +} + +// LastActiveAtLT applies the LT predicate on the "last_active_at" field. +func LastActiveAtLT(v time.Time) predicate.User { + return predicate.User(sql.FieldLT(FieldLastActiveAt, v)) +} + +// LastActiveAtLTE applies the LTE predicate on the "last_active_at" field. +func LastActiveAtLTE(v time.Time) predicate.User { + return predicate.User(sql.FieldLTE(FieldLastActiveAt, v)) +} + +// LastActiveAtIsNil applies the IsNil predicate on the "last_active_at" field. +func LastActiveAtIsNil() predicate.User { + return predicate.User(sql.FieldIsNull(FieldLastActiveAt)) +} + +// LastActiveAtNotNil applies the NotNil predicate on the "last_active_at" field. +func LastActiveAtNotNil() predicate.User { + return predicate.User(sql.FieldNotNull(FieldLastActiveAt)) +} + // BalanceNotifyEnabledEQ applies the EQ predicate on the "balance_notify_enabled" field. func BalanceNotifyEnabledEQ(v bool) predicate.User { return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v)) @@ -1345,6 +1525,52 @@ func HasPaymentOrdersWith(preds ...predicate.PaymentOrder) predicate.User { }) } +// HasAuthIdentities applies the HasEdge predicate on the "auth_identities" edge. +func HasAuthIdentities() predicate.User { + return predicate.User(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, AuthIdentitiesTable, AuthIdentitiesColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasAuthIdentitiesWith applies the HasEdge predicate on the "auth_identities" edge with a given conditions (other predicates). +func HasAuthIdentitiesWith(preds ...predicate.AuthIdentity) predicate.User { + return predicate.User(func(s *sql.Selector) { + step := newAuthIdentitiesStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasPendingAuthSessions applies the HasEdge predicate on the "pending_auth_sessions" edge. +func HasPendingAuthSessions() predicate.User { + return predicate.User(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasPendingAuthSessionsWith applies the HasEdge predicate on the "pending_auth_sessions" edge with a given conditions (other predicates). +func HasPendingAuthSessionsWith(preds ...predicate.PendingAuthSession) predicate.User { + return predicate.User(func(s *sql.Selector) { + step := newPendingAuthSessionsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge. func HasUserAllowedGroups() predicate.User { return predicate.User(func(s *sql.Selector) { diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go index fbc64f9c..db95e813 100644 --- a/backend/ent/user_create.go +++ b/backend/ent/user_create.go @@ -13,8 +13,10 @@ import ( "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/paymentorder" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/usagelog" @@ -211,6 +213,48 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate { return _c } +// SetSignupSource sets the "signup_source" field. +func (_c *UserCreate) SetSignupSource(v string) *UserCreate { + _c.mutation.SetSignupSource(v) + return _c +} + +// SetNillableSignupSource sets the "signup_source" field if the given value is not nil. +func (_c *UserCreate) SetNillableSignupSource(v *string) *UserCreate { + if v != nil { + _c.SetSignupSource(*v) + } + return _c +} + +// SetLastLoginAt sets the "last_login_at" field. +func (_c *UserCreate) SetLastLoginAt(v time.Time) *UserCreate { + _c.mutation.SetLastLoginAt(v) + return _c +} + +// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil. +func (_c *UserCreate) SetNillableLastLoginAt(v *time.Time) *UserCreate { + if v != nil { + _c.SetLastLoginAt(*v) + } + return _c +} + +// SetLastActiveAt sets the "last_active_at" field. +func (_c *UserCreate) SetLastActiveAt(v time.Time) *UserCreate { + _c.mutation.SetLastActiveAt(v) + return _c +} + +// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil. +func (_c *UserCreate) SetNillableLastActiveAt(v *time.Time) *UserCreate { + if v != nil { + _c.SetLastActiveAt(*v) + } + return _c +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (_c *UserCreate) SetBalanceNotifyEnabled(v bool) *UserCreate { _c.mutation.SetBalanceNotifyEnabled(v) @@ -431,6 +475,36 @@ func (_c *UserCreate) AddPaymentOrders(v ...*PaymentOrder) *UserCreate { return _c.AddPaymentOrderIDs(ids...) } +// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs. +func (_c *UserCreate) AddAuthIdentityIDs(ids ...int64) *UserCreate { + _c.mutation.AddAuthIdentityIDs(ids...) + return _c +} + +// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity. +func (_c *UserCreate) AddAuthIdentities(v ...*AuthIdentity) *UserCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddAuthIdentityIDs(ids...) +} + +// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs. +func (_c *UserCreate) AddPendingAuthSessionIDs(ids ...int64) *UserCreate { + _c.mutation.AddPendingAuthSessionIDs(ids...) + return _c +} + +// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity. +func (_c *UserCreate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddPendingAuthSessionIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_c *UserCreate) Mutation() *UserMutation { return _c.mutation @@ -510,6 +584,10 @@ func (_c *UserCreate) defaults() error { v := user.DefaultTotpEnabled _c.mutation.SetTotpEnabled(v) } + if _, ok := _c.mutation.SignupSource(); !ok { + v := user.DefaultSignupSource + _c.mutation.SetSignupSource(v) + } if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok { v := user.DefaultBalanceNotifyEnabled _c.mutation.SetBalanceNotifyEnabled(v) @@ -589,6 +667,14 @@ func (_c *UserCreate) check() error { if _, ok := _c.mutation.TotpEnabled(); !ok { return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)} } + if _, ok := _c.mutation.SignupSource(); !ok { + return &ValidationError{Name: "signup_source", err: errors.New(`ent: missing required field "User.signup_source"`)} + } + if v, ok := _c.mutation.SignupSource(); ok { + if err := user.SignupSourceValidator(v); err != nil { + return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)} + } + } if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok { return &ValidationError{Name: "balance_notify_enabled", err: errors.New(`ent: missing required field "User.balance_notify_enabled"`)} } @@ -684,6 +770,18 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { _spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value) _node.TotpEnabledAt = &value } + if value, ok := _c.mutation.SignupSource(); ok { + _spec.SetField(user.FieldSignupSource, field.TypeString, value) + _node.SignupSource = value + } + if value, ok := _c.mutation.LastLoginAt(); ok { + _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value) + _node.LastLoginAt = &value + } + if value, ok := _c.mutation.LastActiveAt(); ok { + _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value) + _node.LastActiveAt = &value + } if value, ok := _c.mutation.BalanceNotifyEnabled(); ok { _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value) _node.BalanceNotifyEnabled = value @@ -868,6 +966,38 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { } _spec.Edges = append(_spec.Edges, edge) } + if nodes := _c.mutation.AuthIdentitiesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } return _node, _spec } @@ -1106,6 +1236,54 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert { return u } +// SetSignupSource sets the "signup_source" field. +func (u *UserUpsert) SetSignupSource(v string) *UserUpsert { + u.Set(user.FieldSignupSource, v) + return u +} + +// UpdateSignupSource sets the "signup_source" field to the value that was provided on create. +func (u *UserUpsert) UpdateSignupSource() *UserUpsert { + u.SetExcluded(user.FieldSignupSource) + return u +} + +// SetLastLoginAt sets the "last_login_at" field. +func (u *UserUpsert) SetLastLoginAt(v time.Time) *UserUpsert { + u.Set(user.FieldLastLoginAt, v) + return u +} + +// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create. +func (u *UserUpsert) UpdateLastLoginAt() *UserUpsert { + u.SetExcluded(user.FieldLastLoginAt) + return u +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (u *UserUpsert) ClearLastLoginAt() *UserUpsert { + u.SetNull(user.FieldLastLoginAt) + return u +} + +// SetLastActiveAt sets the "last_active_at" field. +func (u *UserUpsert) SetLastActiveAt(v time.Time) *UserUpsert { + u.Set(user.FieldLastActiveAt, v) + return u +} + +// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create. +func (u *UserUpsert) UpdateLastActiveAt() *UserUpsert { + u.SetExcluded(user.FieldLastActiveAt) + return u +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (u *UserUpsert) ClearLastActiveAt() *UserUpsert { + u.SetNull(user.FieldLastActiveAt) + return u +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (u *UserUpsert) SetBalanceNotifyEnabled(v bool) *UserUpsert { u.Set(user.FieldBalanceNotifyEnabled, v) @@ -1446,6 +1624,62 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne { }) } +// SetSignupSource sets the "signup_source" field. +func (u *UserUpsertOne) SetSignupSource(v string) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetSignupSource(v) + }) +} + +// UpdateSignupSource sets the "signup_source" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateSignupSource() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateSignupSource() + }) +} + +// SetLastLoginAt sets the "last_login_at" field. +func (u *UserUpsertOne) SetLastLoginAt(v time.Time) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetLastLoginAt(v) + }) +} + +// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateLastLoginAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateLastLoginAt() + }) +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (u *UserUpsertOne) ClearLastLoginAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.ClearLastLoginAt() + }) +} + +// SetLastActiveAt sets the "last_active_at" field. +func (u *UserUpsertOne) SetLastActiveAt(v time.Time) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetLastActiveAt(v) + }) +} + +// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateLastActiveAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateLastActiveAt() + }) +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (u *UserUpsertOne) ClearLastActiveAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.ClearLastActiveAt() + }) +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (u *UserUpsertOne) SetBalanceNotifyEnabled(v bool) *UserUpsertOne { return u.Update(func(s *UserUpsert) { @@ -1965,6 +2199,62 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk { }) } +// SetSignupSource sets the "signup_source" field. +func (u *UserUpsertBulk) SetSignupSource(v string) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetSignupSource(v) + }) +} + +// UpdateSignupSource sets the "signup_source" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateSignupSource() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateSignupSource() + }) +} + +// SetLastLoginAt sets the "last_login_at" field. +func (u *UserUpsertBulk) SetLastLoginAt(v time.Time) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetLastLoginAt(v) + }) +} + +// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateLastLoginAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateLastLoginAt() + }) +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (u *UserUpsertBulk) ClearLastLoginAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.ClearLastLoginAt() + }) +} + +// SetLastActiveAt sets the "last_active_at" field. +func (u *UserUpsertBulk) SetLastActiveAt(v time.Time) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetLastActiveAt(v) + }) +} + +// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateLastActiveAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateLastActiveAt() + }) +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (u *UserUpsertBulk) ClearLastActiveAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.ClearLastActiveAt() + }) +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (u *UserUpsertBulk) SetBalanceNotifyEnabled(v bool) *UserUpsertBulk { return u.Update(func(s *UserUpsert) { diff --git a/backend/ent/user_query.go b/backend/ent/user_query.go index 113d87ac..f1ee5cfe 100644 --- a/backend/ent/user_query.go +++ b/backend/ent/user_query.go @@ -15,8 +15,10 @@ import ( "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/paymentorder" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/redeemcode" @@ -44,6 +46,8 @@ type UserQuery struct { withAttributeValues *UserAttributeValueQuery withPromoCodeUsages *PromoCodeUsageQuery withPaymentOrders *PaymentOrderQuery + withAuthIdentities *AuthIdentityQuery + withPendingAuthSessions *PendingAuthSessionQuery withUserAllowedGroups *UserAllowedGroupQuery modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). @@ -302,6 +306,50 @@ func (_q *UserQuery) QueryPaymentOrders() *PaymentOrderQuery { return query } +// QueryAuthIdentities chains the current query on the "auth_identities" edge. +func (_q *UserQuery) QueryAuthIdentities() *AuthIdentityQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, selector), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.AuthIdentitiesTable, user.AuthIdentitiesColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryPendingAuthSessions chains the current query on the "pending_auth_sessions" edge. +func (_q *UserQuery) QueryPendingAuthSessions() *PendingAuthSessionQuery { + query := (&PendingAuthSessionClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, selector), + sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.PendingAuthSessionsTable, user.PendingAuthSessionsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + // QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge. func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery { query := (&UserAllowedGroupClient{config: _q.config}).Query() @@ -526,6 +574,8 @@ func (_q *UserQuery) Clone() *UserQuery { withAttributeValues: _q.withAttributeValues.Clone(), withPromoCodeUsages: _q.withPromoCodeUsages.Clone(), withPaymentOrders: _q.withPaymentOrders.Clone(), + withAuthIdentities: _q.withAuthIdentities.Clone(), + withPendingAuthSessions: _q.withPendingAuthSessions.Clone(), withUserAllowedGroups: _q.withUserAllowedGroups.Clone(), // clone intermediate query. sql: _q.sql.Clone(), @@ -643,6 +693,28 @@ func (_q *UserQuery) WithPaymentOrders(opts ...func(*PaymentOrderQuery)) *UserQu return _q } +// WithAuthIdentities tells the query-builder to eager-load the nodes that are connected to +// the "auth_identities" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UserQuery) WithAuthIdentities(opts ...func(*AuthIdentityQuery)) *UserQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withAuthIdentities = query + return _q +} + +// WithPendingAuthSessions tells the query-builder to eager-load the nodes that are connected to +// the "pending_auth_sessions" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UserQuery) WithPendingAuthSessions(opts ...func(*PendingAuthSessionQuery)) *UserQuery { + query := (&PendingAuthSessionClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withPendingAuthSessions = query + return _q +} + // WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to // the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge. func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery { @@ -732,7 +804,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e var ( nodes = []*User{} _spec = _q.querySpec() - loadedTypes = [11]bool{ + loadedTypes = [13]bool{ _q.withAPIKeys != nil, _q.withRedeemCodes != nil, _q.withSubscriptions != nil, @@ -743,6 +815,8 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e _q.withAttributeValues != nil, _q.withPromoCodeUsages != nil, _q.withPaymentOrders != nil, + _q.withAuthIdentities != nil, + _q.withPendingAuthSessions != nil, _q.withUserAllowedGroups != nil, } ) @@ -839,6 +913,22 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e return nil, err } } + if query := _q.withAuthIdentities; query != nil { + if err := _q.loadAuthIdentities(ctx, query, nodes, + func(n *User) { n.Edges.AuthIdentities = []*AuthIdentity{} }, + func(n *User, e *AuthIdentity) { n.Edges.AuthIdentities = append(n.Edges.AuthIdentities, e) }); err != nil { + return nil, err + } + } + if query := _q.withPendingAuthSessions; query != nil { + if err := _q.loadPendingAuthSessions(ctx, query, nodes, + func(n *User) { n.Edges.PendingAuthSessions = []*PendingAuthSession{} }, + func(n *User, e *PendingAuthSession) { + n.Edges.PendingAuthSessions = append(n.Edges.PendingAuthSessions, e) + }); err != nil { + return nil, err + } + } if query := _q.withUserAllowedGroups; query != nil { if err := _q.loadUserAllowedGroups(ctx, query, nodes, func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} }, @@ -1186,6 +1276,69 @@ func (_q *UserQuery) loadPaymentOrders(ctx context.Context, query *PaymentOrderQ } return nil } +func (_q *UserQuery) loadAuthIdentities(ctx context.Context, query *AuthIdentityQuery, nodes []*User, init func(*User), assign func(*User, *AuthIdentity)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(authidentity.FieldUserID) + } + query.Where(predicate.AuthIdentity(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(user.AuthIdentitiesColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.UserID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} +func (_q *UserQuery) loadPendingAuthSessions(ctx context.Context, query *PendingAuthSessionQuery, nodes []*User, init func(*User), assign func(*User, *PendingAuthSession)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(pendingauthsession.FieldTargetUserID) + } + query.Where(predicate.PendingAuthSession(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(user.PendingAuthSessionsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.TargetUserID + if fk == nil { + return fmt.Errorf(`foreign-key "target_user_id" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "target_user_id" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error { fks := make([]driver.Value, 0, len(nodes)) nodeids := make(map[int64]*User) diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go index 6b355247..677eeb6b 100644 --- a/backend/ent/user_update.go +++ b/backend/ent/user_update.go @@ -13,8 +13,10 @@ import ( "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/paymentorder" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/redeemcode" @@ -243,6 +245,60 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate { return _u } +// SetSignupSource sets the "signup_source" field. +func (_u *UserUpdate) SetSignupSource(v string) *UserUpdate { + _u.mutation.SetSignupSource(v) + return _u +} + +// SetNillableSignupSource sets the "signup_source" field if the given value is not nil. +func (_u *UserUpdate) SetNillableSignupSource(v *string) *UserUpdate { + if v != nil { + _u.SetSignupSource(*v) + } + return _u +} + +// SetLastLoginAt sets the "last_login_at" field. +func (_u *UserUpdate) SetLastLoginAt(v time.Time) *UserUpdate { + _u.mutation.SetLastLoginAt(v) + return _u +} + +// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil. +func (_u *UserUpdate) SetNillableLastLoginAt(v *time.Time) *UserUpdate { + if v != nil { + _u.SetLastLoginAt(*v) + } + return _u +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (_u *UserUpdate) ClearLastLoginAt() *UserUpdate { + _u.mutation.ClearLastLoginAt() + return _u +} + +// SetLastActiveAt sets the "last_active_at" field. +func (_u *UserUpdate) SetLastActiveAt(v time.Time) *UserUpdate { + _u.mutation.SetLastActiveAt(v) + return _u +} + +// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil. +func (_u *UserUpdate) SetNillableLastActiveAt(v *time.Time) *UserUpdate { + if v != nil { + _u.SetLastActiveAt(*v) + } + return _u +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (_u *UserUpdate) ClearLastActiveAt() *UserUpdate { + _u.mutation.ClearLastActiveAt() + return _u +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (_u *UserUpdate) SetBalanceNotifyEnabled(v bool) *UserUpdate { _u.mutation.SetBalanceNotifyEnabled(v) @@ -483,6 +539,36 @@ func (_u *UserUpdate) AddPaymentOrders(v ...*PaymentOrder) *UserUpdate { return _u.AddPaymentOrderIDs(ids...) } +// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs. +func (_u *UserUpdate) AddAuthIdentityIDs(ids ...int64) *UserUpdate { + _u.mutation.AddAuthIdentityIDs(ids...) + return _u +} + +// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity. +func (_u *UserUpdate) AddAuthIdentities(v ...*AuthIdentity) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddAuthIdentityIDs(ids...) +} + +// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs. +func (_u *UserUpdate) AddPendingAuthSessionIDs(ids ...int64) *UserUpdate { + _u.mutation.AddPendingAuthSessionIDs(ids...) + return _u +} + +// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity. +func (_u *UserUpdate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddPendingAuthSessionIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_u *UserUpdate) Mutation() *UserMutation { return _u.mutation @@ -698,6 +784,48 @@ func (_u *UserUpdate) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdate { return _u.RemovePaymentOrderIDs(ids...) } +// ClearAuthIdentities clears all "auth_identities" edges to the AuthIdentity entity. +func (_u *UserUpdate) ClearAuthIdentities() *UserUpdate { + _u.mutation.ClearAuthIdentities() + return _u +} + +// RemoveAuthIdentityIDs removes the "auth_identities" edge to AuthIdentity entities by IDs. +func (_u *UserUpdate) RemoveAuthIdentityIDs(ids ...int64) *UserUpdate { + _u.mutation.RemoveAuthIdentityIDs(ids...) + return _u +} + +// RemoveAuthIdentities removes "auth_identities" edges to AuthIdentity entities. +func (_u *UserUpdate) RemoveAuthIdentities(v ...*AuthIdentity) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveAuthIdentityIDs(ids...) +} + +// ClearPendingAuthSessions clears all "pending_auth_sessions" edges to the PendingAuthSession entity. +func (_u *UserUpdate) ClearPendingAuthSessions() *UserUpdate { + _u.mutation.ClearPendingAuthSessions() + return _u +} + +// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to PendingAuthSession entities by IDs. +func (_u *UserUpdate) RemovePendingAuthSessionIDs(ids ...int64) *UserUpdate { + _u.mutation.RemovePendingAuthSessionIDs(ids...) + return _u +} + +// RemovePendingAuthSessions removes "pending_auth_sessions" edges to PendingAuthSession entities. +func (_u *UserUpdate) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemovePendingAuthSessionIDs(ids...) +} + // Save executes the query and returns the number of nodes affected by the update operation. func (_u *UserUpdate) Save(ctx context.Context) (int, error) { if err := _u.defaults(); err != nil { @@ -767,6 +895,11 @@ func (_u *UserUpdate) check() error { return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)} } } + if v, ok := _u.mutation.SignupSource(); ok { + if err := user.SignupSourceValidator(v); err != nil { + return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)} + } + } return nil } @@ -836,6 +969,21 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.TotpEnabledAtCleared() { _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) } + if value, ok := _u.mutation.SignupSource(); ok { + _spec.SetField(user.FieldSignupSource, field.TypeString, value) + } + if value, ok := _u.mutation.LastLoginAt(); ok { + _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value) + } + if _u.mutation.LastLoginAtCleared() { + _spec.ClearField(user.FieldLastLoginAt, field.TypeTime) + } + if value, ok := _u.mutation.LastActiveAt(); ok { + _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value) + } + if _u.mutation.LastActiveAtCleared() { + _spec.ClearField(user.FieldLastActiveAt, field.TypeTime) + } if value, ok := _u.mutation.BalanceNotifyEnabled(); ok { _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value) } @@ -1322,6 +1470,96 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.AuthIdentitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedAuthIdentitiesIDs(); len(nodes) > 0 && !_u.mutation.AuthIdentitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AuthIdentitiesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.PendingAuthSessionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedPendingAuthSessionsIDs(); len(nodes) > 0 && !_u.mutation.PendingAuthSessionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{user.Label} @@ -1548,6 +1786,60 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne { return _u } +// SetSignupSource sets the "signup_source" field. +func (_u *UserUpdateOne) SetSignupSource(v string) *UserUpdateOne { + _u.mutation.SetSignupSource(v) + return _u +} + +// SetNillableSignupSource sets the "signup_source" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableSignupSource(v *string) *UserUpdateOne { + if v != nil { + _u.SetSignupSource(*v) + } + return _u +} + +// SetLastLoginAt sets the "last_login_at" field. +func (_u *UserUpdateOne) SetLastLoginAt(v time.Time) *UserUpdateOne { + _u.mutation.SetLastLoginAt(v) + return _u +} + +// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableLastLoginAt(v *time.Time) *UserUpdateOne { + if v != nil { + _u.SetLastLoginAt(*v) + } + return _u +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (_u *UserUpdateOne) ClearLastLoginAt() *UserUpdateOne { + _u.mutation.ClearLastLoginAt() + return _u +} + +// SetLastActiveAt sets the "last_active_at" field. +func (_u *UserUpdateOne) SetLastActiveAt(v time.Time) *UserUpdateOne { + _u.mutation.SetLastActiveAt(v) + return _u +} + +// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableLastActiveAt(v *time.Time) *UserUpdateOne { + if v != nil { + _u.SetLastActiveAt(*v) + } + return _u +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (_u *UserUpdateOne) ClearLastActiveAt() *UserUpdateOne { + _u.mutation.ClearLastActiveAt() + return _u +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (_u *UserUpdateOne) SetBalanceNotifyEnabled(v bool) *UserUpdateOne { _u.mutation.SetBalanceNotifyEnabled(v) @@ -1788,6 +2080,36 @@ func (_u *UserUpdateOne) AddPaymentOrders(v ...*PaymentOrder) *UserUpdateOne { return _u.AddPaymentOrderIDs(ids...) } +// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs. +func (_u *UserUpdateOne) AddAuthIdentityIDs(ids ...int64) *UserUpdateOne { + _u.mutation.AddAuthIdentityIDs(ids...) + return _u +} + +// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity. +func (_u *UserUpdateOne) AddAuthIdentities(v ...*AuthIdentity) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddAuthIdentityIDs(ids...) +} + +// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs. +func (_u *UserUpdateOne) AddPendingAuthSessionIDs(ids ...int64) *UserUpdateOne { + _u.mutation.AddPendingAuthSessionIDs(ids...) + return _u +} + +// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity. +func (_u *UserUpdateOne) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddPendingAuthSessionIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_u *UserUpdateOne) Mutation() *UserMutation { return _u.mutation @@ -2003,6 +2325,48 @@ func (_u *UserUpdateOne) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdateOne return _u.RemovePaymentOrderIDs(ids...) } +// ClearAuthIdentities clears all "auth_identities" edges to the AuthIdentity entity. +func (_u *UserUpdateOne) ClearAuthIdentities() *UserUpdateOne { + _u.mutation.ClearAuthIdentities() + return _u +} + +// RemoveAuthIdentityIDs removes the "auth_identities" edge to AuthIdentity entities by IDs. +func (_u *UserUpdateOne) RemoveAuthIdentityIDs(ids ...int64) *UserUpdateOne { + _u.mutation.RemoveAuthIdentityIDs(ids...) + return _u +} + +// RemoveAuthIdentities removes "auth_identities" edges to AuthIdentity entities. +func (_u *UserUpdateOne) RemoveAuthIdentities(v ...*AuthIdentity) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveAuthIdentityIDs(ids...) +} + +// ClearPendingAuthSessions clears all "pending_auth_sessions" edges to the PendingAuthSession entity. +func (_u *UserUpdateOne) ClearPendingAuthSessions() *UserUpdateOne { + _u.mutation.ClearPendingAuthSessions() + return _u +} + +// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to PendingAuthSession entities by IDs. +func (_u *UserUpdateOne) RemovePendingAuthSessionIDs(ids ...int64) *UserUpdateOne { + _u.mutation.RemovePendingAuthSessionIDs(ids...) + return _u +} + +// RemovePendingAuthSessions removes "pending_auth_sessions" edges to PendingAuthSession entities. +func (_u *UserUpdateOne) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemovePendingAuthSessionIDs(ids...) +} + // Where appends a list predicates to the UserUpdate builder. func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne { _u.mutation.Where(ps...) @@ -2085,6 +2449,11 @@ func (_u *UserUpdateOne) check() error { return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)} } } + if v, ok := _u.mutation.SignupSource(); ok { + if err := user.SignupSourceValidator(v); err != nil { + return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)} + } + } return nil } @@ -2171,6 +2540,21 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { if _u.mutation.TotpEnabledAtCleared() { _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) } + if value, ok := _u.mutation.SignupSource(); ok { + _spec.SetField(user.FieldSignupSource, field.TypeString, value) + } + if value, ok := _u.mutation.LastLoginAt(); ok { + _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value) + } + if _u.mutation.LastLoginAtCleared() { + _spec.ClearField(user.FieldLastLoginAt, field.TypeTime) + } + if value, ok := _u.mutation.LastActiveAt(); ok { + _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value) + } + if _u.mutation.LastActiveAtCleared() { + _spec.ClearField(user.FieldLastActiveAt, field.TypeTime) + } if value, ok := _u.mutation.BalanceNotifyEnabled(); ok { _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value) } @@ -2657,6 +3041,96 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.AuthIdentitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedAuthIdentitiesIDs(); len(nodes) > 0 && !_u.mutation.AuthIdentitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AuthIdentitiesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.PendingAuthSessionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedPendingAuthSessionsIDs(); len(nodes) > 0 && !_u.mutation.PendingAuthSessionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } _node = &User{config: _u.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index dd9a4e58..6136e9ea 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -1608,6 +1608,9 @@ func (c *Config) Validate() error { return fmt.Errorf("security.csp.policy is required when CSP is enabled") } if c.LinuxDo.Enabled { + if !c.LinuxDo.UsePKCE { + return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.enabled=true") + } if strings.TrimSpace(c.LinuxDo.ClientID) == "" { return fmt.Errorf("linuxdo_connect.client_id is required when linuxdo_connect.enabled=true") } @@ -1629,9 +1632,6 @@ func (c *Config) Validate() error { default: return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none") } - if method == "none" && !c.LinuxDo.UsePKCE { - return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none") - } if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(c.LinuxDo.ClientSecret) == "" { return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic") @@ -1663,6 +1663,12 @@ func (c *Config) Validate() error { warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL) } if c.OIDC.Enabled { + if !c.OIDC.UsePKCE { + return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.enabled=true") + } + if !c.OIDC.ValidateIDToken { + return fmt.Errorf("oidc_connect.validate_id_token must be true when oidc_connect.enabled=true") + } if strings.TrimSpace(c.OIDC.ClientID) == "" { return fmt.Errorf("oidc_connect.client_id is required when oidc_connect.enabled=true") } @@ -1685,9 +1691,6 @@ func (c *Config) Validate() error { default: return fmt.Errorf("oidc_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none") } - if method == "none" && !c.OIDC.UsePKCE { - return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.token_auth_method=none") - } if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(c.OIDC.ClientSecret) == "" { return fmt.Errorf("oidc_connect.client_secret is required when oidc_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic") diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index bec0f126..fe5c7928 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -73,6 +73,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { response.ErrorFrom(c, err) return } + authSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } // Check if ops monitoring is enabled (respects config.ops.enabled) opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context()) @@ -93,7 +98,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { paymentCfg = &service.PaymentConfig{} } - response.Success(c, dto.SystemSettings{ + payload := dto.SystemSettings{ RegistrationEnabled: settings.RegistrationEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled, RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist, @@ -200,7 +205,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow, PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit, PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode, - }) + } + response.Success(c, systemSettingsResponseData(payload, authSourceDefaults)) } // UpdateSettingsRequest 更新设置请求 @@ -276,9 +282,30 @@ type UpdateSettingsRequest struct { CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"` // 默认配置 - DefaultConcurrency int `json:"default_concurrency"` - DefaultBalance float64 `json:"default_balance"` - DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"` + DefaultConcurrency int `json:"default_concurrency"` + DefaultBalance float64 `json:"default_balance"` + DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"` + AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"` + AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"` + AuthSourceDefaultEmailSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_email_subscriptions"` + AuthSourceDefaultEmailGrantOnSignup *bool `json:"auth_source_default_email_grant_on_signup"` + AuthSourceDefaultEmailGrantOnFirstBind *bool `json:"auth_source_default_email_grant_on_first_bind"` + AuthSourceDefaultLinuxDoBalance *float64 `json:"auth_source_default_linuxdo_balance"` + AuthSourceDefaultLinuxDoConcurrency *int `json:"auth_source_default_linuxdo_concurrency"` + AuthSourceDefaultLinuxDoSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_linuxdo_subscriptions"` + AuthSourceDefaultLinuxDoGrantOnSignup *bool `json:"auth_source_default_linuxdo_grant_on_signup"` + AuthSourceDefaultLinuxDoGrantOnFirstBind *bool `json:"auth_source_default_linuxdo_grant_on_first_bind"` + AuthSourceDefaultOIDCBalance *float64 `json:"auth_source_default_oidc_balance"` + AuthSourceDefaultOIDCConcurrency *int `json:"auth_source_default_oidc_concurrency"` + AuthSourceDefaultOIDCSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_oidc_subscriptions"` + AuthSourceDefaultOIDCGrantOnSignup *bool `json:"auth_source_default_oidc_grant_on_signup"` + AuthSourceDefaultOIDCGrantOnFirstBind *bool `json:"auth_source_default_oidc_grant_on_first_bind"` + AuthSourceDefaultWeChatBalance *float64 `json:"auth_source_default_wechat_balance"` + AuthSourceDefaultWeChatConcurrency *int `json:"auth_source_default_wechat_concurrency"` + AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"` + AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"` + AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"` + ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"` // Model fallback configuration EnableModelFallback bool `json:"enable_model_fallback"` @@ -357,6 +384,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.ErrorFrom(c, err) return } + previousAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } // 验证参数 if req.DefaultConcurrency < 1 { @@ -381,6 +413,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { req.SMTPPort = 587 } req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions) + req.AuthSourceDefaultEmailSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultEmailSubscriptions) + req.AuthSourceDefaultLinuxDoSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultLinuxDoSubscriptions) + req.AuthSourceDefaultOIDCSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultOIDCSubscriptions) + req.AuthSourceDefaultWeChatSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultWeChatSubscriptions) // SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置 // 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置 @@ -538,25 +574,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.BadRequest(c, "OIDC scopes must contain openid") return } + if !req.OIDCConnectUsePKCE { + response.BadRequest(c, "OIDC PKCE must be enabled") + return + } + if !req.OIDCConnectValidateIDToken { + response.BadRequest(c, "OIDC ID Token validation must be enabled") + return + } switch req.OIDCConnectTokenAuthMethod { case "", "client_secret_post", "client_secret_basic", "none": default: response.BadRequest(c, "OIDC Token Auth Method must be one of client_secret_post/client_secret_basic/none") return } - if req.OIDCConnectTokenAuthMethod == "none" && !req.OIDCConnectUsePKCE { - response.BadRequest(c, "OIDC PKCE must be enabled when token_auth_method=none") - return - } if req.OIDCConnectClockSkewSeconds < 0 || req.OIDCConnectClockSkewSeconds > 600 { response.BadRequest(c, "OIDC clock skew seconds must be between 0 and 600") return } - if req.OIDCConnectValidateIDToken { - if req.OIDCConnectAllowedSigningAlgs == "" { - response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true") - return - } + if req.OIDCConnectAllowedSigningAlgs == "" { + response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true") + return } if req.OIDCConnectJWKSURL != "" { if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectJWKSURL); err != nil { @@ -933,6 +971,41 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.ErrorFrom(c, err) return } + authSourceDefaults := &service.AuthSourceDefaultSettings{ + Email: service.ProviderDefaultGrantSettings{ + Balance: float64ValueOrDefault(req.AuthSourceDefaultEmailBalance, previousAuthSourceDefaults.Email.Balance), + Concurrency: intValueOrDefault(req.AuthSourceDefaultEmailConcurrency, previousAuthSourceDefaults.Email.Concurrency), + Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultEmailSubscriptions, previousAuthSourceDefaults.Email.Subscriptions), + GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnSignup, previousAuthSourceDefaults.Email.GrantOnSignup), + GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnFirstBind, previousAuthSourceDefaults.Email.GrantOnFirstBind), + }, + LinuxDo: service.ProviderDefaultGrantSettings{ + Balance: float64ValueOrDefault(req.AuthSourceDefaultLinuxDoBalance, previousAuthSourceDefaults.LinuxDo.Balance), + Concurrency: intValueOrDefault(req.AuthSourceDefaultLinuxDoConcurrency, previousAuthSourceDefaults.LinuxDo.Concurrency), + Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultLinuxDoSubscriptions, previousAuthSourceDefaults.LinuxDo.Subscriptions), + GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnSignup, previousAuthSourceDefaults.LinuxDo.GrantOnSignup), + GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnFirstBind, previousAuthSourceDefaults.LinuxDo.GrantOnFirstBind), + }, + OIDC: service.ProviderDefaultGrantSettings{ + Balance: float64ValueOrDefault(req.AuthSourceDefaultOIDCBalance, previousAuthSourceDefaults.OIDC.Balance), + Concurrency: intValueOrDefault(req.AuthSourceDefaultOIDCConcurrency, previousAuthSourceDefaults.OIDC.Concurrency), + Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultOIDCSubscriptions, previousAuthSourceDefaults.OIDC.Subscriptions), + GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnSignup, previousAuthSourceDefaults.OIDC.GrantOnSignup), + GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnFirstBind, previousAuthSourceDefaults.OIDC.GrantOnFirstBind), + }, + WeChat: service.ProviderDefaultGrantSettings{ + Balance: float64ValueOrDefault(req.AuthSourceDefaultWeChatBalance, previousAuthSourceDefaults.WeChat.Balance), + Concurrency: intValueOrDefault(req.AuthSourceDefaultWeChatConcurrency, previousAuthSourceDefaults.WeChat.Concurrency), + Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultWeChatSubscriptions, previousAuthSourceDefaults.WeChat.Subscriptions), + GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup), + GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind), + }, + ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup), + } + if err := h.settingService.UpdateAuthSourceDefaultSettings(c.Request.Context(), authSourceDefaults); err != nil { + response.ErrorFrom(c, err) + return + } // Update payment configuration (integrated into system settings). // Skip if no payment fields were provided (prevents accidental wipe). @@ -977,6 +1050,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.ErrorFrom(c, err) return } + updatedAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions)) for _, sub := range updatedSettings.DefaultSubscriptions { updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{ @@ -994,7 +1072,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { updatedPaymentCfg = &service.PaymentConfig{} } - response.Success(c, dto.SystemSettings{ + payload := dto.SystemSettings{ RegistrationEnabled: updatedSettings.RegistrationEnabled, EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist, @@ -1100,7 +1178,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow, PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit, PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode, - }) + } + response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults)) } // hasPaymentFields returns true if any payment-related field was explicitly provided. @@ -1412,6 +1491,84 @@ func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto return normalized } +func normalizeOptionalDefaultSubscriptions(input *[]dto.DefaultSubscriptionSetting) *[]dto.DefaultSubscriptionSetting { + if input == nil { + return nil + } + normalized := normalizeDefaultSubscriptions(*input) + return &normalized +} + +func float64ValueOrDefault(value *float64, fallback float64) float64 { + if value == nil { + return fallback + } + return *value +} + +func intValueOrDefault(value *int, fallback int) int { + if value == nil { + return fallback + } + return *value +} + +func boolValueOrDefault(value *bool, fallback bool) bool { + if value == nil { + return fallback + } + return *value +} + +func defaultSubscriptionsValueOrDefault(input *[]dto.DefaultSubscriptionSetting, fallback []service.DefaultSubscriptionSetting) []service.DefaultSubscriptionSetting { + if input == nil { + return fallback + } + result := make([]service.DefaultSubscriptionSetting, 0, len(*input)) + for _, item := range *input { + result = append(result, service.DefaultSubscriptionSetting{ + GroupID: item.GroupID, + ValidityDays: item.ValidityDays, + }) + } + return result +} + +func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults *service.AuthSourceDefaultSettings) map[string]any { + data := make(map[string]any) + raw, err := json.Marshal(settings) + if err == nil { + _ = json.Unmarshal(raw, &data) + } + if authSourceDefaults == nil { + authSourceDefaults = &service.AuthSourceDefaultSettings{} + } + + data["auth_source_default_email_balance"] = authSourceDefaults.Email.Balance + data["auth_source_default_email_concurrency"] = authSourceDefaults.Email.Concurrency + data["auth_source_default_email_subscriptions"] = authSourceDefaults.Email.Subscriptions + data["auth_source_default_email_grant_on_signup"] = authSourceDefaults.Email.GrantOnSignup + data["auth_source_default_email_grant_on_first_bind"] = authSourceDefaults.Email.GrantOnFirstBind + data["auth_source_default_linuxdo_balance"] = authSourceDefaults.LinuxDo.Balance + data["auth_source_default_linuxdo_concurrency"] = authSourceDefaults.LinuxDo.Concurrency + data["auth_source_default_linuxdo_subscriptions"] = authSourceDefaults.LinuxDo.Subscriptions + data["auth_source_default_linuxdo_grant_on_signup"] = authSourceDefaults.LinuxDo.GrantOnSignup + data["auth_source_default_linuxdo_grant_on_first_bind"] = authSourceDefaults.LinuxDo.GrantOnFirstBind + data["auth_source_default_oidc_balance"] = authSourceDefaults.OIDC.Balance + data["auth_source_default_oidc_concurrency"] = authSourceDefaults.OIDC.Concurrency + data["auth_source_default_oidc_subscriptions"] = authSourceDefaults.OIDC.Subscriptions + data["auth_source_default_oidc_grant_on_signup"] = authSourceDefaults.OIDC.GrantOnSignup + data["auth_source_default_oidc_grant_on_first_bind"] = authSourceDefaults.OIDC.GrantOnFirstBind + data["auth_source_default_wechat_balance"] = authSourceDefaults.WeChat.Balance + data["auth_source_default_wechat_concurrency"] = authSourceDefaults.WeChat.Concurrency + data["auth_source_default_wechat_subscriptions"] = authSourceDefaults.WeChat.Subscriptions + data["auth_source_default_wechat_grant_on_signup"] = authSourceDefaults.WeChat.GrantOnSignup + data["auth_source_default_wechat_grant_on_first_bind"] = authSourceDefaults.WeChat.GrantOnFirstBind + data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup + + return data +} + func equalStringSlice(a, b []string) bool { if len(a) != len(b) { return false diff --git a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go new file mode 100644 index 00000000..b26fa447 --- /dev/null +++ b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go @@ -0,0 +1,149 @@ +package admin + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type settingHandlerRepoStub struct { + values map[string]string + lastUpdates map[string]string +} + +func (s *settingHandlerRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) { + panic("unexpected Get call") +} + +func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) { + panic("unexpected GetValue call") +} + +func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *settingHandlerRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *settingHandlerRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + s.lastUpdates = make(map[string]string, len(settings)) + for key, value := range settings { + s.lastUpdates[key] = value + if s.values == nil { + s.values = map[string]string{} + } + s.values[key] = value + } + return nil +} + +func (s *settingHandlerRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + out := make(map[string]string, len(s.values)) + for key, value := range s.values { + out[key] = value + } + return out, nil +} + +func (s *settingHandlerRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +func TestSettingHandler_GetSettings_InjectsAuthSourceDefaults(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyPromoCodeEnabled: "true", + service.SettingKeyAuthSourceDefaultEmailBalance: "9.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "8", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`, + service.SettingKeyForceEmailOnThirdPartySignup: "true", + }, + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/settings", nil) + + handler.GetSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + var resp response.Response + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + data, ok := resp.Data.(map[string]any) + require.True(t, ok) + require.Equal(t, 9.5, data["auth_source_default_email_balance"]) + require.Equal(t, float64(8), data["auth_source_default_email_concurrency"]) + require.Equal(t, true, data["force_email_on_third_party_signup"]) + + subscriptions, ok := data["auth_source_default_email_subscriptions"].([]any) + require.True(t, ok) + require.Len(t, subscriptions, 1) +} + +func TestSettingHandler_UpdateSettings_PreservesOmittedAuthSourceDefaults(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyRegistrationEnabled: "false", + service.SettingKeyPromoCodeEnabled: "true", + service.SettingKeyAuthSourceDefaultEmailBalance: "9.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "8", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`, + service.SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true", + service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false", + service.SettingKeyForceEmailOnThirdPartySignup: "true", + }, + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + body := map[string]any{ + "registration_enabled": true, + "promo_code_enabled": true, + "auth_source_default_email_balance": 12.75, + } + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "12.75000000", repo.values[service.SettingKeyAuthSourceDefaultEmailBalance]) + require.Equal(t, "8", repo.values[service.SettingKeyAuthSourceDefaultEmailConcurrency]) + require.Equal(t, `[{"group_id":31,"validity_days":15}]`, repo.values[service.SettingKeyAuthSourceDefaultEmailSubscriptions]) + require.Equal(t, "true", repo.values[service.SettingKeyForceEmailOnThirdPartySignup]) + + var resp response.Response + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + data, ok := resp.Data.(map[string]any) + require.True(t, ok) + require.Equal(t, 12.75, data["auth_source_default_email_balance"]) + require.Equal(t, float64(8), data["auth_source_default_email_concurrency"]) + require.Equal(t, true, data["force_email_on_third_party_signup"]) +} diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index 2f182642..b0edcf5a 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -219,7 +219,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { } // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired - tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") if err != nil { if errors.Is(err, service.ErrOAuthInvitationRequired) { if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ @@ -262,6 +262,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { ProviderKey: "linuxdo", ProviderSubject: subject, }, + TargetUserID: &user.ID, ResolvedEmail: email, RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey, @@ -287,7 +288,9 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { } type completeLinuxDoOAuthRequest struct { - InvitationCode string `json:"invitation_code" binding:"required"` + InvitationCode string `json:"invitation_code" binding:"required"` + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` } // CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating @@ -335,11 +338,23 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { return } - tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) if err != nil { response.ErrorFrom(c, err) return } + decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ + AdoptDisplayName: req.AdoptDisplayName, + AdoptAvatar: req.AdoptAvatar, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) + return + } if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { clearOAuthPendingSessionCookie(c, secureCookie) clearOAuthPendingBrowserCookie(c, secureCookie) diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go index 90bc10d1..661c0da0 100644 --- a/backend/internal/handler/auth_linuxdo_oauth_test.go +++ b/backend/internal/handler/auth_linuxdo_oauth_test.go @@ -1,10 +1,21 @@ package handler import ( + "bytes" + "context" + "net/http" + "net/http/httptest" "strings" "testing" + "time" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -110,3 +121,79 @@ func TestSingleLineStripsWhitespace(t *testing.T) { require.Equal(t, "hello world", singleLine("hello\r\nworld")) require.Equal(t, "", singleLine("\n\t\r")) } + +func TestCompleteLinuxDoOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("linuxdo-complete-session"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("linuxdo-subject-1"). + SetResolvedEmail("linuxdo-subject-1@linuxdo-connect.invalid"). + SetBrowserSessionKey("linuxdo-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + "suggested_display_name": "LinuxDo Display", + "suggested_avatar_url": "https://cdn.example/linuxdo.png", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + _, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + AdoptAvatar: true, + }) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-browser")}) + c.Request = req + + handler.CompleteLinuxDoOAuthRegistration(c) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.NotEmpty(t, responseData["access_token"]) + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEQ(session.ResolvedEmail)). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "LinuxDo Display", userEntity.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("linuxdo-subject-1"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + require.Equal(t, "LinuxDo Display", identity.Metadata["display_name"]) + require.Equal(t, "https://cdn.example/linuxdo.png", identity.Metadata["avatar_url"]) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.True(t, decision.AdoptDisplayName) + require.True(t, decision.AdoptAvatar) + + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) +} diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index a758c0b9..da8ac858 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -1,10 +1,17 @@ package handler import ( + "context" + "errors" + "io" "net/http" "net/url" "strings" + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -26,6 +33,7 @@ const ( type oauthPendingSessionPayload struct { Intent string Identity service.PendingAuthIdentityKey + TargetUserID *int64 ResolvedEmail string RedirectTo string BrowserSessionKey string @@ -33,6 +41,11 @@ type oauthPendingSessionPayload struct { CompletionResponse map[string]any } +type oauthAdoptionDecisionRequest struct { + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` +} + func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) { if h == nil || h.authService == nil || h.authService.EntClient() == nil { return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") @@ -125,6 +138,7 @@ func (h *AuthHandler) createOAuthPendingSession(c *gin.Context, payload oauthPen session, err := svc.CreatePendingSession(c.Request.Context(), service.CreatePendingAuthSessionInput{ Intent: strings.TrimSpace(payload.Intent), Identity: payload.Identity, + TargetUserID: payload.TargetUserID, ResolvedEmail: strings.TrimSpace(payload.ResolvedEmail), RedirectTo: strings.TrimSpace(payload.RedirectTo), BrowserSessionKey: strings.TrimSpace(payload.BrowserSessionKey), @@ -175,6 +189,291 @@ func pendingSessionWantsInvitation(payload map[string]any) bool { return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required") } +func (r oauthAdoptionDecisionRequest) hasDecision() bool { + return r.AdoptDisplayName != nil || r.AdoptAvatar != nil +} + +func (r oauthAdoptionDecisionRequest) toServiceInput(sessionID int64) service.PendingIdentityAdoptionDecisionInput { + input := service.PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: sessionID, + } + if r.AdoptDisplayName != nil { + input.AdoptDisplayName = *r.AdoptDisplayName + } + if r.AdoptAvatar != nil { + input.AdoptAvatar = *r.AdoptAvatar + } + return input +} + +func bindOptionalOAuthAdoptionDecision(c *gin.Context) (oauthAdoptionDecisionRequest, error) { + var req oauthAdoptionDecisionRequest + if c == nil || c.Request == nil || c.Request.Body == nil { + return req, nil + } + if err := c.ShouldBindJSON(&req); err != nil { + if errors.Is(err, io.EOF) { + return req, nil + } + return req, err + } + return req, nil +} + +func persistPendingOAuthAdoptionDecision( + c *gin.Context, + svc *service.AuthPendingIdentityService, + sessionID int64, + req oauthAdoptionDecisionRequest, +) error { + if !req.hasDecision() { + return nil + } + if svc == nil { + return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + if _, err := svc.UpsertAdoptionDecision(c.Request.Context(), req.toServiceInput(sessionID)); err != nil { + return infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err) + } + return nil +} + +func cloneOAuthMetadata(values map[string]any) map[string]any { + if len(values) == 0 { + return map[string]any{} + } + cloned := make(map[string]any, len(values)) + for key, value := range values { + cloned[key] = value + } + return cloned +} + +func normalizeAdoptedOAuthDisplayName(value string) string { + value = strings.TrimSpace(value) + if len([]rune(value)) > 100 { + value = string([]rune(value)[:100]) + } + return value +} + +func (h *AuthHandler) entClient() *dbent.Client { + if h == nil || h.authService == nil { + return nil + } + return h.authService.EntClient() +} + +func (h *AuthHandler) upsertPendingOAuthAdoptionDecision( + c *gin.Context, + sessionID int64, + req oauthAdoptionDecisionRequest, +) (*dbent.IdentityAdoptionDecision, error) { + client := h.entClient() + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + existing, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(sessionID)). + Only(c.Request.Context()) + if err != nil && !dbent.IsNotFound(err) { + return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_LOAD_FAILED", "failed to load oauth profile adoption decision").WithCause(err) + } + if existing != nil && !req.hasDecision() { + return existing, nil + } + if existing == nil && !req.hasDecision() { + return nil, nil + } + + input := service.PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: sessionID, + } + if existing != nil { + input.AdoptDisplayName = existing.AdoptDisplayName + input.AdoptAvatar = existing.AdoptAvatar + input.IdentityID = existing.IdentityID + } + if req.AdoptDisplayName != nil { + input.AdoptDisplayName = *req.AdoptDisplayName + } + if req.AdoptAvatar != nil { + input.AdoptAvatar = *req.AdoptAvatar + } + + svc, err := h.pendingIdentityService() + if err != nil { + return nil, err + } + decision, err := svc.UpsertAdoptionDecision(c.Request.Context(), input) + if err != nil { + return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err) + } + return decision, nil +} + +func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) (int64, error) { + if session == nil { + return 0, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid") + } + if session.TargetUserID != nil && *session.TargetUserID > 0 { + return *session.TargetUserID, nil + } + email := strings.TrimSpace(session.ResolvedEmail) + if email == "" { + return 0, infraerrors.BadRequest("PENDING_AUTH_TARGET_USER_MISSING", "pending auth target user is missing") + } + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEQ(email)). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return 0, infraerrors.InternalServer("PENDING_AUTH_TARGET_USER_NOT_FOUND", "pending auth target user was not found") + } + return 0, err + } + return userEntity.ID, nil +} + +func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string { + if session == nil { + return nil + } + switch strings.TrimSpace(session.ProviderType) { + case "oidc": + issuer := strings.TrimSpace(session.ProviderKey) + if issuer == "" { + issuer = pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer") + } + if issuer == "" { + return nil + } + return &issuer + default: + issuer := pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer") + if issuer == "" { + return nil + } + return &issuer + } +} + +func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) { + client := tx.Client() + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(strings.TrimSpace(session.ProviderType)), + authidentity.ProviderKeyEQ(strings.TrimSpace(session.ProviderKey)), + authidentity.ProviderSubjectEQ(strings.TrimSpace(session.ProviderSubject)), + ). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, err + } + if identity != nil { + if identity.UserID != userID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + return identity, nil + } + + create := client.AuthIdentity.Create(). + SetUserID(userID). + SetProviderType(strings.TrimSpace(session.ProviderType)). + SetProviderKey(strings.TrimSpace(session.ProviderKey)). + SetProviderSubject(strings.TrimSpace(session.ProviderSubject)). + SetMetadata(cloneOAuthMetadata(session.UpstreamIdentityClaims)) + if issuer := oauthIdentityIssuer(session); issuer != nil { + create = create.SetIssuer(strings.TrimSpace(*issuer)) + } + return create.Save(ctx) +} + +func applyPendingOAuthAdoption( + ctx context.Context, + client *dbent.Client, + session *dbent.PendingAuthSession, + decision *dbent.IdentityAdoptionDecision, + overrideUserID *int64, +) error { + if client == nil || session == nil || decision == nil { + return nil + } + if !decision.AdoptDisplayName && !decision.AdoptAvatar { + return nil + } + + targetUserID := int64(0) + if overrideUserID != nil && *overrideUserID > 0 { + targetUserID = *overrideUserID + } else { + resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, client, session) + if err != nil { + return err + } + targetUserID = resolvedUserID + } + + adoptedDisplayName := "" + if decision.AdoptDisplayName { + adoptedDisplayName = normalizeAdoptedOAuthDisplayName(pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name")) + } + adoptedAvatarURL := "" + if decision.AdoptAvatar { + adoptedAvatarURL = pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url") + } + + tx, err := client.Tx(ctx) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + + if decision.AdoptDisplayName && adoptedDisplayName != "" { + if err := tx.Client().User.UpdateOneID(targetUserID). + SetUsername(adoptedDisplayName). + Exec(ctx); err != nil { + return err + } + } + + identity, err := ensurePendingOAuthIdentityForUser(ctx, tx, session, targetUserID) + if err != nil { + return err + } + + metadata := cloneOAuthMetadata(identity.Metadata) + for key, value := range session.UpstreamIdentityClaims { + metadata[key] = value + } + if decision.AdoptDisplayName && adoptedDisplayName != "" { + metadata["display_name"] = adoptedDisplayName + } + if decision.AdoptAvatar && adoptedAvatarURL != "" { + metadata["avatar_url"] = adoptedAvatarURL + } + + updateIdentity := tx.Client().AuthIdentity.UpdateOneID(identity.ID).SetMetadata(metadata) + if issuer := oauthIdentityIssuer(session); issuer != nil { + updateIdentity = updateIdentity.SetIssuer(strings.TrimSpace(*issuer)) + } + if _, err := updateIdentity.Save(ctx); err != nil { + return err + } + + if decision.IdentityID == nil || *decision.IdentityID != identity.ID { + if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID). + SetIdentityID(identity.ID). + Save(ctx); err != nil { + return err + } + } + + return tx.Commit() +} + func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) { if len(payload) == 0 || len(upstream) == 0 { return @@ -206,6 +505,11 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { clearOAuthPendingSessionCookie(c, secureCookie) clearOAuthPendingBrowserCookie(c, secureCookie) } + adoptionDecision, err := bindOptionalOAuthAdoptionDecision(c) + if err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } sessionToken, err := readOAuthPendingSessionCookie(c) if err != nil || strings.TrimSpace(sessionToken) == "" { @@ -248,9 +552,30 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims) if pendingSessionWantsInvitation(payload) { + if adoptionDecision.hasDecision() { + decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, adoptionDecision) + if err != nil { + response.ErrorFrom(c, err) + return + } + _ = decision + } response.Success(c, payload) return } + if !adoptionDecision.hasDecision() { + response.Success(c, payload) + return + } + decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, adoptionDecision) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, session.TargetUserID); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) + return + } if _, err := svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { clearCookies() diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index 5517bae2..829fc217 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -1,9 +1,30 @@ package handler import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "net/http" + "net/http/httptest" "testing" + "time" + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" ) func TestApplySuggestedProfileToCompletionResponse(t *testing.T) { @@ -38,3 +59,439 @@ func TestApplySuggestedProfileToCompletionResponseKeepsExistingPayloadValues(t * require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"]) require.Equal(t, true, payload["adoption_required"]) } + +func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecision(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("linuxdo-123@linuxdo-connect.invalid"). + SetUsername("legacy-name"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("pending-session-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("123"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + "suggested_display_name": "Alice Example", + "suggested_avatar_url": "https://cdn.example/alice.png", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "access-token", + "redirect": "/dashboard", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + previewRecorder := httptest.NewRecorder() + previewCtx, _ := gin.CreateTestContext(previewRecorder) + previewReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil) + previewReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + previewReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")}) + previewCtx.Request = previewReq + + handler.ExchangePendingOAuthCompletion(previewCtx) + + require.Equal(t, http.StatusOK, previewRecorder.Code) + previewData := decodeJSONResponseData(t, previewRecorder) + require.Equal(t, "Alice Example", previewData["suggested_display_name"]) + require.Equal(t, "https://cdn.example/alice.png", previewData["suggested_avatar_url"]) + require.Equal(t, true, previewData["adoption_required"]) + + storedUser, err := client.User.Get(ctx, userEntity.ID) + require.NoError(t, err) + require.Equal(t, "legacy-name", storedUser.Username) + + previewSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.Nil(t, previewSession.ConsumedAt) + + body := bytes.NewBufferString(`{"adopt_display_name":true,"adopt_avatar":true}`) + finalizeRecorder := httptest.NewRecorder() + finalizeCtx, _ := gin.CreateTestContext(finalizeRecorder) + finalizeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) + finalizeReq.Header.Set("Content-Type", "application/json") + finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")}) + finalizeCtx.Request = finalizeReq + + handler.ExchangePendingOAuthCompletion(finalizeCtx) + + require.Equal(t, http.StatusOK, finalizeRecorder.Code) + + storedUser, err = client.User.Get(ctx, userEntity.ID) + require.NoError(t, err) + require.Equal(t, "Alice Example", storedUser.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + require.Equal(t, "Alice Example", identity.Metadata["display_name"]) + require.Equal(t, "https://cdn.example/alice.png", identity.Metadata["avatar_url"]) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.True(t, decision.AdoptDisplayName) + require.True(t, decision.AdoptAvatar) + + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) +} + +func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:auth_oauth_pending_flow_handler?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + AccessTokenExpireMinutes: 60, + RefreshTokenExpireDays: 7, + }, + Default: config.DefaultConfig{ + UserBalance: 0, + UserConcurrency: 1, + }, + } + settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{ + values: map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled), + }, + }, cfg) + authSvc := service.NewAuthService( + client, + &oauthPendingFlowUserRepo{client: client}, + nil, + &oauthPendingFlowRefreshTokenCacheStub{}, + cfg, + settingSvc, + nil, + nil, + nil, + nil, + nil, + ) + + return &AuthHandler{ + authService: authSvc, + settingSvc: settingSvc, + }, client +} + +func boolSettingValue(v bool) string { + if v { + return "true" + } + return "false" +} + +func boolPtr(v bool) *bool { + return &v +} + +type oauthPendingFlowSettingRepoStub struct { + values map[string]string +} + +func (s *oauthPendingFlowSettingRepoStub) Get(context.Context, string) (*service.Setting, error) { + return nil, service.ErrSettingNotFound +} + +func (s *oauthPendingFlowSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + value, ok := s.values[key] + if !ok { + return "", service.ErrSettingNotFound + } + return value, nil +} + +func (s *oauthPendingFlowSettingRepoStub) Set(context.Context, string, string) error { + return nil +} + +func (s *oauthPendingFlowSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + result := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + result[key] = value + } + } + return result, nil +} + +func (s *oauthPendingFlowSettingRepoStub) SetMultiple(context.Context, map[string]string) error { + return nil +} + +func (s *oauthPendingFlowSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + result := make(map[string]string, len(s.values)) + for key, value := range s.values { + result[key] = value + } + return result, nil +} + +func (s *oauthPendingFlowSettingRepoStub) Delete(context.Context, string) error { + return nil +} + +type oauthPendingFlowRefreshTokenCacheStub struct{} + +func (s *oauthPendingFlowRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) { + return nil, service.ErrRefreshTokenNotFound +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) { + return nil, nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) { + return nil, nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) { + return false, nil +} + +func decodeJSONResponseData(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any { + t.Helper() + + var envelope struct { + Data map[string]any `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &envelope)) + return envelope.Data +} + +func decodeJSONBody(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any { + t.Helper() + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + return payload +} + +type oauthPendingFlowUserRepo struct { + client *dbent.Client +} + +func (r *oauthPendingFlowUserRepo) Create(ctx context.Context, user *service.User) error { + entity, err := r.client.User.Create(). + SetEmail(user.Email). + SetUsername(user.Username). + SetNotes(user.Notes). + SetPasswordHash(user.PasswordHash). + SetRole(user.Role). + SetBalance(user.Balance). + SetConcurrency(user.Concurrency). + SetStatus(user.Status). + SetSignupSource(user.SignupSource). + SetNillableLastLoginAt(user.LastLoginAt). + SetNillableLastActiveAt(user.LastActiveAt). + Save(ctx) + if err != nil { + return err + } + user.ID = entity.ID + user.CreatedAt = entity.CreatedAt + user.UpdatedAt = entity.UpdatedAt + return nil +} + +func (r *oauthPendingFlowUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) { + entity, err := r.client.User.Get(ctx, id) + if err != nil { + if dbent.IsNotFound(err) { + return nil, service.ErrUserNotFound + } + return nil, err + } + return oauthPendingFlowServiceUser(entity), nil +} + +func (r *oauthPendingFlowUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) { + entity, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, service.ErrUserNotFound + } + return nil, err + } + return oauthPendingFlowServiceUser(entity), nil +} + +func (r *oauthPendingFlowUserRepo) GetFirstAdmin(context.Context) (*service.User, error) { + panic("unexpected GetFirstAdmin call") +} + +func (r *oauthPendingFlowUserRepo) Update(ctx context.Context, user *service.User) error { + entity, err := r.client.User.UpdateOneID(user.ID). + SetEmail(user.Email). + SetUsername(user.Username). + SetNotes(user.Notes). + SetPasswordHash(user.PasswordHash). + SetRole(user.Role). + SetBalance(user.Balance). + SetConcurrency(user.Concurrency). + SetStatus(user.Status). + SetSignupSource(user.SignupSource). + SetNillableLastLoginAt(user.LastLoginAt). + SetNillableLastActiveAt(user.LastActiveAt). + Save(ctx) + if err != nil { + return err + } + user.UpdatedAt = entity.UpdatedAt + return nil +} + +func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error { + return r.client.User.DeleteOneID(id).Exec(ctx) +} + +func (r *oauthPendingFlowUserRepo) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) { + return nil, service.ErrUserNotFound +} + +func (r *oauthPendingFlowUserRepo) UpsertUserAvatar(context.Context, int64, service.UpsertUserAvatarInput) (*service.UserAvatar, error) { + panic("unexpected UpsertUserAvatar call") +} + +func (r *oauthPendingFlowUserRepo) DeleteUserAvatar(context.Context, int64) error { + return nil +} + +func (r *oauthPendingFlowUserRepo) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (r *oauthPendingFlowUserRepo) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (r *oauthPendingFlowUserRepo) UpdateBalance(context.Context, int64, float64) error { + panic("unexpected UpdateBalance call") +} + +func (r *oauthPendingFlowUserRepo) DeductBalance(context.Context, int64, float64) error { + panic("unexpected DeductBalance call") +} + +func (r *oauthPendingFlowUserRepo) UpdateConcurrency(context.Context, int64, int) error { + panic("unexpected UpdateConcurrency call") +} + +func (r *oauthPendingFlowUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) { + count, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Count(ctx) + return count > 0, err +} + +func (r *oauthPendingFlowUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { + panic("unexpected RemoveGroupFromAllowedGroups call") +} + +func (r *oauthPendingFlowUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { + panic("unexpected AddGroupToAllowedGroups call") +} + +func (r *oauthPendingFlowUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { + panic("unexpected RemoveGroupFromUserAllowedGroups call") +} + +func (r *oauthPendingFlowUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { + panic("unexpected UpdateTotpSecret call") +} + +func (r *oauthPendingFlowUserRepo) EnableTotp(context.Context, int64) error { + panic("unexpected EnableTotp call") +} + +func (r *oauthPendingFlowUserRepo) DisableTotp(context.Context, int64) error { + panic("unexpected DisableTotp call") +} + +func oauthPendingFlowServiceUser(entity *dbent.User) *service.User { + if entity == nil { + return nil + } + return &service.User{ + ID: entity.ID, + Email: entity.Email, + Username: entity.Username, + Notes: entity.Notes, + PasswordHash: entity.PasswordHash, + Role: entity.Role, + Balance: entity.Balance, + Concurrency: entity.Concurrency, + Status: entity.Status, + SignupSource: entity.SignupSource, + LastLoginAt: entity.LastLoginAt, + LastActiveAt: entity.LastActiveAt, + CreatedAt: entity.CreatedAt, + UpdatedAt: entity.UpdatedAt, + } +} diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index e3694c8f..ceda633c 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -326,7 +326,7 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { ) // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired - tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") if err != nil { if errors.Is(err, service.ErrOAuthInvitationRequired) { if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ @@ -371,6 +371,7 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { ProviderKey: issuer, ProviderSubject: subject, }, + TargetUserID: &user.ID, ResolvedEmail: email, RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey, @@ -399,7 +400,9 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { } type completeOIDCOAuthRequest struct { - InvitationCode string `json:"invitation_code" binding:"required"` + InvitationCode string `json:"invitation_code" binding:"required"` + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` } // CompleteOIDCOAuthRegistration completes a pending OAuth registration by validating @@ -447,11 +450,23 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { return } - tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) if err != nil { response.ErrorFrom(c, err) return } + decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ + AdoptDisplayName: req.AdoptDisplayName, + AdoptAvatar: req.AdoptAvatar, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) + return + } if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { clearOAuthPendingSessionCookie(c, secureCookie) clearOAuthPendingBrowserCookie(c, secureCookie) diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go index c389db51..9107e13a 100644 --- a/backend/internal/handler/auth_oidc_oauth_test.go +++ b/backend/internal/handler/auth_oidc_oauth_test.go @@ -1,6 +1,7 @@ package handler import ( + "bytes" "context" "crypto/rand" "crypto/rsa" @@ -12,7 +13,13 @@ import ( "testing" "time" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/require" ) @@ -123,3 +130,80 @@ func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK { E: e, } } + +func TestCompleteOIDCOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("oidc-complete-session"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example.com"). + SetProviderSubject("oidc-subject-1"). + SetResolvedEmail("93a310f4c1944c5bbd2e246df1f76485@oidc-connect.invalid"). + SetBrowserSessionKey("oidc-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "issuer": "https://issuer.example.com", + "suggested_display_name": "OIDC Display", + "suggested_avatar_url": "https://cdn.example/oidc.png", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + _, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + AdoptAvatar: true, + }) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-browser")}) + c.Request = req + + handler.CompleteOIDCOAuthRegistration(c) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.NotEmpty(t, responseData["access_token"]) + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEQ(session.ResolvedEmail)). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "OIDC Display", userEntity.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example.com"), + authidentity.ProviderSubjectEQ("oidc-subject-1"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + require.Equal(t, "OIDC Display", identity.Metadata["display_name"]) + require.Equal(t, "https://cdn.example/oidc.png", identity.Metadata["avatar_url"]) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.True(t, decision.AdoptDisplayName) + require.True(t, decision.AdoptAvatar) + + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) +} diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go new file mode 100644 index 00000000..867a77a1 --- /dev/null +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -0,0 +1,618 @@ +package handler + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +const ( + wechatOAuthCookiePath = "/api/v1/auth/oauth/wechat" + wechatOAuthCookieMaxAgeSec = 10 * 60 + wechatOAuthStateCookieName = "wechat_oauth_state" + wechatOAuthRedirectCookieName = "wechat_oauth_redirect" + wechatOAuthIntentCookieName = "wechat_oauth_intent" + wechatOAuthModeCookieName = "wechat_oauth_mode" + wechatOAuthDefaultRedirectTo = "/dashboard" + wechatOAuthDefaultFrontendCB = "/auth/wechat/callback" + wechatOAuthProviderKey = "wechat-main" + + wechatOAuthIntentLogin = "login" + wechatOAuthIntentBind = "bind_current_user" + wechatOAuthIntentAdoptEmail = "adopt_existing_user_by_email" +) + +var ( + wechatOAuthAccessTokenURL = "https://api.weixin.qq.com/sns/oauth2/access_token" + wechatOAuthUserInfoURL = "https://api.weixin.qq.com/sns/userinfo" +) + +type wechatOAuthConfig struct { + mode string + appID string + appSecret string + authorizeURL string + scope string + redirectURI string + frontendCallback string +} + +type wechatOAuthTokenResponse struct { + AccessToken string `json:"access_token"` + ExpiresIn int64 `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + OpenID string `json:"openid"` + Scope string `json:"scope"` + UnionID string `json:"unionid"` + ErrCode int64 `json:"errcode"` + ErrMsg string `json:"errmsg"` +} + +type wechatOAuthUserInfoResponse struct { + OpenID string `json:"openid"` + Nickname string `json:"nickname"` + HeadImgURL string `json:"headimgurl"` + UnionID string `json:"unionid"` + ErrCode int64 `json:"errcode"` + ErrMsg string `json:"errmsg"` +} + +// WeChatOAuthStart starts the WeChat OAuth login flow and stores the short-lived +// browser cookies required by the rebuild pending-auth bridge. +func (h *AuthHandler) WeChatOAuthStart(c *gin.Context) { + cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), c.Query("mode"), c) + if err != nil { + response.ErrorFrom(c, err) + return + } + + state, err := oauth.GenerateState() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err)) + return + } + + redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect")) + if redirectTo == "" { + redirectTo = wechatOAuthDefaultRedirectTo + } + + browserSessionKey, err := generateOAuthPendingBrowserSession() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err)) + return + } + + intent := normalizeWeChatOAuthIntent(c.Query("intent")) + secureCookie := isRequestHTTPS(c) + wechatSetCookie(c, wechatOAuthStateCookieName, encodeCookieValue(state), wechatOAuthCookieMaxAgeSec, secureCookie) + wechatSetCookie(c, wechatOAuthRedirectCookieName, encodeCookieValue(redirectTo), wechatOAuthCookieMaxAgeSec, secureCookie) + wechatSetCookie(c, wechatOAuthIntentCookieName, encodeCookieValue(intent), wechatOAuthCookieMaxAgeSec, secureCookie) + wechatSetCookie(c, wechatOAuthModeCookieName, encodeCookieValue(cfg.mode), wechatOAuthCookieMaxAgeSec, secureCookie) + setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie) + clearOAuthPendingSessionCookie(c, secureCookie) + + authURL, err := buildWeChatAuthorizeURL(cfg, state) + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err)) + return + } + + c.Redirect(http.StatusFound, authURL) +} + +// WeChatOAuthCallback exchanges the code with WeChat, resolves openid/unionid, +// and stores the result in the unified pending-auth flow. +func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) { + frontendCallback := wechatOAuthFrontendCallback() + + if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" { + redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description")) + return + } + + code := strings.TrimSpace(c.Query("code")) + state := strings.TrimSpace(c.Query("state")) + if code == "" || state == "" { + redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "") + return + } + + secureCookie := isRequestHTTPS(c) + defer func() { + wechatClearCookie(c, wechatOAuthStateCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthRedirectCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthIntentCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthModeCookieName, secureCookie) + }() + + expectedState, err := readCookieDecoded(c, wechatOAuthStateCookieName) + if err != nil || expectedState == "" || state != expectedState { + redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "") + return + } + + redirectTo, _ := readCookieDecoded(c, wechatOAuthRedirectCookieName) + redirectTo = sanitizeFrontendRedirectPath(redirectTo) + if redirectTo == "" { + redirectTo = wechatOAuthDefaultRedirectTo + } + browserSessionKey, _ := readOAuthPendingBrowserCookie(c) + if strings.TrimSpace(browserSessionKey) == "" { + redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "") + return + } + + intent, _ := readCookieDecoded(c, wechatOAuthIntentCookieName) + mode, err := readCookieDecoded(c, wechatOAuthModeCookieName) + if err != nil || strings.TrimSpace(mode) == "" { + redirectOAuthError(c, frontendCallback, "invalid_state", "missing oauth mode", "") + return + } + + cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), mode, c) + if err != nil { + redirectOAuthError(c, frontendCallback, "provider_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + + tokenResp, userInfo, err := fetchWeChatOAuthIdentity(c.Request.Context(), cfg, code) + if err != nil { + redirectOAuthError(c, frontendCallback, "provider_error", "wechat_identity_fetch_failed", singleLine(err.Error())) + return + } + + unionid := strings.TrimSpace(firstNonEmpty(userInfo.UnionID, tokenResp.UnionID)) + openid := strings.TrimSpace(firstNonEmpty(userInfo.OpenID, tokenResp.OpenID)) + providerSubject := firstNonEmpty(unionid, openid) + if providerSubject == "" { + redirectOAuthError(c, frontendCallback, "provider_error", "wechat_missing_subject", "") + return + } + + username := firstNonEmpty(userInfo.Nickname, wechatFallbackUsername(providerSubject)) + email := wechatSyntheticEmail(providerSubject) + upstreamClaims := map[string]any{ + "email": email, + "username": username, + "subject": providerSubject, + "openid": openid, + "unionid": unionid, + "mode": cfg.mode, + "suggested_display_name": strings.TrimSpace(userInfo.Nickname), + "suggested_avatar_url": strings.TrimSpace(userInfo.HeadImgURL), + } + + tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") + if err != nil { + if err := h.createWeChatPendingSession(c, normalizeWeChatOAuthIntent(intent), providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, err); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + + if err := h.createWeChatPendingSession(c, normalizeWeChatOAuthIntent(intent), providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, nil); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) +} + +type completeWeChatOAuthRequest struct { + InvitationCode string `json:"invitation_code" binding:"required"` + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` +} + +// CompleteWeChatOAuthRegistration completes a pending WeChat OAuth registration by +// validating the invitation code and consuming the current pending browser session. +// POST /api/v1/auth/oauth/wechat/complete-registration +func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) { + var req completeWeChatOAuthRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()}) + return + } + + secureCookie := isRequestHTTPS(c) + sessionToken, err := readOAuthPendingSessionCookie(c) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound) + return + } + browserSessionKey, err := readOAuthPendingBrowserCookie(c) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch) + return + } + pendingSvc, err := h.pendingIdentityService() + if err != nil { + response.ErrorFrom(c, err) + return + } + session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, err) + return + } + + email := strings.TrimSpace(session.ResolvedEmail) + username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username") + if email == "" || username == "" { + response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")) + return + } + + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + if err != nil { + response.ErrorFrom(c, err) + return + } + decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ + AdoptDisplayName: req.AdoptDisplayName, + AdoptAvatar: req.AdoptAvatar, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) + return + } + if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, err) + return + } + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + + c.JSON(http.StatusOK, gin.H{ + "access_token": tokenPair.AccessToken, + "refresh_token": tokenPair.RefreshToken, + "expires_in": tokenPair.ExpiresIn, + "token_type": "Bearer", + }) +} + +func (h *AuthHandler) createWeChatPendingSession( + c *gin.Context, + intent string, + providerSubject string, + email string, + redirectTo string, + browserSessionKey string, + upstreamClaims map[string]any, + tokenPair *service.TokenPair, + authErr error, +) error { + completionResponse := map[string]any{ + "redirect": redirectTo, + } + if authErr != nil { + if errors.Is(authErr, service.ErrOAuthInvitationRequired) { + completionResponse["error"] = "invitation_required" + } else { + return authErr + } + } else if tokenPair != nil { + completionResponse["access_token"] = tokenPair.AccessToken + completionResponse["refresh_token"] = tokenPair.RefreshToken + completionResponse["expires_in"] = tokenPair.ExpiresIn + completionResponse["token_type"] = "Bearer" + } + + return h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: intent, + Identity: service.PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: wechatOAuthProviderKey, + ProviderSubject: providerSubject, + }, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: completionResponse, + }) +} + +func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string, c *gin.Context) (wechatOAuthConfig, error) { + mode, err := resolveWeChatOAuthMode(rawMode, c) + if err != nil { + return wechatOAuthConfig{}, err + } + + apiBaseURL := "" + if h != nil && h.settingSvc != nil { + settings, err := h.settingSvc.GetAllSettings(ctx) + if err == nil && settings != nil { + apiBaseURL = strings.TrimSpace(settings.APIBaseURL) + } + } + + cfg := wechatOAuthConfig{ + mode: mode, + redirectURI: resolveWeChatOAuthAbsoluteURL(apiBaseURL, c, "/api/v1/auth/oauth/wechat/callback"), + frontendCallback: wechatOAuthFrontendCallback(), + } + + switch mode { + case "mp": + cfg.appID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID")) + cfg.appSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET")) + cfg.authorizeURL = "https://open.weixin.qq.com/connect/oauth2/authorize" + cfg.scope = "snsapi_userinfo" + default: + cfg.appID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID")) + cfg.appSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET")) + cfg.authorizeURL = "https://open.weixin.qq.com/connect/qrconnect" + cfg.scope = "snsapi_login" + } + + if cfg.appID == "" || cfg.appSecret == "" { + return wechatOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled") + } + if strings.TrimSpace(cfg.redirectURI) == "" { + return wechatOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url not configured") + } + + return cfg, nil +} + +func wechatOAuthFrontendCallback() string { + return firstNonEmpty(strings.TrimSpace(os.Getenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL")), wechatOAuthDefaultFrontendCB) +} + +func resolveWeChatOAuthMode(rawMode string, c *gin.Context) (string, error) { + mode := strings.ToLower(strings.TrimSpace(rawMode)) + if mode == "" { + if isWeChatBrowserRequest(c) { + return "mp", nil + } + return "open", nil + } + if mode != "open" && mode != "mp" { + return "", infraerrors.BadRequest("INVALID_MODE", "wechat oauth mode must be open or mp") + } + return mode, nil +} + +func isWeChatBrowserRequest(c *gin.Context) bool { + if c == nil || c.Request == nil { + return false + } + return strings.Contains(strings.ToLower(strings.TrimSpace(c.GetHeader("User-Agent"))), "micromessenger") +} + +func normalizeWeChatOAuthIntent(raw string) string { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "", "login": + return wechatOAuthIntentLogin + case "bind", "bind_current_user": + return wechatOAuthIntentBind + case "adopt", "adopt_existing_user_by_email": + return wechatOAuthIntentAdoptEmail + default: + return wechatOAuthIntentLogin + } +} + +func buildWeChatAuthorizeURL(cfg wechatOAuthConfig, state string) (string, error) { + u, err := url.Parse(cfg.authorizeURL) + if err != nil { + return "", fmt.Errorf("parse authorize url: %w", err) + } + query := u.Query() + query.Set("appid", cfg.appID) + query.Set("redirect_uri", cfg.redirectURI) + query.Set("response_type", "code") + query.Set("scope", cfg.scope) + query.Set("state", state) + u.RawQuery = query.Encode() + u.Fragment = "wechat_redirect" + return u.String(), nil +} + +func resolveWeChatOAuthAbsoluteURL(apiBaseURL string, c *gin.Context, callbackPath string) string { + callbackPath = strings.TrimSpace(callbackPath) + if callbackPath == "" { + return "" + } + + if raw := strings.TrimSpace(apiBaseURL); raw != "" { + if parsed, err := url.Parse(raw); err == nil && parsed.Scheme != "" && parsed.Host != "" { + basePath := strings.TrimRight(parsed.EscapedPath(), "/") + targetPath := callbackPath + if basePath != "" && strings.HasSuffix(basePath, "/api/v1") && strings.HasPrefix(callbackPath, "/api/v1") { + targetPath = basePath + strings.TrimPrefix(callbackPath, "/api/v1") + } else if basePath != "" { + targetPath = basePath + callbackPath + } + return parsed.Scheme + "://" + parsed.Host + targetPath + } + } + + if c == nil || c.Request == nil { + return "" + } + scheme := "http" + if isRequestHTTPS(c) { + scheme = "https" + } + host := strings.TrimSpace(c.Request.Host) + if forwardedHost := strings.TrimSpace(c.GetHeader("X-Forwarded-Host")); forwardedHost != "" { + host = forwardedHost + } + if host == "" { + return "" + } + return scheme + "://" + host + callbackPath +} + +func fetchWeChatOAuthIdentity(ctx context.Context, cfg wechatOAuthConfig, code string) (*wechatOAuthTokenResponse, *wechatOAuthUserInfoResponse, error) { + tokenResp, err := exchangeWeChatOAuthCode(ctx, cfg, code) + if err != nil { + return nil, nil, err + } + userInfo, err := fetchWeChatUserInfo(ctx, tokenResp) + if err != nil { + return nil, nil, err + } + return tokenResp, userInfo, nil +} + +func exchangeWeChatOAuthCode(ctx context.Context, cfg wechatOAuthConfig, code string) (*wechatOAuthTokenResponse, error) { + endpoint, err := url.Parse(wechatOAuthAccessTokenURL) + if err != nil { + return nil, fmt.Errorf("parse wechat access token url: %w", err) + } + + query := endpoint.Query() + query.Set("appid", cfg.appID) + query.Set("secret", cfg.appSecret) + query.Set("code", strings.TrimSpace(code)) + query.Set("grant_type", "authorization_code") + endpoint.RawQuery = query.Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil) + if err != nil { + return nil, fmt.Errorf("build wechat access token request: %w", err) + } + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("request wechat access token: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read wechat access token response: %w", err) + } + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, fmt.Errorf("wechat access token status=%d", resp.StatusCode) + } + + var tokenResp wechatOAuthTokenResponse + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, fmt.Errorf("decode wechat access token response: %w", err) + } + if tokenResp.ErrCode != 0 { + return nil, fmt.Errorf("wechat access token error=%d %s", tokenResp.ErrCode, strings.TrimSpace(tokenResp.ErrMsg)) + } + if strings.TrimSpace(tokenResp.AccessToken) == "" { + return nil, fmt.Errorf("wechat access token missing access_token") + } + return &tokenResp, nil +} + +func fetchWeChatUserInfo(ctx context.Context, tokenResp *wechatOAuthTokenResponse) (*wechatOAuthUserInfoResponse, error) { + if tokenResp == nil { + return nil, fmt.Errorf("wechat token response is nil") + } + + endpoint, err := url.Parse(wechatOAuthUserInfoURL) + if err != nil { + return nil, fmt.Errorf("parse wechat userinfo url: %w", err) + } + query := endpoint.Query() + query.Set("access_token", strings.TrimSpace(tokenResp.AccessToken)) + query.Set("openid", strings.TrimSpace(tokenResp.OpenID)) + query.Set("lang", "zh_CN") + endpoint.RawQuery = query.Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil) + if err != nil { + return nil, fmt.Errorf("build wechat userinfo request: %w", err) + } + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("request wechat userinfo: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read wechat userinfo response: %w", err) + } + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, fmt.Errorf("wechat userinfo status=%d", resp.StatusCode) + } + + var userInfo wechatOAuthUserInfoResponse + if err := json.Unmarshal(body, &userInfo); err != nil { + return nil, fmt.Errorf("decode wechat userinfo response: %w", err) + } + if userInfo.ErrCode != 0 { + return nil, fmt.Errorf("wechat userinfo error=%d %s", userInfo.ErrCode, strings.TrimSpace(userInfo.ErrMsg)) + } + return &userInfo, nil +} + +func wechatSyntheticEmail(subject string) string { + subject = strings.TrimSpace(subject) + if subject == "" { + return "" + } + return "wechat-" + subject + service.WeChatConnectSyntheticEmailDomain +} + +func wechatFallbackUsername(subject string) string { + subject = strings.TrimSpace(subject) + if subject == "" { + return "wechat_user" + } + return "wechat_" + truncateFragmentValue(subject) +} + +func wechatSetCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: value, + Path: wechatOAuthCookiePath, + MaxAge: maxAgeSec, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func wechatClearCookie(c *gin.Context, name string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: "", + Path: wechatOAuthCookiePath, + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go new file mode 100644 index 00000000..1a765dcc --- /dev/null +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -0,0 +1,411 @@ +//go:build unit + +package handler + +import ( + "bytes" + "context" + "database/sql" + "encoding/base64" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/repository" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func TestWeChatOAuthStartRedirectsAndSetsPendingCookies(t *testing.T) { + t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") + t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") + + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/start?mode=open&redirect=/billing", nil) + c.Request.Host = "api.example.com" + + handler := &AuthHandler{} + handler.WeChatOAuthStart(c) + + require.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + require.NotEmpty(t, location) + require.Contains(t, location, "open.weixin.qq.com") + require.Contains(t, location, "appid=wx-open-app") + require.Contains(t, location, "scope=snsapi_login") + + cookies := recorder.Result().Cookies() + require.NotEmpty(t, findCookie(cookies, wechatOAuthStateCookieName)) + require.NotEmpty(t, findCookie(cookies, wechatOAuthRedirectCookieName)) + require.NotEmpty(t, findCookie(cookies, wechatOAuthModeCookieName)) + require.NotEmpty(t, findCookie(cookies, oauthPendingBrowserCookieName)) +} + +func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) { + t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") + t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") + t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") + + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + ctx := context.Background() + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "wechat", session.ProviderType) + require.Equal(t, "wechat-main", session.ProviderKey) + require.Equal(t, "union-456", session.ProviderSubject) + require.Equal(t, "wechat-union-456@wechat-connect.invalid", session.ResolvedEmail) + require.Equal(t, "WeChat Nick", session.UpstreamIdentityClaims["suggested_display_name"]) + require.Equal(t, "https://cdn.example/avatar.png", session.UpstreamIdentityClaims["suggested_avatar_url"]) + require.Equal(t, "union-456", session.UpstreamIdentityClaims["unionid"]) + require.Equal(t, "openid-123", session.UpstreamIdentityClaims["openid"]) +} + +func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing.T) { + t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") + t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") + t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") + + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, true) + defer client.Close() + + ctx := context.Background() + redeemRepo := repository.NewRedeemCodeRepository(client) + require.NoError(t, redeemRepo.Create(ctx, &service.RedeemCode{ + Code: "invite-1", + Type: service.RedeemTypeInvitation, + Status: service.StatusUnused, + })) + + callbackRecorder := httptest.NewRecorder() + callbackCtx, _ := gin.CreateTestContext(callbackRecorder) + callbackReq := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + callbackReq.Host = "api.example.com" + callbackReq.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + callbackReq.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + callbackReq.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + callbackReq.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + callbackCtx.Request = callbackReq + + handler.WeChatOAuthCallback(callbackCtx) + + require.Equal(t, http.StatusFound, callbackRecorder.Code) + require.Equal(t, "/auth/wechat/callback", callbackRecorder.Header().Get("Location")) + + sessionCookie := findCookie(callbackRecorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + sessionToken := decodeCookieValueForTest(t, sessionCookie.Value) + + pendingSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(sessionToken)). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "invitation_required", pendingSession.LocalFlowState[oauthCompletionResponseKey].(map[string]any)["error"]) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true,"adopt_avatar":true}`) + completeRecorder := httptest.NewRecorder() + completeCtx, _ := gin.CreateTestContext(completeRecorder) + completeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body) + completeReq.Header.Set("Content-Type", "application/json") + completeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(sessionToken)}) + completeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-123")}) + completeCtx.Request = completeReq + + handler.CompleteWeChatOAuthRegistration(completeCtx) + + require.Equal(t, http.StatusOK, completeRecorder.Code) + responseData := decodeJSONBody(t, completeRecorder) + require.NotEmpty(t, responseData["access_token"]) + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEQ("wechat-union-456@wechat-connect.invalid")). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "WeChat Display", userEntity.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ("wechat-main"), + authidentity.ProviderSubjectEQ("union-456"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + require.Equal(t, "WeChat Display", identity.Metadata["display_name"]) + require.Equal(t, "https://cdn.example/wechat.png", identity.Metadata["avatar_url"]) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.True(t, decision.AdoptDisplayName) + require.True(t, decision.AdoptAvatar) + + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(pendingSession.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) +} + +func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:auth_wechat_oauth?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + + userRepo := &oauthPendingFlowUserRepo{client: client} + redeemRepo := repository.NewRedeemCodeRepository(client) + settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{ + values: map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled), + }, + }, &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + AccessTokenExpireMinutes: 60, + RefreshTokenExpireDays: 7, + }, + Default: config.DefaultConfig{ + UserBalance: 0, + UserConcurrency: 1, + }, + }) + + authSvc := service.NewAuthService( + client, + userRepo, + redeemRepo, + &wechatOAuthRefreshTokenCacheStub{}, + &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + AccessTokenExpireMinutes: 60, + RefreshTokenExpireDays: 7, + }, + Default: config.DefaultConfig{ + UserBalance: 0, + UserConcurrency: 1, + }, + }, + settingSvc, + nil, + nil, + nil, + nil, + nil, + ) + + return &AuthHandler{ + authService: authSvc, + settingSvc: settingSvc, + }, client +} + +func encodedCookie(name, value string) *http.Cookie { + return &http.Cookie{ + Name: name, + Value: encodeCookieValue(value), + Path: "/", + } +} + +func findCookie(cookies []*http.Cookie, name string) *http.Cookie { + for _, cookie := range cookies { + if cookie.Name == name { + return cookie + } + } + return nil +} + +func decodeCookieValueForTest(t *testing.T, value string) string { + t.Helper() + raw, err := base64.RawURLEncoding.DecodeString(value) + require.NoError(t, err) + return string(raw) +} + +type wechatOAuthSettingRepoStub struct { + values map[string]string +} + +func (s *wechatOAuthSettingRepoStub) Get(context.Context, string) (*service.Setting, error) { + return nil, service.ErrSettingNotFound +} + +func (s *wechatOAuthSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + value, ok := s.values[key] + if !ok { + return "", service.ErrSettingNotFound + } + return value, nil +} + +func (s *wechatOAuthSettingRepoStub) Set(context.Context, string, string) error { + return nil +} + +func (s *wechatOAuthSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + result := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + result[key] = value + } + } + return result, nil +} + +func (s *wechatOAuthSettingRepoStub) SetMultiple(context.Context, map[string]string) error { + return nil +} + +func (s *wechatOAuthSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + result := make(map[string]string, len(s.values)) + for key, value := range s.values { + result[key] = value + } + return result, nil +} + +func (s *wechatOAuthSettingRepoStub) Delete(context.Context, string) error { + return nil +} + +type wechatOAuthRefreshTokenCacheStub struct{} + +func (s *wechatOAuthRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) { + return nil, service.ErrRefreshTokenNotFound +} + +func (s *wechatOAuthRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) { + return nil, nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) { + return nil, nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) { + return false, nil +} diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 3659e79b..f44b3e3b 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -189,6 +189,7 @@ type PublicSettings struct { CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"` OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"` OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"` SoraClientEnabled bool `json:"sora_client_enabled"` diff --git a/backend/internal/handler/payment_webhook_handler.go b/backend/internal/handler/payment_webhook_handler.go index 8a83bfeb..9fdefa93 100644 --- a/backend/internal/handler/payment_webhook_handler.go +++ b/backend/internal/handler/payment_webhook_handler.go @@ -120,7 +120,7 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string) // This allows looking up the correct provider instance before verification. func extractOutTradeNo(rawBody, providerKey string) string { switch providerKey { - case payment.TypeEasyPay: + case payment.TypeEasyPay, payment.TypeAlipay: values, err := url.ParseQuery(rawBody) if err == nil { return values.Get("out_trade_no") diff --git a/backend/internal/handler/payment_webhook_handler_test.go b/backend/internal/handler/payment_webhook_handler_test.go index bdef1766..6f448131 100644 --- a/backend/internal/handler/payment_webhook_handler_test.go +++ b/backend/internal/handler/payment_webhook_handler_test.go @@ -97,3 +97,37 @@ func TestWebhookConstants(t *testing.T) { assert.Equal(t, 200, webhookLogTruncateLen) }) } + +func TestExtractOutTradeNo(t *testing.T) { + tests := []struct { + name string + providerKey string + rawBody string + want string + }{ + { + name: "easypay query payload", + providerKey: "easypay", + rawBody: "out_trade_no=sub2_123&trade_status=TRADE_SUCCESS", + want: "sub2_123", + }, + { + name: "alipay query payload", + providerKey: "alipay", + rawBody: "notify_time=2026-04-20+12%3A00%3A00&out_trade_no=sub2_456", + want: "sub2_456", + }, + { + name: "unknown provider", + providerKey: "wxpay", + rawBody: "{}", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, extractOutTradeNo(tt.rawBody, tt.providerKey)) + }) + } +} diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 1717b7a1..c7bc3e2a 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -56,6 +56,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, + WeChatOAuthEnabled: settings.WeChatOAuthEnabled, OIDCOAuthEnabled: settings.OIDCOAuthEnabled, OIDCOAuthProviderName: settings.OIDCOAuthProviderName, BackendModeEnabled: settings.BackendModeEnabled, diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index 2535ea5e..904341d0 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -34,10 +34,16 @@ type ChangePasswordRequest struct { // UpdateProfileRequest represents the update profile request payload type UpdateProfileRequest struct { Username *string `json:"username"` + AvatarURL *string `json:"avatar_url"` BalanceNotifyEnabled *bool `json:"balance_notify_enabled"` BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` } +type userProfileResponse struct { + dto.User + AvatarURL string `json:"avatar_url,omitempty"` +} + // GetProfile handles getting user profile // GET /api/v1/users/me func (h *UserHandler) GetProfile(c *gin.Context) { @@ -47,13 +53,13 @@ func (h *UserHandler) GetProfile(c *gin.Context) { return } - userData, err := h.userService.GetByID(c.Request.Context(), subject.UserID) + userData, err := h.userService.GetProfile(c.Request.Context(), subject.UserID) if err != nil { response.ErrorFrom(c, err) return } - response.Success(c, dto.UserFromService(userData)) + response.Success(c, userProfileResponseFromService(userData)) } // ChangePassword handles changing user password @@ -101,6 +107,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { svcReq := service.UpdateProfileRequest{ Username: req.Username, + AvatarURL: req.AvatarURL, BalanceNotifyEnabled: req.BalanceNotifyEnabled, BalanceNotifyThreshold: req.BalanceNotifyThreshold, } @@ -110,7 +117,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { return } - response.Success(c, dto.UserFromService(updatedUser)) + response.Success(c, userProfileResponseFromService(updatedUser)) } // SendNotifyEmailCodeRequest represents the request to send notify email verification code @@ -176,7 +183,7 @@ func (h *UserHandler) VerifyNotifyEmail(c *gin.Context) { return } - response.Success(c, dto.UserFromService(updatedUser)) + response.Success(c, userProfileResponseFromService(updatedUser)) } // RemoveNotifyEmailRequest represents the request to remove a notify email @@ -212,7 +219,7 @@ func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) { return } - response.Success(c, dto.UserFromService(updatedUser)) + response.Success(c, userProfileResponseFromService(updatedUser)) } // ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state @@ -248,5 +255,16 @@ func (h *UserHandler) ToggleNotifyEmail(c *gin.Context) { return } - response.Success(c, dto.UserFromService(updatedUser)) + response.Success(c, userProfileResponseFromService(updatedUser)) +} + +func userProfileResponseFromService(user *service.User) userProfileResponse { + base := dto.UserFromService(user) + if base == nil { + return userProfileResponse{} + } + return userProfileResponse{ + User: *base, + AvatarURL: user.AvatarURL, + } } diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go new file mode 100644 index 00000000..1973f59e --- /dev/null +++ b/backend/internal/handler/user_handler_test.go @@ -0,0 +1,136 @@ +//go:build unit + +package handler + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type userHandlerRepoStub struct { + user *service.User +} + +func (s *userHandlerRepoStub) Create(context.Context, *service.User) error { return nil } +func (s *userHandlerRepoStub) GetByID(context.Context, int64) (*service.User, error) { + cloned := *s.user + return &cloned, nil +} +func (s *userHandlerRepoStub) GetByEmail(context.Context, string) (*service.User, error) { + cloned := *s.user + return &cloned, nil +} +func (s *userHandlerRepoStub) GetFirstAdmin(context.Context) (*service.User, error) { + cloned := *s.user + return &cloned, nil +} +func (s *userHandlerRepoStub) Update(_ context.Context, user *service.User) error { + cloned := *user + s.user = &cloned + return nil +} +func (s *userHandlerRepoStub) Delete(context.Context, int64) error { return nil } +func (s *userHandlerRepoStub) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) { + if s.user == nil || s.user.AvatarURL == "" { + return nil, nil + } + return &service.UserAvatar{ + StorageProvider: s.user.AvatarSource, + URL: s.user.AvatarURL, + ContentType: s.user.AvatarMIME, + ByteSize: s.user.AvatarByteSize, + SHA256: s.user.AvatarSHA256, + }, nil +} +func (s *userHandlerRepoStub) UpsertUserAvatar(_ context.Context, _ int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) { + s.user.AvatarURL = input.URL + s.user.AvatarSource = input.StorageProvider + s.user.AvatarMIME = input.ContentType + s.user.AvatarByteSize = input.ByteSize + s.user.AvatarSHA256 = input.SHA256 + return &service.UserAvatar{ + StorageProvider: input.StorageProvider, + URL: input.URL, + ContentType: input.ContentType, + ByteSize: input.ByteSize, + SHA256: input.SHA256, + }, nil +} +func (s *userHandlerRepoStub) DeleteUserAvatar(context.Context, int64) error { + s.user.AvatarURL = "" + s.user.AvatarSource = "" + s.user.AvatarMIME = "" + s.user.AvatarByteSize = 0 + s.user.AvatarSHA256 = "" + return nil +} +func (s *userHandlerRepoStub) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *userHandlerRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil } +func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil } +func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil } +func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil } +func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { + return 0, nil +} +func (s *userHandlerRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error { + return nil +} +func (s *userHandlerRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { + return nil +} +func (s *userHandlerRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil } +func (s *userHandlerRepoStub) EnableTotp(context.Context, int64) error { return nil } +func (s *userHandlerRepoStub) DisableTotp(context.Context, int64) error { return nil } + +func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) { + gin.SetMode(gin.TestMode) + + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 11, + Email: "handler-avatar@example.com", + Username: "handler-avatar", + Role: service.RoleUser, + Status: service.StatusActive, + }, + } + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil) + + body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/user", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11}) + + handler.UpdateProfile(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + AvatarURL string `json:"avatar_url"` + Username string `json:"username"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Equal(t, "https://cdn.example.com/avatar.png", resp.Data.AvatarURL) + require.Equal(t, "handler-avatar", resp.Data.Username) +} diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 38ea9bde..36d80309 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -149,6 +149,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se user.FieldBalanceNotifyThreshold, user.FieldBalanceNotifyExtraEmails, user.FieldTotalRecharged, + user.FieldSignupSource, + user.FieldLastLoginAt, + user.FieldLastActiveAt, ) }). WithGroup(func(q *dbent.GroupQuery) { @@ -656,6 +659,9 @@ func userEntityToService(u *dbent.User) *service.User { Balance: u.Balance, Concurrency: u.Concurrency, Status: u.Status, + SignupSource: u.SignupSource, + LastLoginAt: u.LastLoginAt, + LastActiveAt: u.LastActiveAt, TotpSecretEncrypted: u.TotpSecretEncrypted, TotpEnabled: u.TotpEnabled, TotpEnabledAt: u.TotpEnabledAt, diff --git a/backend/internal/repository/auth_identity_migration_report.go b/backend/internal/repository/auth_identity_migration_report.go new file mode 100644 index 00000000..70f298c1 --- /dev/null +++ b/backend/internal/repository/auth_identity_migration_report.go @@ -0,0 +1,148 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + "time" +) + +type AuthIdentityMigrationReport struct { + ID int64 + ReportType string + ReportKey string + Details map[string]any + CreatedAt time.Time +} + +type AuthIdentityMigrationReportQuery struct { + ReportType string + Limit int + Offset int +} + +type AuthIdentityMigrationReportSummary struct { + Total int64 + ByType map[string]int64 +} + +func (r *userRepository) ListAuthIdentityMigrationReports(ctx context.Context, query AuthIdentityMigrationReportQuery) ([]AuthIdentityMigrationReport, error) { + exec := txAwareSQLExecutor(ctx, r.sql, r.client) + if exec == nil { + return nil, fmt.Errorf("sql executor is not configured") + } + + limit := query.Limit + if limit <= 0 { + limit = 100 + } + rows, err := exec.QueryContext(ctx, ` +SELECT id, report_type, report_key, details, created_at +FROM auth_identity_migration_reports +WHERE ($1 = '' OR report_type = $1) +ORDER BY created_at DESC, id DESC +LIMIT $2 OFFSET $3`, + strings.TrimSpace(query.ReportType), + limit, + query.Offset, + ) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + reports := make([]AuthIdentityMigrationReport, 0) + for rows.Next() { + report, scanErr := scanAuthIdentityMigrationReport(rows) + if scanErr != nil { + return nil, scanErr + } + reports = append(reports, report) + } + if err := rows.Err(); err != nil { + return nil, err + } + return reports, nil +} + +func (r *userRepository) GetAuthIdentityMigrationReport(ctx context.Context, reportType, reportKey string) (*AuthIdentityMigrationReport, error) { + exec := txAwareSQLExecutor(ctx, r.sql, r.client) + if exec == nil { + return nil, fmt.Errorf("sql executor is not configured") + } + + rows, err := exec.QueryContext(ctx, ` +SELECT id, report_type, report_key, details, created_at +FROM auth_identity_migration_reports +WHERE report_type = $1 AND report_key = $2 +LIMIT 1`, + strings.TrimSpace(reportType), + strings.TrimSpace(reportKey), + ) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + return nil, sql.ErrNoRows + } + report, err := scanAuthIdentityMigrationReport(rows) + if err != nil { + return nil, err + } + return &report, rows.Err() +} + +func (r *userRepository) SummarizeAuthIdentityMigrationReports(ctx context.Context) (*AuthIdentityMigrationReportSummary, error) { + exec := txAwareSQLExecutor(ctx, r.sql, r.client) + if exec == nil { + return nil, fmt.Errorf("sql executor is not configured") + } + + rows, err := exec.QueryContext(ctx, ` +SELECT report_type, COUNT(*) +FROM auth_identity_migration_reports +GROUP BY report_type +ORDER BY report_type ASC`) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + summary := &AuthIdentityMigrationReportSummary{ + ByType: make(map[string]int64), + } + for rows.Next() { + var reportType string + var count int64 + if err := rows.Scan(&reportType, &count); err != nil { + return nil, err + } + summary.ByType[reportType] = count + summary.Total += count + } + if err := rows.Err(); err != nil { + return nil, err + } + return summary, nil +} + +func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error }) (AuthIdentityMigrationReport, error) { + var ( + report AuthIdentityMigrationReport + details []byte + ) + if err := scanner.Scan(&report.ID, &report.ReportType, &report.ReportKey, &details, &report.CreatedAt); err != nil { + return AuthIdentityMigrationReport{}, err + } + report.Details = map[string]any{} + if len(details) > 0 { + if err := json.Unmarshal(details, &report.Details); err != nil { + return AuthIdentityMigrationReport{}, err + } + } + return report, nil +} diff --git a/backend/internal/repository/user_profile_identity_repo.go b/backend/internal/repository/user_profile_identity_repo.go new file mode 100644 index 00000000..4ecae4a4 --- /dev/null +++ b/backend/internal/repository/user_profile_identity_repo.go @@ -0,0 +1,544 @@ +package repository + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "strings" + "time" + "unsafe" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +var ( + ErrAuthIdentityOwnershipConflict = infraerrors.Conflict( + "AUTH_IDENTITY_OWNERSHIP_CONFLICT", + "auth identity already belongs to another user", + ) + ErrAuthIdentityChannelOwnershipConflict = infraerrors.Conflict( + "AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", + "auth identity channel already belongs to another user", + ) +) + +type ProviderGrantReason string + +const ( + ProviderGrantReasonSignup ProviderGrantReason = "signup" + ProviderGrantReasonFirstBind ProviderGrantReason = "first_bind" +) + +type AuthIdentityKey struct { + ProviderType string + ProviderKey string + ProviderSubject string +} + +type AuthIdentityChannelKey struct { + ProviderType string + ProviderKey string + Channel string + ChannelAppID string + ChannelSubject string +} + +type CreateAuthIdentityInput struct { + UserID int64 + Canonical AuthIdentityKey + Channel *AuthIdentityChannelKey + Issuer *string + VerifiedAt *time.Time + Metadata map[string]any + ChannelMetadata map[string]any +} + +type BindAuthIdentityInput = CreateAuthIdentityInput + +type CreateAuthIdentityResult struct { + Identity *dbent.AuthIdentity + Channel *dbent.AuthIdentityChannel +} + +func (r *CreateAuthIdentityResult) IdentityRef() AuthIdentityKey { + if r == nil || r.Identity == nil { + return AuthIdentityKey{} + } + return AuthIdentityKey{ + ProviderType: r.Identity.ProviderType, + ProviderKey: r.Identity.ProviderKey, + ProviderSubject: r.Identity.ProviderSubject, + } +} + +func (r *CreateAuthIdentityResult) ChannelRef() *AuthIdentityChannelKey { + if r == nil || r.Channel == nil { + return nil + } + return &AuthIdentityChannelKey{ + ProviderType: r.Channel.ProviderType, + ProviderKey: r.Channel.ProviderKey, + Channel: r.Channel.Channel, + ChannelAppID: r.Channel.ChannelAppID, + ChannelSubject: r.Channel.ChannelSubject, + } +} + +type UserAuthIdentityLookup struct { + User *dbent.User + Identity *dbent.AuthIdentity + Channel *dbent.AuthIdentityChannel +} + +type ProviderGrantRecordInput struct { + UserID int64 + ProviderType string + GrantReason ProviderGrantReason +} + +type IdentityAdoptionDecisionInput struct { + PendingAuthSessionID int64 + IdentityID *int64 + AdoptDisplayName bool + AdoptAvatar bool +} + +type sqlQueryExecutor interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) +} + +func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error { + if dbent.TxFromContext(ctx) != nil { + return fn(ctx) + } + + tx, err := r.client.Tx(ctx) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + + txCtx := dbent.NewTxContext(ctx, tx) + if err := fn(txCtx); err != nil { + return err + } + return tx.Commit() +} + +func (r *userRepository) CreateAuthIdentity(ctx context.Context, input CreateAuthIdentityInput) (*CreateAuthIdentityResult, error) { + client := clientFromContext(ctx, r.client) + + create := client.AuthIdentity.Create(). + SetUserID(input.UserID). + SetProviderType(strings.TrimSpace(input.Canonical.ProviderType)). + SetProviderKey(strings.TrimSpace(input.Canonical.ProviderKey)). + SetProviderSubject(strings.TrimSpace(input.Canonical.ProviderSubject)). + SetMetadata(copyMetadata(input.Metadata)). + SetNillableIssuer(input.Issuer). + SetNillableVerifiedAt(input.VerifiedAt) + + identity, err := create.Save(ctx) + if err != nil { + return nil, err + } + + var channel *dbent.AuthIdentityChannel + if input.Channel != nil { + channel, err = client.AuthIdentityChannel.Create(). + SetIdentityID(identity.ID). + SetProviderType(strings.TrimSpace(input.Channel.ProviderType)). + SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)). + SetChannel(strings.TrimSpace(input.Channel.Channel)). + SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)). + SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)). + SetMetadata(copyMetadata(input.ChannelMetadata)). + Save(ctx) + if err != nil { + return nil, err + } + } + + return &CreateAuthIdentityResult{Identity: identity, Channel: channel}, nil +} + +func (r *userRepository) GetUserByCanonicalIdentity(ctx context.Context, key AuthIdentityKey) (*UserAuthIdentityLookup, error) { + identity, err := clientFromContext(ctx, r.client).AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)), + authidentity.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)), + authidentity.ProviderSubjectEQ(strings.TrimSpace(key.ProviderSubject)), + ). + WithUser(). + Only(ctx) + if err != nil { + return nil, err + } + + return &UserAuthIdentityLookup{ + User: identity.Edges.User, + Identity: identity, + }, nil +} + +func (r *userRepository) GetUserByChannelIdentity(ctx context.Context, key AuthIdentityChannelKey) (*UserAuthIdentityLookup, error) { + channel, err := clientFromContext(ctx, r.client).AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)), + authidentitychannel.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)), + authidentitychannel.ChannelEQ(strings.TrimSpace(key.Channel)), + authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(key.ChannelAppID)), + authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(key.ChannelSubject)), + ). + WithIdentity(func(q *dbent.AuthIdentityQuery) { + q.WithUser() + }). + Only(ctx) + if err != nil { + return nil, err + } + + return &UserAuthIdentityLookup{ + User: channel.Edges.Identity.Edges.User, + Identity: channel.Edges.Identity, + Channel: channel, + }, nil +} + +func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindAuthIdentityInput) (*CreateAuthIdentityResult, error) { + var result *CreateAuthIdentityResult + err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error { + client := clientFromContext(txCtx, r.client) + canonical := input.Canonical + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(strings.TrimSpace(canonical.ProviderType)), + authidentity.ProviderKeyEQ(strings.TrimSpace(canonical.ProviderKey)), + authidentity.ProviderSubjectEQ(strings.TrimSpace(canonical.ProviderSubject)), + ). + Only(txCtx) + if err != nil && !dbent.IsNotFound(err) { + return err + } + if identity != nil && identity.UserID != input.UserID { + return ErrAuthIdentityOwnershipConflict + } + if identity == nil { + identity, err = client.AuthIdentity.Create(). + SetUserID(input.UserID). + SetProviderType(strings.TrimSpace(canonical.ProviderType)). + SetProviderKey(strings.TrimSpace(canonical.ProviderKey)). + SetProviderSubject(strings.TrimSpace(canonical.ProviderSubject)). + SetMetadata(copyMetadata(input.Metadata)). + SetNillableIssuer(input.Issuer). + SetNillableVerifiedAt(input.VerifiedAt). + Save(txCtx) + if err != nil { + return err + } + } else { + update := client.AuthIdentity.UpdateOneID(identity.ID) + if input.Metadata != nil { + update = update.SetMetadata(copyMetadata(input.Metadata)) + } + if input.Issuer != nil { + update = update.SetIssuer(strings.TrimSpace(*input.Issuer)) + } + if input.VerifiedAt != nil { + update = update.SetVerifiedAt(*input.VerifiedAt) + } + identity, err = update.Save(txCtx) + if err != nil { + return err + } + } + + var channel *dbent.AuthIdentityChannel + if input.Channel != nil { + channel, err = client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ(strings.TrimSpace(input.Channel.ProviderType)), + authidentitychannel.ProviderKeyEQ(strings.TrimSpace(input.Channel.ProviderKey)), + authidentitychannel.ChannelEQ(strings.TrimSpace(input.Channel.Channel)), + authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(input.Channel.ChannelAppID)), + authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(input.Channel.ChannelSubject)), + ). + WithIdentity(). + Only(txCtx) + if err != nil && !dbent.IsNotFound(err) { + return err + } + if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != input.UserID { + return ErrAuthIdentityChannelOwnershipConflict + } + if channel == nil { + channel, err = client.AuthIdentityChannel.Create(). + SetIdentityID(identity.ID). + SetProviderType(strings.TrimSpace(input.Channel.ProviderType)). + SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)). + SetChannel(strings.TrimSpace(input.Channel.Channel)). + SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)). + SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)). + SetMetadata(copyMetadata(input.ChannelMetadata)). + Save(txCtx) + if err != nil { + return err + } + } else { + update := client.AuthIdentityChannel.UpdateOneID(channel.ID). + SetIdentityID(identity.ID) + if input.ChannelMetadata != nil { + update = update.SetMetadata(copyMetadata(input.ChannelMetadata)) + } + channel, err = update.Save(txCtx) + if err != nil { + return err + } + } + } + + result = &CreateAuthIdentityResult{Identity: identity, Channel: channel} + return nil + }) + if err != nil { + return nil, err + } + return result, nil +} + +func (r *userRepository) RecordProviderGrant(ctx context.Context, input ProviderGrantRecordInput) (bool, error) { + exec := txAwareSQLExecutor(ctx, r.sql, r.client) + if exec == nil { + return false, fmt.Errorf("sql executor is not configured") + } + + result, err := exec.ExecContext(ctx, ` +INSERT INTO user_provider_default_grants (user_id, provider_type, grant_reason) +VALUES ($1, $2, $3) +ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`, + input.UserID, + strings.TrimSpace(input.ProviderType), + string(input.GrantReason), + ) + if err != nil { + return false, err + } + affected, err := result.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +func (r *userRepository) UpsertIdentityAdoptionDecision(ctx context.Context, input IdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) { + client := clientFromContext(ctx, r.client) + current, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, err + } + now := time.Now().UTC() + if current == nil { + create := client.IdentityAdoptionDecision.Create(). + SetPendingAuthSessionID(input.PendingAuthSessionID). + SetAdoptDisplayName(input.AdoptDisplayName). + SetAdoptAvatar(input.AdoptAvatar). + SetDecidedAt(now) + if input.IdentityID != nil { + create = create.SetIdentityID(*input.IdentityID) + } + return create.Save(ctx) + } + + update := client.IdentityAdoptionDecision.UpdateOneID(current.ID). + SetAdoptDisplayName(input.AdoptDisplayName). + SetAdoptAvatar(input.AdoptAvatar) + if input.IdentityID != nil { + update = update.SetIdentityID(*input.IdentityID) + } + return update.Save(ctx) +} + +func (r *userRepository) GetIdentityAdoptionDecisionByPendingAuthSessionID(ctx context.Context, pendingAuthSessionID int64) (*dbent.IdentityAdoptionDecision, error) { + return clientFromContext(ctx, r.client).IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingAuthSessionID)). + Only(ctx) +} + +func (r *userRepository) UpdateUserLastLoginAt(ctx context.Context, userID int64, loginAt time.Time) error { + _, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID). + SetLastLoginAt(loginAt). + Save(ctx) + return err +} + +func (r *userRepository) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error { + _, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID). + SetLastActiveAt(activeAt). + Save(ctx) + return err +} + +func (r *userRepository) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) { + exec, err := r.userProfileIdentitySQL(ctx) + if err != nil { + return nil, err + } + + rows, err := exec.QueryContext(ctx, ` +SELECT storage_provider, storage_key, url, content_type, byte_size, sha256 +FROM user_avatars +WHERE user_id = $1`, userID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + return nil, rows.Err() + } + + var avatar service.UserAvatar + if err := rows.Scan( + &avatar.StorageProvider, + &avatar.StorageKey, + &avatar.URL, + &avatar.ContentType, + &avatar.ByteSize, + &avatar.SHA256, + ); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return &avatar, nil +} + +func (r *userRepository) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) { + exec, err := r.userProfileIdentitySQL(ctx) + if err != nil { + return nil, err + } + + _, err = exec.ExecContext(ctx, ` +INSERT INTO user_avatars (user_id, storage_provider, storage_key, url, content_type, byte_size, sha256, updated_at) +VALUES ($1, $2, $3, $4, $5, $6, $7, NOW()) +ON CONFLICT (user_id) DO UPDATE SET + storage_provider = EXCLUDED.storage_provider, + storage_key = EXCLUDED.storage_key, + url = EXCLUDED.url, + content_type = EXCLUDED.content_type, + byte_size = EXCLUDED.byte_size, + sha256 = EXCLUDED.sha256, + updated_at = NOW()`, + userID, + strings.TrimSpace(input.StorageProvider), + strings.TrimSpace(input.StorageKey), + strings.TrimSpace(input.URL), + strings.TrimSpace(input.ContentType), + input.ByteSize, + strings.TrimSpace(input.SHA256), + ) + if err != nil { + return nil, err + } + + return &service.UserAvatar{ + StorageProvider: strings.TrimSpace(input.StorageProvider), + StorageKey: strings.TrimSpace(input.StorageKey), + URL: strings.TrimSpace(input.URL), + ContentType: strings.TrimSpace(input.ContentType), + ByteSize: input.ByteSize, + SHA256: strings.TrimSpace(input.SHA256), + }, nil +} + +func (r *userRepository) DeleteUserAvatar(ctx context.Context, userID int64) error { + exec, err := r.userProfileIdentitySQL(ctx) + if err != nil { + return err + } + _, err = exec.ExecContext(ctx, `DELETE FROM user_avatars WHERE user_id = $1`, userID) + return err +} + +func (r *userRepository) attachUserAvatar(ctx context.Context, user *service.User) error { + if user == nil { + return nil + } + + avatar, err := r.GetUserAvatar(ctx, user.ID) + if err != nil { + return err + } + if avatar == nil { + return nil + } + + user.AvatarURL = avatar.URL + user.AvatarSource = avatar.StorageProvider + user.AvatarMIME = avatar.ContentType + user.AvatarByteSize = avatar.ByteSize + user.AvatarSHA256 = avatar.SHA256 + return nil +} + +func copyMetadata(in map[string]any) map[string]any { + if len(in) == 0 { + return map[string]any{} + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func txAwareSQLExecutor(ctx context.Context, fallback sqlExecutor, client *dbent.Client) sqlQueryExecutor { + if tx := dbent.TxFromContext(ctx); tx != nil { + if exec := sqlExecutorFromEntClient(tx.Client()); exec != nil { + return exec + } + } + if fallback != nil { + return fallback + } + return sqlExecutorFromEntClient(client) +} + +func (r *userRepository) userProfileIdentitySQL(ctx context.Context) (sqlQueryExecutor, error) { + exec := txAwareSQLExecutor(ctx, r.sql, r.client) + if exec == nil { + return nil, fmt.Errorf("sql executor is not configured") + } + return exec, nil +} + +func sqlExecutorFromEntClient(client *dbent.Client) sqlQueryExecutor { + if client == nil { + return nil + } + + clientValue := reflect.ValueOf(client).Elem() + configValue := clientValue.FieldByName("config") + driverValue := configValue.FieldByName("driver") + if !driverValue.IsValid() { + return nil + } + + driver := reflect.NewAt(driverValue.Type(), unsafe.Pointer(driverValue.UnsafeAddr())).Elem().Interface() + exec, ok := driver.(sqlQueryExecutor) + if !ok { + return nil + } + return exec +} diff --git a/backend/internal/repository/user_profile_identity_repo_contract_test.go b/backend/internal/repository/user_profile_identity_repo_contract_test.go new file mode 100644 index 00000000..19022ec1 --- /dev/null +++ b/backend/internal/repository/user_profile_identity_repo_contract_test.go @@ -0,0 +1,428 @@ +//go:build integration + +package repository + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/suite" +) + +type UserProfileIdentityRepoSuite struct { + suite.Suite + ctx context.Context + client *dbent.Client + repo *userRepository +} + +func TestUserProfileIdentityRepoSuite(t *testing.T) { + suite.Run(t, new(UserProfileIdentityRepoSuite)) +} + +func (s *UserProfileIdentityRepoSuite) SetupTest() { + s.ctx = context.Background() + s.client = testEntClient(s.T()) + s.repo = newUserRepositoryWithSQL(s.client, integrationDB) + + _, err := integrationDB.ExecContext(s.ctx, ` +TRUNCATE TABLE + identity_adoption_decisions, + auth_identity_channels, + auth_identities, + pending_auth_sessions, + auth_identity_migration_reports, + user_provider_default_grants, + user_avatars +RESTART IDENTITY`) + s.Require().NoError(err) +} + +func (s *UserProfileIdentityRepoSuite) mustCreateUser(label string) *dbent.User { + s.T().Helper() + + user, err := s.client.User.Create(). + SetEmail(fmt.Sprintf("%s-%d@example.com", label, time.Now().UnixNano())). + SetPasswordHash("test-password-hash"). + SetRole("user"). + SetStatus("active"). + Save(s.ctx) + s.Require().NoError(err) + return user +} + +func (s *UserProfileIdentityRepoSuite) mustCreatePendingAuthSession(key AuthIdentityKey) *dbent.PendingAuthSession { + s.T().Helper() + + session, err := s.client.PendingAuthSession.Create(). + SetSessionToken(fmt.Sprintf("pending-%d", time.Now().UnixNano())). + SetIntent("bind_current_user"). + SetProviderType(key.ProviderType). + SetProviderKey(key.ProviderKey). + SetProviderSubject(key.ProviderSubject). + SetExpiresAt(time.Now().UTC().Add(15 * time.Minute)). + SetUpstreamIdentityClaims(map[string]any{"provider_subject": key.ProviderSubject}). + SetLocalFlowState(map[string]any{"step": "pending"}). + Save(s.ctx) + s.Require().NoError(err) + return session +} + +func (s *UserProfileIdentityRepoSuite) TestCreateAndLookupCanonicalAndChannelIdentity() { + user := s.mustCreateUser("canonical-channel") + + verifiedAt := time.Now().UTC().Truncate(time.Second) + created, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-123", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + Channel: "mp", + ChannelAppID: "wx-app", + ChannelSubject: "openid-123", + }, + Issuer: stringPtr("https://issuer.example"), + VerifiedAt: &verifiedAt, + Metadata: map[string]any{"unionid": "union-123"}, + ChannelMetadata: map[string]any{"openid": "openid-123"}, + }) + s.Require().NoError(err) + s.Require().NotNil(created.Identity) + s.Require().NotNil(created.Channel) + + canonical, err := s.repo.GetUserByCanonicalIdentity(s.ctx, created.IdentityRef()) + s.Require().NoError(err) + s.Require().Equal(user.ID, canonical.User.ID) + s.Require().Equal(created.Identity.ID, canonical.Identity.ID) + s.Require().Equal("union-123", canonical.Identity.ProviderSubject) + + channel, err := s.repo.GetUserByChannelIdentity(s.ctx, *created.ChannelRef()) + s.Require().NoError(err) + s.Require().Equal(user.ID, channel.User.ID) + s.Require().Equal(created.Identity.ID, channel.Identity.ID) + s.Require().Equal(created.Channel.ID, channel.Channel.ID) +} + +func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_IsIdempotentAndRejectsOtherOwners() { + owner := s.mustCreateUser("owner") + other := s.mustCreateUser("other") + + first, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{ + UserID: owner.ID, + Canonical: AuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + ProviderSubject: "subject-1", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + Channel: "oauth", + ChannelAppID: "linuxdo-web", + ChannelSubject: "subject-1", + }, + Metadata: map[string]any{"username": "first"}, + ChannelMetadata: map[string]any{"scope": "read"}, + }) + s.Require().NoError(err) + + second, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{ + UserID: owner.ID, + Canonical: AuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + ProviderSubject: "subject-1", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + Channel: "oauth", + ChannelAppID: "linuxdo-web", + ChannelSubject: "subject-1", + }, + Metadata: map[string]any{"username": "second"}, + ChannelMetadata: map[string]any{"scope": "write"}, + }) + s.Require().NoError(err) + s.Require().Equal(first.Identity.ID, second.Identity.ID) + s.Require().Equal(first.Channel.ID, second.Channel.ID) + s.Require().Equal("second", second.Identity.Metadata["username"]) + s.Require().Equal("write", second.Channel.Metadata["scope"]) + + _, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{ + UserID: other.ID, + Canonical: AuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + ProviderSubject: "subject-1", + }, + }) + s.Require().ErrorIs(err, ErrAuthIdentityOwnershipConflict) + + _, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{ + UserID: other.ID, + Canonical: AuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + ProviderSubject: "subject-2", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + Channel: "oauth", + ChannelAppID: "linuxdo-web", + ChannelSubject: "subject-1", + }, + }) + s.Require().ErrorIs(err, ErrAuthIdentityChannelOwnershipConflict) +} + +func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_RollsBackIdentityAndGrantOnError() { + user := s.mustCreateUser("tx-rollback") + expectedErr := errors.New("rollback") + + err := s.repo.WithUserProfileIdentityTx(s.ctx, func(txCtx context.Context) error { + _, err := s.repo.CreateAuthIdentity(txCtx, CreateAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-rollback", + }, + }) + s.Require().NoError(err) + + inserted, err := s.repo.RecordProviderGrant(txCtx, ProviderGrantRecordInput{ + UserID: user.ID, + ProviderType: "oidc", + GrantReason: ProviderGrantReasonFirstBind, + }) + s.Require().NoError(err) + s.Require().True(inserted) + return expectedErr + }) + s.Require().ErrorIs(err, expectedErr) + + _, err = s.repo.GetUserByCanonicalIdentity(s.ctx, AuthIdentityKey{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-rollback", + }) + s.Require().True(dbent.IsNotFound(err)) + + var count int + s.Require().NoError(integrationDB.QueryRowContext(s.ctx, ` +SELECT COUNT(*) +FROM user_provider_default_grants +WHERE user_id = $1 AND provider_type = $2 AND grant_reason = $3`, + user.ID, + "oidc", + string(ProviderGrantReasonFirstBind), + ).Scan(&count)) + s.Require().Zero(count) +} + +func (s *UserProfileIdentityRepoSuite) TestRecordProviderGrant_IsIdempotentPerReason() { + user := s.mustCreateUser("grant") + + inserted, err := s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{ + UserID: user.ID, + ProviderType: "wechat", + GrantReason: ProviderGrantReasonFirstBind, + }) + s.Require().NoError(err) + s.Require().True(inserted) + + inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{ + UserID: user.ID, + ProviderType: "wechat", + GrantReason: ProviderGrantReasonFirstBind, + }) + s.Require().NoError(err) + s.Require().False(inserted) + + inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{ + UserID: user.ID, + ProviderType: "wechat", + GrantReason: ProviderGrantReasonSignup, + }) + s.Require().NoError(err) + s.Require().True(inserted) + + var count int + s.Require().NoError(integrationDB.QueryRowContext(s.ctx, ` +SELECT COUNT(*) +FROM user_provider_default_grants +WHERE user_id = $1 AND provider_type = $2`, + user.ID, + "wechat", + ).Scan(&count)) + s.Require().Equal(2, count) +} + +func (s *UserProfileIdentityRepoSuite) TestUpsertIdentityAdoptionDecision_PersistsAndLinksIdentity() { + user := s.mustCreateUser("adoption") + identity, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-adoption", + }, + }) + s.Require().NoError(err) + + session := s.mustCreatePendingAuthSession(identity.IdentityRef()) + + first, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + AdoptDisplayName: true, + AdoptAvatar: false, + }) + s.Require().NoError(err) + s.Require().True(first.AdoptDisplayName) + s.Require().False(first.AdoptAvatar) + s.Require().Nil(first.IdentityID) + + second, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + IdentityID: &identity.Identity.ID, + AdoptDisplayName: true, + AdoptAvatar: true, + }) + s.Require().NoError(err) + s.Require().Equal(first.ID, second.ID) + s.Require().NotNil(second.IdentityID) + s.Require().Equal(identity.Identity.ID, *second.IdentityID) + s.Require().True(second.AdoptAvatar) + + loaded, err := s.repo.GetIdentityAdoptionDecisionByPendingAuthSessionID(s.ctx, session.ID) + s.Require().NoError(err) + s.Require().Equal(second.ID, loaded.ID) + s.Require().Equal(identity.Identity.ID, *loaded.IdentityID) +} + +func (s *UserProfileIdentityRepoSuite) TestUserAvatarCRUDAndUserLookup() { + user := s.mustCreateUser("avatar") + + inlineAvatar, err := s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{ + StorageProvider: "inline", + URL: "data:image/png;base64,QUJD", + ContentType: "image/png", + ByteSize: 3, + SHA256: "902fbdd2b1df0c4f70b4a5d23525e932", + }) + s.Require().NoError(err) + s.Require().Equal("inline", inlineAvatar.StorageProvider) + s.Require().Equal("data:image/png;base64,QUJD", inlineAvatar.URL) + + loadedAvatar, err := s.repo.GetUserAvatar(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().NotNil(loadedAvatar) + s.Require().Equal("image/png", loadedAvatar.ContentType) + s.Require().Equal(3, loadedAvatar.ByteSize) + + _, err = s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{ + StorageProvider: "remote_url", + URL: "https://cdn.example.com/avatar.png", + }) + s.Require().NoError(err) + + loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().NotNil(loadedAvatar) + s.Require().Equal("remote_url", loadedAvatar.StorageProvider) + s.Require().Equal("https://cdn.example.com/avatar.png", loadedAvatar.URL) + s.Require().Zero(loadedAvatar.ByteSize) + + s.Require().NoError(s.repo.DeleteUserAvatar(s.ctx, user.ID)) + loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().Nil(loadedAvatar) +} + +func (s *UserProfileIdentityRepoSuite) TestAuthIdentityMigrationReportHelpers_ListAndSummarize() { + _, err := integrationDB.ExecContext(s.ctx, ` +INSERT INTO auth_identity_migration_reports (report_type, report_key, details, created_at) +VALUES + ('wechat_openid_only_requires_remediation', 'u-1', '{"user_id":1}'::jsonb, '2026-04-20T10:00:00Z'), + ('wechat_openid_only_requires_remediation', 'u-2', '{"user_id":2}'::jsonb, '2026-04-20T11:00:00Z'), + ('oidc_synthetic_email_requires_manual_recovery', 'u-3', '{"user_id":3}'::jsonb, '2026-04-20T12:00:00Z')`) + s.Require().NoError(err) + + summary, err := s.repo.SummarizeAuthIdentityMigrationReports(s.ctx) + s.Require().NoError(err) + s.Require().Equal(int64(3), summary.Total) + s.Require().Equal(int64(2), summary.ByType["wechat_openid_only_requires_remediation"]) + s.Require().Equal(int64(1), summary.ByType["oidc_synthetic_email_requires_manual_recovery"]) + + reports, err := s.repo.ListAuthIdentityMigrationReports(s.ctx, AuthIdentityMigrationReportQuery{ + ReportType: "wechat_openid_only_requires_remediation", + Limit: 10, + }) + s.Require().NoError(err) + s.Require().Len(reports, 2) + s.Require().Equal("u-2", reports[0].ReportKey) + s.Require().Equal(float64(2), reports[0].Details["user_id"]) + + report, err := s.repo.GetAuthIdentityMigrationReport(s.ctx, "oidc_synthetic_email_requires_manual_recovery", "u-3") + s.Require().NoError(err) + s.Require().Equal("u-3", report.ReportKey) + s.Require().Equal(float64(3), report.Details["user_id"]) +} + +func (s *UserProfileIdentityRepoSuite) TestUpdateUserLastLoginAndActiveAt_UsesDedicatedColumns() { + user := s.mustCreateUser("activity") + loginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC) + activeAt := loginAt.Add(5 * time.Minute) + + s.Require().NoError(s.repo.UpdateUserLastLoginAt(s.ctx, user.ID, loginAt)) + s.Require().NoError(s.repo.UpdateUserLastActiveAt(s.ctx, user.ID, activeAt)) + + var storedLoginAt sqlNullTime + var storedActiveAt sqlNullTime + s.Require().NoError(integrationDB.QueryRowContext(s.ctx, ` +SELECT last_login_at, last_active_at +FROM users +WHERE id = $1`, + user.ID, + ).Scan(&storedLoginAt, &storedActiveAt)) + s.Require().True(storedLoginAt.Valid) + s.Require().True(storedActiveAt.Valid) + s.Require().True(storedLoginAt.Time.Equal(loginAt)) + s.Require().True(storedActiveAt.Time.Equal(activeAt)) +} + +type sqlNullTime struct { + Time time.Time + Valid bool +} + +func (t *sqlNullTime) Scan(value any) error { + switch v := value.(type) { + case time.Time: + t.Time = v + t.Valid = true + return nil + case nil: + t.Time = time.Time{} + t.Valid = false + return nil + default: + return fmt.Errorf("unsupported scan type %T", value) + } +} + +func stringPtr(v string) *string { + return &v +} diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 913e1c40..0c607ecc 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -64,6 +64,9 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error SetBalance(userIn.Balance). SetConcurrency(userIn.Concurrency). SetStatus(userIn.Status). + SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)). + SetNillableLastLoginAt(userIn.LastLoginAt). + SetNillableLastActiveAt(userIn.LastActiveAt). Save(ctx) if err != nil { return translatePersistenceError(err, nil, service.ErrEmailExists) @@ -151,6 +154,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold). SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)). SetTotalRecharged(userIn.TotalRecharged) + if userIn.SignupSource != "" { + updateOp = updateOp.SetSignupSource(userIn.SignupSource) + } + if userIn.LastLoginAt != nil { + updateOp = updateOp.SetLastLoginAt(*userIn.LastLoginAt) + } + if userIn.LastActiveAt != nil { + updateOp = updateOp.SetLastActiveAt(*userIn.LastActiveAt) + } if userIn.BalanceNotifyThreshold == nil { updateOp = updateOp.ClearBalanceNotifyThreshold() } @@ -300,6 +312,7 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) var field string defaultField := true + nullsLastField := false switch sortBy { case "email": field = dbuser.FieldEmail @@ -322,6 +335,14 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) case "created_at": field = dbuser.FieldCreatedAt defaultField = false + case "last_login_at": + field = dbuser.FieldLastLoginAt + defaultField = false + nullsLastField = true + case "last_active_at": + field = dbuser.FieldLastActiveAt + defaultField = false + nullsLastField = true default: field = dbuser.FieldID } @@ -330,11 +351,23 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) if defaultField && field == dbuser.FieldID { return []func(*entsql.Selector){dbent.Asc(dbuser.FieldID)} } + if nullsLastField { + return []func(*entsql.Selector){ + entsql.OrderByField(field, entsql.OrderNullsLast()).ToFunc(), + dbent.Asc(dbuser.FieldID), + } + } return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbuser.FieldID)} } if defaultField && field == dbuser.FieldID { return []func(*entsql.Selector){dbent.Desc(dbuser.FieldID)} } + if nullsLastField { + return []func(*entsql.Selector){ + entsql.OrderByField(field, entsql.OrderDesc(), entsql.OrderNullsLast()).ToFunc(), + dbent.Desc(dbuser.FieldID), + } + } return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbuser.FieldID)} } @@ -558,10 +591,21 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) { return } dst.ID = src.ID + dst.SignupSource = src.SignupSource + dst.LastLoginAt = src.LastLoginAt + dst.LastActiveAt = src.LastActiveAt dst.CreatedAt = src.CreatedAt dst.UpdatedAt = src.UpdatedAt } +func userSignupSourceOrDefault(signupSource string) string { + signupSource = strings.TrimSpace(signupSource) + if signupSource == "" { + return "email" + } + return signupSource +} + // marshalExtraEmails serializes notify email entries to JSON for storage. func marshalExtraEmails(entries []service.NotifyEmailEntry) string { return service.MarshalNotifyEmails(entries) diff --git a/backend/internal/repository/user_repo_sort_integration_test.go b/backend/internal/repository/user_repo_sort_integration_test.go index ab84b0e9..8abef45a 100644 --- a/backend/internal/repository/user_repo_sort_integration_test.go +++ b/backend/internal/repository/user_repo_sort_integration_test.go @@ -4,6 +4,7 @@ package repository import ( "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" @@ -36,4 +37,86 @@ func (s *UserRepoSuite) TestList_DefaultSortByNewestFirst() { s.Require().Equal(first.ID, users[1].ID) } +func (s *UserRepoSuite) TestCreateAndRead_PreservesSignupSourceAndActivityTimestamps() { + lastLoginAt := time.Now().Add(-2 * time.Hour).UTC().Truncate(time.Microsecond) + lastActiveAt := time.Now().Add(-30 * time.Minute).UTC().Truncate(time.Microsecond) + + created := s.mustCreateUser(&service.User{ + Email: "identity-meta@example.com", + SignupSource: "github", + LastLoginAt: &lastLoginAt, + LastActiveAt: &lastActiveAt, + }) + + got, err := s.repo.GetByID(s.ctx, created.ID) + s.Require().NoError(err) + s.Require().Equal("github", got.SignupSource) + s.Require().NotNil(got.LastLoginAt) + s.Require().NotNil(got.LastActiveAt) + s.Require().True(got.LastLoginAt.Equal(lastLoginAt)) + s.Require().True(got.LastActiveAt.Equal(lastActiveAt)) +} + +func (s *UserRepoSuite) TestUpdate_PersistsSignupSourceAndActivityTimestamps() { + created := s.mustCreateUser(&service.User{Email: "identity-update@example.com"}) + lastLoginAt := time.Now().Add(-90 * time.Minute).UTC().Truncate(time.Microsecond) + lastActiveAt := time.Now().Add(-15 * time.Minute).UTC().Truncate(time.Microsecond) + + created.SignupSource = "oidc" + created.LastLoginAt = &lastLoginAt + created.LastActiveAt = &lastActiveAt + + s.Require().NoError(s.repo.Update(s.ctx, created)) + + got, err := s.repo.GetByID(s.ctx, created.ID) + s.Require().NoError(err) + s.Require().Equal("oidc", got.SignupSource) + s.Require().NotNil(got.LastLoginAt) + s.Require().NotNil(got.LastActiveAt) + s.Require().True(got.LastLoginAt.Equal(lastLoginAt)) + s.Require().True(got.LastActiveAt.Equal(lastActiveAt)) +} + +func (s *UserRepoSuite) TestListWithFilters_SortByLastLoginAtDesc() { + older := time.Now().Add(-4 * time.Hour).UTC().Truncate(time.Microsecond) + newer := time.Now().Add(-1 * time.Hour).UTC().Truncate(time.Microsecond) + + s.mustCreateUser(&service.User{Email: "nil-login@example.com"}) + s.mustCreateUser(&service.User{Email: "older-login@example.com", LastLoginAt: &older}) + s.mustCreateUser(&service.User{Email: "newer-login@example.com", LastLoginAt: &newer}) + + users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{ + Page: 1, + PageSize: 10, + SortBy: "last_login_at", + SortOrder: "desc", + }, service.UserListFilters{}) + s.Require().NoError(err) + s.Require().Len(users, 3) + s.Require().Equal("newer-login@example.com", users[0].Email) + s.Require().Equal("older-login@example.com", users[1].Email) + s.Require().Equal("nil-login@example.com", users[2].Email) +} + +func (s *UserRepoSuite) TestListWithFilters_SortByLastActiveAtAsc() { + earlier := time.Now().Add(-3 * time.Hour).UTC().Truncate(time.Microsecond) + later := time.Now().Add(-45 * time.Minute).UTC().Truncate(time.Microsecond) + + s.mustCreateUser(&service.User{Email: "nil-active@example.com"}) + s.mustCreateUser(&service.User{Email: "later-active@example.com", LastActiveAt: &later}) + s.mustCreateUser(&service.User{Email: "earlier-active@example.com", LastActiveAt: &earlier}) + + users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{ + Page: 1, + PageSize: 10, + SortBy: "last_active_at", + SortOrder: "asc", + }, service.UserListFilters{}) + s.Require().NoError(err) + s.Require().Len(users, 3) + s.Require().Equal("earlier-active@example.com", users[0].Email) + s.Require().Equal("later-active@example.com", users[1].Email) + s.Require().Equal("nil-active@example.com", users[2].Email) +} + func TestUserRepoSortSuiteSmoke(_ *testing.T) {} diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index b686b986..e903898f 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -479,7 +479,7 @@ func TestAPIContracts(t *testing.T) { service.SettingKeyOIDCConnectRedirectURL: "", service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback", service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post", - service.SettingKeyOIDCConnectUsePKCE: "false", + service.SettingKeyOIDCConnectUsePKCE: "true", service.SettingKeyOIDCConnectValidateIDToken: "true", service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256", service.SettingKeyOIDCConnectClockSkewSeconds: "120", @@ -549,7 +549,7 @@ func TestAPIContracts(t *testing.T) { "oidc_connect_redirect_url": "", "oidc_connect_frontend_redirect_url": "/auth/oidc/callback", "oidc_connect_token_auth_method": "client_secret_post", - "oidc_connect_use_pkce": false, + "oidc_connect_use_pkce": true, "oidc_connect_validate_id_token": true, "oidc_connect_allowed_signing_algs": "RS256,ES256,PS256", "oidc_connect_clock_skew_seconds": 120, diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index c143b030..911a4064 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -64,12 +64,26 @@ func RegisterAuthRoutes( }), h.Auth.ResetPassword) auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart) auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback) + auth.GET("/oauth/wechat/start", h.Auth.WeChatOAuthStart) + auth.GET("/oauth/wechat/callback", h.Auth.WeChatOAuthCallback) + auth.POST("/oauth/pending/exchange", + rateLimiter.LimitWithOptions("oauth-pending-exchange", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.ExchangePendingOAuthCompletion, + ) auth.POST("/oauth/linuxdo/complete-registration", rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{ FailureMode: middleware.RateLimitFailClose, }), h.Auth.CompleteLinuxDoOAuthRegistration, ) + auth.POST("/oauth/wechat/complete-registration", + rateLimiter.LimitWithOptions("oauth-wechat-complete", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CompleteWeChatOAuthRegistration, + ) auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart) auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback) auth.POST("/oauth/oidc/complete-registration", diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go index 419ddbc3..b802a9c2 100644 --- a/backend/internal/service/admin_service_apikey_test.go +++ b/backend/internal/service/admin_service_apikey_test.go @@ -44,6 +44,15 @@ func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, erro } func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) GetUserAvatar(context.Context, int64) (*UserAvatar, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) UpsertUserAvatar(context.Context, int64, UpsertUserAvatarInput) (*UserAvatar, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) DeleteUserAvatar(context.Context, int64) error { + panic("unexpected") +} func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { panic("unexpected") } diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index fbc856cf..323286b0 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -62,6 +62,18 @@ func (s *userRepoStub) Delete(ctx context.Context, id int64) error { return s.deleteErr } +func (s *userRepoStub) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) { + panic("unexpected GetUserAvatar call") +} + +func (s *userRepoStub) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) { + panic("unexpected UpsertUserAvatar call") +} + +func (s *userRepoStub) DeleteUserAvatar(ctx context.Context, userID int64) error { + panic("unexpected DeleteUserAvatar call") +} + func (s *userRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { panic("unexpected List call") } diff --git a/backend/internal/service/auth_pending_identity_service.go b/backend/internal/service/auth_pending_identity_service.go new file mode 100644 index 00000000..b7e86e12 --- /dev/null +++ b/backend/internal/service/auth_pending_identity_service.go @@ -0,0 +1,326 @@ +package service + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "fmt" + "strings" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +var ( + ErrPendingAuthSessionNotFound = infraerrors.NotFound("PENDING_AUTH_SESSION_NOT_FOUND", "pending auth session not found") + ErrPendingAuthSessionExpired = infraerrors.Unauthorized("PENDING_AUTH_SESSION_EXPIRED", "pending auth session has expired") + ErrPendingAuthSessionConsumed = infraerrors.Unauthorized("PENDING_AUTH_SESSION_CONSUMED", "pending auth session has already been used") + ErrPendingAuthCodeInvalid = infraerrors.Unauthorized("PENDING_AUTH_CODE_INVALID", "pending auth completion code is invalid") + ErrPendingAuthCodeExpired = infraerrors.Unauthorized("PENDING_AUTH_CODE_EXPIRED", "pending auth completion code has expired") + ErrPendingAuthCodeConsumed = infraerrors.Unauthorized("PENDING_AUTH_CODE_CONSUMED", "pending auth completion code has already been used") + ErrPendingAuthBrowserMismatch = infraerrors.Unauthorized("PENDING_AUTH_BROWSER_MISMATCH", "pending auth completion code does not match this browser session") +) + +const ( + defaultPendingAuthTTL = 15 * time.Minute + defaultPendingAuthCompletionTTL = 5 * time.Minute +) + +type PendingAuthIdentityKey struct { + ProviderType string + ProviderKey string + ProviderSubject string +} + +type CreatePendingAuthSessionInput struct { + SessionToken string + Intent string + Identity PendingAuthIdentityKey + TargetUserID *int64 + RedirectTo string + ResolvedEmail string + RegistrationPasswordHash string + BrowserSessionKey string + UpstreamIdentityClaims map[string]any + LocalFlowState map[string]any + ExpiresAt time.Time +} + +type IssuePendingAuthCompletionCodeInput struct { + PendingAuthSessionID int64 + BrowserSessionKey string + TTL time.Duration +} + +type IssuePendingAuthCompletionCodeResult struct { + Code string + ExpiresAt time.Time +} + +type PendingIdentityAdoptionDecisionInput struct { + PendingAuthSessionID int64 + IdentityID *int64 + AdoptDisplayName bool + AdoptAvatar bool +} + +type AuthPendingIdentityService struct { + entClient *dbent.Client +} + +func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService { + return &AuthPendingIdentityService{entClient: entClient} +} + +func (s *AuthPendingIdentityService) CreatePendingSession(ctx context.Context, input CreatePendingAuthSessionInput) (*dbent.PendingAuthSession, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + sessionToken := strings.TrimSpace(input.SessionToken) + if sessionToken == "" { + var err error + sessionToken, err = randomOpaqueToken(24) + if err != nil { + return nil, err + } + } + + expiresAt := input.ExpiresAt.UTC() + if expiresAt.IsZero() { + expiresAt = time.Now().UTC().Add(defaultPendingAuthTTL) + } + + create := s.entClient.PendingAuthSession.Create(). + SetSessionToken(sessionToken). + SetIntent(strings.TrimSpace(input.Intent)). + SetProviderType(strings.TrimSpace(input.Identity.ProviderType)). + SetProviderKey(strings.TrimSpace(input.Identity.ProviderKey)). + SetProviderSubject(strings.TrimSpace(input.Identity.ProviderSubject)). + SetRedirectTo(strings.TrimSpace(input.RedirectTo)). + SetResolvedEmail(strings.TrimSpace(input.ResolvedEmail)). + SetRegistrationPasswordHash(strings.TrimSpace(input.RegistrationPasswordHash)). + SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey)). + SetUpstreamIdentityClaims(copyPendingMap(input.UpstreamIdentityClaims)). + SetLocalFlowState(copyPendingMap(input.LocalFlowState)). + SetExpiresAt(expiresAt) + if input.TargetUserID != nil { + create = create.SetTargetUserID(*input.TargetUserID) + } + return create.Save(ctx) +} + +func (s *AuthPendingIdentityService) IssueCompletionCode(ctx context.Context, input IssuePendingAuthCompletionCodeInput) (*IssuePendingAuthCompletionCodeResult, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + session, err := s.entClient.PendingAuthSession.Get(ctx, input.PendingAuthSessionID) + if err != nil { + if dbent.IsNotFound(err) { + return nil, ErrPendingAuthSessionNotFound + } + return nil, err + } + + code, err := randomOpaqueToken(24) + if err != nil { + return nil, err + } + ttl := input.TTL + if ttl <= 0 { + ttl = defaultPendingAuthCompletionTTL + } + expiresAt := time.Now().UTC().Add(ttl) + + update := s.entClient.PendingAuthSession.UpdateOneID(session.ID). + SetCompletionCodeHash(hashPendingAuthCode(code)). + SetCompletionCodeExpiresAt(expiresAt) + if strings.TrimSpace(input.BrowserSessionKey) != "" { + update = update.SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey)) + } + if _, err := update.Save(ctx); err != nil { + return nil, err + } + + return &IssuePendingAuthCompletionCodeResult{ + Code: code, + ExpiresAt: expiresAt, + }, nil +} + +func (s *AuthPendingIdentityService) ConsumeCompletionCode(ctx context.Context, rawCode, browserSessionKey string) (*dbent.PendingAuthSession, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + codeHash := hashPendingAuthCode(strings.TrimSpace(rawCode)) + session, err := s.entClient.PendingAuthSession.Query(). + Where(pendingauthsession.CompletionCodeHashEQ(codeHash)). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, ErrPendingAuthCodeInvalid + } + return nil, err + } + + return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthCodeExpired, ErrPendingAuthCodeConsumed) +} + +func (s *AuthPendingIdentityService) ConsumeBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + session, err := s.getBrowserSession(ctx, sessionToken) + if err != nil { + return nil, err + } + + return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed) +} + +func (s *AuthPendingIdentityService) GetBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + session, err := s.getBrowserSession(ctx, sessionToken) + if err != nil { + return nil, err + } + if err := validatePendingSessionState(session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed); err != nil { + return nil, err + } + return session, nil +} + +func (s *AuthPendingIdentityService) getBrowserSession(ctx context.Context, sessionToken string) (*dbent.PendingAuthSession, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + sessionToken = strings.TrimSpace(sessionToken) + if sessionToken == "" { + return nil, ErrPendingAuthSessionNotFound + } + + session, err := s.entClient.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(sessionToken)). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, ErrPendingAuthSessionNotFound + } + return nil, err + } + return session, nil +} + +func (s *AuthPendingIdentityService) consumeSession( + ctx context.Context, + session *dbent.PendingAuthSession, + browserSessionKey string, + expiredErr error, + consumedErr error, +) (*dbent.PendingAuthSession, error) { + if err := validatePendingSessionState(session, browserSessionKey, expiredErr, consumedErr); err != nil { + return nil, err + } + + now := time.Now().UTC() + updated, err := s.entClient.PendingAuthSession.UpdateOneID(session.ID). + SetConsumedAt(now). + SetCompletionCodeHash(""). + ClearCompletionCodeExpiresAt(). + Save(ctx) + if err != nil { + return nil, err + } + return updated, nil +} + +func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error { + if session == nil { + return ErrPendingAuthSessionNotFound + } + + now := time.Now().UTC() + if session.ConsumedAt != nil { + return consumedErr + } + if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) { + return expiredErr + } + if session.CompletionCodeExpiresAt != nil && now.After(*session.CompletionCodeExpiresAt) { + return expiredErr + } + if strings.TrimSpace(session.BrowserSessionKey) != "" && strings.TrimSpace(browserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) { + return ErrPendingAuthBrowserMismatch + } + return nil +} + +func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context, input PendingIdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + existing, err := s.entClient.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, err + } + if existing == nil { + create := s.entClient.IdentityAdoptionDecision.Create(). + SetPendingAuthSessionID(input.PendingAuthSessionID). + SetAdoptDisplayName(input.AdoptDisplayName). + SetAdoptAvatar(input.AdoptAvatar). + SetDecidedAt(time.Now().UTC()) + if input.IdentityID != nil { + create = create.SetIdentityID(*input.IdentityID) + } + return create.Save(ctx) + } + + update := s.entClient.IdentityAdoptionDecision.UpdateOneID(existing.ID). + SetAdoptDisplayName(input.AdoptDisplayName). + SetAdoptAvatar(input.AdoptAvatar) + if input.IdentityID != nil { + update = update.SetIdentityID(*input.IdentityID) + } + return update.Save(ctx) +} + +func copyPendingMap(in map[string]any) map[string]any { + if len(in) == 0 { + return map[string]any{} + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func randomOpaqueToken(byteLen int) (string, error) { + if byteLen <= 0 { + byteLen = 16 + } + buf := make([]byte, byteLen) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return hex.EncodeToString(buf), nil +} + +func hashPendingAuthCode(code string) string { + sum := sha256.Sum256([]byte(code)) + return hex.EncodeToString(sum[:]) +} diff --git a/backend/internal/service/auth_pending_identity_service_test.go b/backend/internal/service/auth_pending_identity_service_test.go new file mode 100644 index 00000000..c69ebfd2 --- /dev/null +++ b/backend/internal/service/auth_pending_identity_service_test.go @@ -0,0 +1,224 @@ +//go:build unit + +package service + +import ( + "context" + "database/sql" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func newAuthPendingIdentityServiceTestClient(t *testing.T) (*AuthPendingIdentityService, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:auth_pending_identity_service?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + return NewAuthPendingIdentityService(client), client +} + +func TestAuthPendingIdentityService_CreatePendingSessionStoresSeparatedState(t *testing.T) { + svc, client := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + targetUser, err := client.User.Create(). + SetEmail("pending-target@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "bind_current_user", + Identity: PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-123", + }, + TargetUserID: &targetUser.ID, + RedirectTo: "/profile", + ResolvedEmail: "user@example.com", + BrowserSessionKey: "browser-1", + UpstreamIdentityClaims: map[string]any{"nickname": "wx-user", "avatar_url": "https://cdn.example/avatar.png"}, + LocalFlowState: map[string]any{"step": "email_required"}, + }) + require.NoError(t, err) + require.NotEmpty(t, session.SessionToken) + require.Equal(t, "bind_current_user", session.Intent) + require.Equal(t, "wechat", session.ProviderType) + require.NotNil(t, session.TargetUserID) + require.Equal(t, targetUser.ID, *session.TargetUserID) + require.Equal(t, "wx-user", session.UpstreamIdentityClaims["nickname"]) + require.Equal(t, "email_required", session.LocalFlowState["step"]) +} + +func TestAuthPendingIdentityService_CompletionCodeIsBrowserBoundAndOneTime(t *testing.T) { + svc, _ := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "login", + Identity: PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + ProviderSubject: "subject-1", + }, + BrowserSessionKey: "browser-expected", + UpstreamIdentityClaims: map[string]any{"nickname": "linux-user"}, + LocalFlowState: map[string]any{"step": "pending"}, + }) + require.NoError(t, err) + + issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{ + PendingAuthSessionID: session.ID, + BrowserSessionKey: "browser-expected", + }) + require.NoError(t, err) + require.NotEmpty(t, issued.Code) + + _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-other") + require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch) + + consumed, err := svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected") + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) + require.Empty(t, consumed.CompletionCodeHash) + require.Nil(t, consumed.CompletionCodeExpiresAt) + + _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected") + require.ErrorIs(t, err, ErrPendingAuthCodeInvalid) +} + +func TestAuthPendingIdentityService_CompletionCodeExpires(t *testing.T) { + svc, client := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "login", + Identity: PendingAuthIdentityKey{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-1", + }, + BrowserSessionKey: "browser-expired", + }) + require.NoError(t, err) + + issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{ + PendingAuthSessionID: session.ID, + BrowserSessionKey: "browser-expired", + TTL: time.Second, + }) + require.NoError(t, err) + + _, err = client.PendingAuthSession.UpdateOneID(session.ID). + SetCompletionCodeExpiresAt(time.Now().UTC().Add(-time.Minute)). + Save(ctx) + require.NoError(t, err) + + _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expired") + require.ErrorIs(t, err, ErrPendingAuthCodeExpired) +} + +func TestAuthPendingIdentityService_UpsertAdoptionDecision(t *testing.T) { + svc, client := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("adoption@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + identity, err := client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("wechat"). + SetProviderKey("wechat-open"). + SetProviderSubject("union-adoption"). + SetMetadata(map[string]any{}). + Save(ctx) + require.NoError(t, err) + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "bind_current_user", + Identity: PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-adoption", + }, + }) + require.NoError(t, err) + + first, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + AdoptDisplayName: true, + AdoptAvatar: false, + }) + require.NoError(t, err) + require.True(t, first.AdoptDisplayName) + require.False(t, first.AdoptAvatar) + require.Nil(t, first.IdentityID) + + second, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + IdentityID: &identity.ID, + AdoptDisplayName: true, + AdoptAvatar: true, + }) + require.NoError(t, err) + require.Equal(t, first.ID, second.ID) + require.NotNil(t, second.IdentityID) + require.Equal(t, identity.ID, *second.IdentityID) + require.True(t, second.AdoptAvatar) +} + +func TestAuthPendingIdentityService_ConsumeBrowserSession(t *testing.T) { + svc, _ := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "login", + Identity: PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "subject-session-token", + }, + BrowserSessionKey: "browser-session", + LocalFlowState: map[string]any{ + "completion_response": map[string]any{ + "access_token": "token", + }, + }, + }) + require.NoError(t, err) + + _, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-other") + require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch) + + consumed, err := svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session") + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) + + _, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session") + require.ErrorIs(t, err, ErrPendingAuthSessionConsumed) +} diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index fd28cd42..962009ce 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -13,6 +13,7 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" @@ -106,6 +107,13 @@ func NewAuthService( } } +func (s *AuthService) EntClient() *dbent.Client { + if s == nil { + return nil + } + return s.entClient +} + // Register 用户注册,返回token和用户 func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) { return s.RegisterWithVerification(ctx, email, password, "", "", "") @@ -205,6 +213,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err) return "", nil, ErrServiceUnavailable } + s.postAuthUserBootstrap(ctx, user, "email", true) s.assignDefaultSubscriptions(ctx, user.ID) // 标记邀请码为已使用(如果使用了邀请码) @@ -421,6 +430,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string if !user.IsActive() { return "", nil, ErrUserNotActive } + s.touchUserLogin(ctx, user.ID) // 生成JWT token token, err := s.GenerateToken(user) @@ -501,6 +511,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username } } else { user = newUser + s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true) s.assignDefaultSubscriptions(ctx, user.ID) } } else { @@ -520,6 +531,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err) } } + s.touchUserLogin(ctx, user.ID) token, err := s.GenerateToken(user) if err != nil { @@ -630,6 +642,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return nil, nil, ErrServiceUnavailable } user = newUser + s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true) s.assignDefaultSubscriptions(ctx, user.ID) } } else { @@ -646,6 +659,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema } } else { user = newUser + s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true) s.assignDefaultSubscriptions(ctx, user.ID) if invitationRedeemCode != nil { if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { @@ -670,6 +684,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err) } } + s.touchUserLogin(ctx, user.ID) tokenPair, err := s.GenerateTokenPair(ctx, user, "") if err != nil { @@ -678,63 +693,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return tokenPair, user, nil } -// pendingOAuthTokenTTL is the validity period for pending OAuth tokens. -const pendingOAuthTokenTTL = 10 * time.Minute - -// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens. -const pendingOAuthPurpose = "pending_oauth_registration" - -type pendingOAuthClaims struct { - Email string `json:"email"` - Username string `json:"username"` - Purpose string `json:"purpose"` - jwt.RegisteredClaims -} - -// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity -// while waiting for the user to supply an invitation code. -func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) { - now := time.Now() - claims := &pendingOAuthClaims{ - Email: email, - Username: username, - Purpose: pendingOAuthPurpose, - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)), - IssuedAt: jwt.NewNumericDate(now), - NotBefore: jwt.NewNumericDate(now), - }, - } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - return token.SignedString([]byte(s.cfg.JWT.Secret)) -} - -// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity. -// Returns ErrInvalidToken when the token is invalid or expired. -func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) { - if len(tokenStr) > maxTokenLength { - return "", "", ErrInvalidToken - } - parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) - token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) { - if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) - } - return []byte(s.cfg.JWT.Secret), nil - }) - if parseErr != nil { - return "", "", ErrInvalidToken - } - claims, ok := token.Claims.(*pendingOAuthClaims) - if !ok || !token.Valid { - return "", "", ErrInvalidToken - } - if claims.Purpose != pendingOAuthPurpose { - return "", "", ErrInvalidToken - } - return claims.Email, claims.Username, nil -} - func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) { if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { return @@ -752,6 +710,95 @@ func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int } } +func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) { + if user == nil || user.ID <= 0 { + return + } + + if strings.TrimSpace(signupSource) == "" { + signupSource = "email" + } + s.updateUserSignupSource(ctx, user.ID, signupSource) + + if signupSource == "email" { + s.ensureEmailAuthIdentity(ctx, user) + } + if touchLogin { + s.touchUserLogin(ctx, user.ID) + } +} + +func (s *AuthService) updateUserSignupSource(ctx context.Context, userID int64, signupSource string) { + if s == nil || s.entClient == nil || userID <= 0 { + return + } + if strings.TrimSpace(signupSource) == "" { + return + } + if err := s.entClient.User.UpdateOneID(userID). + SetSignupSource(signupSource). + Exec(ctx); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to update signup source: user_id=%d source=%s err=%v", userID, signupSource, err) + } +} + +func (s *AuthService) touchUserLogin(ctx context.Context, userID int64) { + if s == nil || s.entClient == nil || userID <= 0 { + return + } + now := time.Now().UTC() + if err := s.entClient.User.UpdateOneID(userID). + SetLastLoginAt(now). + SetLastActiveAt(now). + Exec(ctx); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to touch login timestamps: user_id=%d err=%v", userID, err) + } +} + +func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) { + if s == nil || s.entClient == nil || user == nil || user.ID <= 0 { + return + } + + email := strings.ToLower(strings.TrimSpace(user.Email)) + if email == "" || isReservedEmail(email) { + return + } + + if err := s.entClient.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("email"). + SetProviderKey("email"). + SetProviderSubject(email). + SetVerifiedAt(time.Now().UTC()). + SetMetadata(map[string]any{ + "source": "auth_service_dual_write", + }). + OnConflictColumns( + authidentity.FieldProviderType, + authidentity.FieldProviderKey, + authidentity.FieldProviderSubject, + ). + DoNothing(). + Exec(ctx); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err) + } +} + +func inferLegacySignupSource(email string) string { + normalized := strings.ToLower(strings.TrimSpace(email)) + switch { + case strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain): + return "linuxdo" + case strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain): + return "oidc" + case strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain): + return "wechat" + default: + return "email" + } +} + func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error { if s.settingService == nil { return nil @@ -834,7 +881,8 @@ func randomHexString(byteLength int) (string, error) { func isReservedEmail(email string) bool { normalized := strings.ToLower(strings.TrimSpace(email)) return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) || - strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) + strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) || + strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain) } // GenerateToken 生成JWT access token diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go new file mode 100644 index 00000000..5bd2b25d --- /dev/null +++ b/backend/internal/service/auth_service_identity_sync_test.go @@ -0,0 +1,153 @@ +//go:build unit + +package service_test + +import ( + "context" + "database/sql" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/repository" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +type authIdentitySettingRepoStub struct { + values map[string]string +} + +func (s *authIdentitySettingRepoStub) Get(context.Context, string) (*service.Setting, error) { + panic("unexpected Get call") +} + +func (s *authIdentitySettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + if v, ok := s.values[key]; ok { + return v, nil + } + return "", service.ErrSettingNotFound +} + +func (s *authIdentitySettingRepoStub) Set(context.Context, string, string) error { + panic("unexpected Set call") +} + +func (s *authIdentitySettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) { + panic("unexpected GetMultiple call") +} + +func (s *authIdentitySettingRepoStub) SetMultiple(context.Context, map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *authIdentitySettingRepoStub) GetAll(context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *authIdentitySettingRepoStub) Delete(context.Context, string) error { + panic("unexpected Delete call") +} + +func newAuthServiceWithEnt(t *testing.T) (*service.AuthService, service.UserRepository, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:auth_service_identity_sync?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + repo := repository.NewUserRepository(client, db) + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-auth-identity-secret", + ExpireHour: 1, + }, + Default: config.DefaultConfig{ + UserBalance: 3.5, + UserConcurrency: 2, + }, + } + settingSvc := service.NewSettingService(&authIdentitySettingRepoStub{ + values: map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + }, + }, cfg) + + svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, nil) + return svc, repo, client +} + +func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) { + svc, _, client := newAuthServiceWithEnt(t) + ctx := context.Background() + + token, user, err := svc.Register(ctx, "user@example.com", "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, user) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, "email", storedUser.SignupSource) + require.NotNil(t, storedUser.LastLoginAt) + require.NotNil(t, storedUser.LastActiveAt) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("user@example.com"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, user.ID, identity.UserID) + require.NotNil(t, identity.VerifiedAt) +} + +func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) { + svc, repo, client := newAuthServiceWithEnt(t) + ctx := context.Background() + + user := &service.User{ + Email: "login@example.com", + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 1, + Concurrency: 1, + } + require.NoError(t, user.SetPassword("password")) + require.NoError(t, repo.Create(ctx, user)) + + old := time.Now().Add(-2 * time.Hour).UTC().Round(time.Second) + _, err := client.User.UpdateOneID(user.ID). + SetLastLoginAt(old). + SetLastActiveAt(old). + Save(ctx) + require.NoError(t, err) + + token, gotUser, err := svc.Login(ctx, user.Email, "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, gotUser) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.NotNil(t, storedUser.LastLoginAt) + require.NotNil(t, storedUser.LastActiveAt) + require.True(t, storedUser.LastLoginAt.After(old)) + require.True(t, storedUser.LastActiveAt.After(old)) +} diff --git a/backend/internal/service/auth_service_pending_oauth_test.go b/backend/internal/service/auth_service_pending_oauth_test.go deleted file mode 100644 index 0472e06c..00000000 --- a/backend/internal/service/auth_service_pending_oauth_test.go +++ /dev/null @@ -1,146 +0,0 @@ -//go:build unit - -package service - -import ( - "testing" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/golang-jwt/jwt/v5" - "github.com/stretchr/testify/require" -) - -func newAuthServiceForPendingOAuthTest() *AuthService { - cfg := &config.Config{ - JWT: config.JWTConfig{ - Secret: "test-secret-pending-oauth", - ExpireHour: 1, - }, - } - return NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil) -} - -// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。 -func TestVerifyPendingOAuthToken_ValidToken(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - - token, err := svc.CreatePendingOAuthToken("user@example.com", "alice") - require.NoError(t, err) - require.NotEmpty(t, token) - - email, username, err := svc.VerifyPendingOAuthToken(token) - require.NoError(t, err) - require.Equal(t, "user@example.com", email) - require.Equal(t, "alice", username) -} - -// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_RegularJWTRejected(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - - // 签发一个普通 access token(JWTClaims,无 Purpose 字段) - accessToken, err := svc.GenerateToken(&User{ - ID: 1, - Email: "user@example.com", - Role: RoleUser, - }) - require.NoError(t, err) - - _, _, err = svc.VerifyPendingOAuthToken(accessToken) - require.ErrorIs(t, err, ErrInvalidToken) -} - -// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT,应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_WrongPurpose(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - - now := time.Now() - claims := &pendingOAuthClaims{ - Email: "user@example.com", - Username: "alice", - Purpose: "some_other_purpose", - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)), - IssuedAt: jwt.NewNumericDate(now), - NotBefore: jwt.NewNumericDate(now), - }, - } - tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) - require.NoError(t, err) - - _, _, err = svc.VerifyPendingOAuthToken(tokenStr) - require.ErrorIs(t, err, ErrInvalidToken) -} - -// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT(模拟旧 token),应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_MissingPurpose(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - - now := time.Now() - claims := &pendingOAuthClaims{ - Email: "user@example.com", - Username: "alice", - Purpose: "", // 旧 token 无此字段,反序列化后为零值 - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)), - IssuedAt: jwt.NewNumericDate(now), - NotBefore: jwt.NewNumericDate(now), - }, - } - tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) - require.NoError(t, err) - - _, _, err = svc.VerifyPendingOAuthToken(tokenStr) - require.ErrorIs(t, err, ErrInvalidToken) -} - -// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_ExpiredToken(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - - past := time.Now().Add(-1 * time.Hour) - claims := &pendingOAuthClaims{ - Email: "user@example.com", - Username: "alice", - Purpose: pendingOAuthPurpose, - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(past), - IssuedAt: jwt.NewNumericDate(past.Add(-10 * time.Minute)), - NotBefore: jwt.NewNumericDate(past.Add(-10 * time.Minute)), - }, - } - tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) - require.NoError(t, err) - - _, _, err = svc.VerifyPendingOAuthToken(tokenStr) - require.ErrorIs(t, err, ErrInvalidToken) -} - -// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_WrongSecret(t *testing.T) { - other := NewAuthService(nil, nil, nil, nil, &config.Config{ - JWT: config.JWTConfig{Secret: "other-secret"}, - }, nil, nil, nil, nil, nil, nil) - - token, err := other.CreatePendingOAuthToken("user@example.com", "alice") - require.NoError(t, err) - - svc := newAuthServiceForPendingOAuthTest() - _, _, err = svc.VerifyPendingOAuthToken(token) - require.ErrorIs(t, err, ErrInvalidToken) -} - -// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_TooLong(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - giant := make([]byte, maxTokenLength+1) - for i := range giant { - giant[i] = 'a' - } - _, _, err := svc.VerifyPendingOAuthToken(string(giant)) - require.ErrorIs(t, err, ErrInvalidToken) -} diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index cb452efb..1dddf77e 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -74,6 +74,9 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid" // OIDCConnectSyntheticEmailDomain 是 OIDC 用户的合成邮箱后缀(RFC 保留域名)。 const OIDCConnectSyntheticEmailDomain = "@oidc-connect.invalid" +// WeChatConnectSyntheticEmailDomain 是 WeChat Connect 用户的合成邮箱后缀(RFC 保留域名)。 +const WeChatConnectSyntheticEmailDomain = "@wechat-connect.invalid" + // Setting keys const ( // 注册设置 @@ -153,6 +156,29 @@ const ( SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON) + // 第三方认证来源默认授予配置 + SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance" + SettingKeyAuthSourceDefaultEmailConcurrency = "auth_source_default_email_concurrency" + SettingKeyAuthSourceDefaultEmailSubscriptions = "auth_source_default_email_subscriptions" + SettingKeyAuthSourceDefaultEmailGrantOnSignup = "auth_source_default_email_grant_on_signup" + SettingKeyAuthSourceDefaultEmailGrantOnFirstBind = "auth_source_default_email_grant_on_first_bind" + SettingKeyAuthSourceDefaultLinuxDoBalance = "auth_source_default_linuxdo_balance" + SettingKeyAuthSourceDefaultLinuxDoConcurrency = "auth_source_default_linuxdo_concurrency" + SettingKeyAuthSourceDefaultLinuxDoSubscriptions = "auth_source_default_linuxdo_subscriptions" + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup = "auth_source_default_linuxdo_grant_on_signup" + SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind = "auth_source_default_linuxdo_grant_on_first_bind" + SettingKeyAuthSourceDefaultOIDCBalance = "auth_source_default_oidc_balance" + SettingKeyAuthSourceDefaultOIDCConcurrency = "auth_source_default_oidc_concurrency" + SettingKeyAuthSourceDefaultOIDCSubscriptions = "auth_source_default_oidc_subscriptions" + SettingKeyAuthSourceDefaultOIDCGrantOnSignup = "auth_source_default_oidc_grant_on_signup" + SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind = "auth_source_default_oidc_grant_on_first_bind" + SettingKeyAuthSourceDefaultWeChatBalance = "auth_source_default_wechat_balance" + SettingKeyAuthSourceDefaultWeChatConcurrency = "auth_source_default_wechat_concurrency" + SettingKeyAuthSourceDefaultWeChatSubscriptions = "auth_source_default_wechat_subscriptions" + SettingKeyAuthSourceDefaultWeChatGrantOnSignup = "auth_source_default_wechat_grant_on_signup" + SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind = "auth_source_default_wechat_grant_on_first_bind" + SettingKeyForceEmailOnThirdPartySignup = "force_email_on_third_party_signup" + // 管理员 API Key SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成) diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index 6c09e354..09e60220 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -13,14 +13,30 @@ import ( "sync" "sync/atomic" "time" + + "golang.org/x/sync/singleflight" ) const ( openAIAccountScheduleLayerPreviousResponse = "previous_response_id" openAIAccountScheduleLayerSessionSticky = "session_hash" openAIAccountScheduleLayerLoadBalance = "load_balance" + openAIAdvancedSchedulerSettingKey = "openai_advanced_scheduler_enabled" ) +const ( + openAIAdvancedSchedulerSettingCacheTTL = 5 * time.Second + openAIAdvancedSchedulerSettingDBTimeout = 2 * time.Second +) + +type cachedOpenAIAdvancedSchedulerSetting struct { + enabled bool + expiresAt int64 +} + +var openAIAdvancedSchedulerSettingCache atomic.Value // *cachedOpenAIAdvancedSchedulerSetting +var openAIAdvancedSchedulerSettingSF singleflight.Group + type OpenAIAccountScheduleRequest struct { GroupID *int64 SessionHash string @@ -805,10 +821,56 @@ func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountScheduler return snapshot } -func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler { +func (s *OpenAIGatewayService) openAIAdvancedSchedulerSettingRepo() SettingRepository { + if s == nil || s.rateLimitService == nil || s.rateLimitService.settingService == nil { + return nil + } + return s.rateLimitService.settingService.settingRepo +} + +func (s *OpenAIGatewayService) isOpenAIAdvancedSchedulerEnabled(ctx context.Context) bool { + if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil { + if time.Now().UnixNano() < cached.expiresAt { + return cached.enabled + } + } + + result, _, _ := openAIAdvancedSchedulerSettingSF.Do(openAIAdvancedSchedulerSettingKey, func() (any, error) { + if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil { + if time.Now().UnixNano() < cached.expiresAt { + return cached.enabled, nil + } + } + + enabled := false + if repo := s.openAIAdvancedSchedulerSettingRepo(); repo != nil { + dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), openAIAdvancedSchedulerSettingDBTimeout) + defer cancel() + + value, err := repo.GetValue(dbCtx, openAIAdvancedSchedulerSettingKey) + if err == nil { + enabled = strings.EqualFold(strings.TrimSpace(value), "true") + } + } + + openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{ + enabled: enabled, + expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(), + }) + return enabled, nil + }) + + enabled, _ := result.(bool) + return enabled +} + +func (s *OpenAIGatewayService) getOpenAIAccountScheduler(ctx context.Context) OpenAIAccountScheduler { if s == nil { return nil } + if !s.isOpenAIAdvancedSchedulerEnabled(ctx) { + return nil + } s.openaiSchedulerOnce.Do(func() { if s.openaiAccountStats == nil { s.openaiAccountStats = newOpenAIAccountRuntimeStats() @@ -820,6 +882,11 @@ func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountSchedule return s.openaiScheduler } +func resetOpenAIAdvancedSchedulerSettingCacheForTest() { + openAIAdvancedSchedulerSettingCache = atomic.Value{} + openAIAdvancedSchedulerSettingSF = singleflight.Group{} +} + func (s *OpenAIGatewayService) SelectAccountWithScheduler( ctx context.Context, groupID *int64, @@ -830,7 +897,7 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler( requiredTransport OpenAIUpstreamTransport, ) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { decision := OpenAIAccountScheduleDecision{} - scheduler := s.getOpenAIAccountScheduler() + scheduler := s.getOpenAIAccountScheduler(ctx) if scheduler == nil { selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs) decision.Layer = openAIAccountScheduleLayerLoadBalance @@ -856,7 +923,7 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler( } func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) { - scheduler := s.getOpenAIAccountScheduler() + scheduler := s.getOpenAIAccountScheduler(context.Background()) if scheduler == nil { return } @@ -864,7 +931,7 @@ func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64 } func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() { - scheduler := s.getOpenAIAccountScheduler() + scheduler := s.getOpenAIAccountScheduler(context.Background()) if scheduler == nil { return } @@ -872,7 +939,7 @@ func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() { } func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot { - scheduler := s.getOpenAIAccountScheduler() + scheduler := s.getOpenAIAccountScheduler(context.Background()) if scheduler == nil { return OpenAIAccountSchedulerMetricsSnapshot{} } diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go index 088815ed..a54f2614 100644 --- a/backend/internal/service/openai_account_scheduler_test.go +++ b/backend/internal/service/openai_account_scheduler_test.go @@ -2,6 +2,7 @@ package service import ( "context" + "errors" "fmt" "math" "sync" @@ -18,6 +19,202 @@ type openAISnapshotCacheStub struct { accountsByID map[int64]*Account } +type schedulerTestOpenAIAccountRepo struct { + AccountRepository + accounts []Account +} + +func (r schedulerTestOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) { + for i := range r.accounts { + if r.accounts[i].ID == id { + return &r.accounts[i], nil + } + } + return nil, errors.New("account not found") +} + +func (r schedulerTestOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) { + var result []Account + for _, acc := range r.accounts { + if acc.Platform == platform { + result = append(result, acc) + } + } + return result, nil +} + +func (r schedulerTestOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) { + var result []Account + for _, acc := range r.accounts { + if acc.Platform == platform { + result = append(result, acc) + } + } + return result, nil +} + +func (r schedulerTestOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) { + return r.ListSchedulableByPlatform(ctx, platform) +} + +type schedulerTestConcurrencyCache struct { + ConcurrencyCache + loadBatchErr error + loadMap map[int64]*AccountLoadInfo + acquireResults map[int64]bool + waitCounts map[int64]int + skipDefaultLoad bool +} + +func (c schedulerTestConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + if c.acquireResults != nil { + if result, ok := c.acquireResults[accountID]; ok { + return result, nil + } + } + return true, nil +} + +func (c schedulerTestConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error { + return nil +} + +func (c schedulerTestConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { + if c.loadBatchErr != nil { + return nil, c.loadBatchErr + } + out := make(map[int64]*AccountLoadInfo, len(accounts)) + if c.skipDefaultLoad && c.loadMap != nil { + for _, acc := range accounts { + if load, ok := c.loadMap[acc.ID]; ok { + out[acc.ID] = load + } + } + return out, nil + } + for _, acc := range accounts { + if c.loadMap != nil { + if load, ok := c.loadMap[acc.ID]; ok { + out[acc.ID] = load + continue + } + } + out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0} + } + return out, nil +} + +func (c schedulerTestConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { + if c.waitCounts != nil { + if count, ok := c.waitCounts[accountID]; ok { + return count, nil + } + } + return 0, nil +} + +type schedulerTestGatewayCache struct { + sessionBindings map[string]int64 + deletedSessions map[string]int +} + +func (c *schedulerTestGatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) { + if id, ok := c.sessionBindings[sessionHash]; ok { + return id, nil + } + return 0, errors.New("not found") +} + +func (c *schedulerTestGatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error { + if c.sessionBindings == nil { + c.sessionBindings = make(map[string]int64) + } + c.sessionBindings[sessionHash] = accountID + return nil +} + +func (c *schedulerTestGatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error { + return nil +} + +func (c *schedulerTestGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error { + if c.sessionBindings == nil { + return nil + } + if c.deletedSessions == nil { + c.deletedSessions = make(map[string]int) + } + c.deletedSessions[sessionHash]++ + delete(c.sessionBindings, sessionHash) + return nil +} + +func newSchedulerTestOpenAIWSV2Config() *config.Config { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + return cfg +} + +type openAIAdvancedSchedulerSettingRepoStub struct { + values map[string]string +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + value, err := s.GetValue(ctx, key) + if err != nil { + return nil, err + } + return &Setting{Key: key, Value: value}, nil +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + if s == nil || s.values == nil { + return "", ErrSettingNotFound + } + value, ok := s.values[key] + if !ok { + return "", ErrSettingNotFound + } + return value, nil +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) Set(context.Context, string, string) error { + panic("unexpected call to Set") +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) { + panic("unexpected call to GetMultiple") +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) SetMultiple(context.Context, map[string]string) error { + panic("unexpected call to SetMultiple") +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + panic("unexpected call to GetAll") +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) Delete(context.Context, string) error { + panic("unexpected call to Delete") +} + +func newOpenAIAdvancedSchedulerRateLimitService(enabled string) *RateLimitService { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + repo := &openAIAdvancedSchedulerSettingRepoStub{ + values: map[string]string{}, + } + if enabled != "" { + repo.values[openAIAdvancedSchedulerSettingKey] = enabled + } + return &RateLimitService{ + settingService: NewSettingService(repo, &config.Config{}), + } +} + func (s *openAISnapshotCacheStub) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) { if len(s.snapshotAccounts) == 0 { return nil, false, nil @@ -45,6 +242,138 @@ func (s *openAISnapshotCacheStub) GetAccount(ctx context.Context, accountID int6 return &cloned, nil } +func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabledUsesLegacyLoadAwareness(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + ctx := context.Background() + groupID := int64(10106) + accounts := []Account{ + { + ID: 36001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 5, + }, + { + ID: 36002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + } + cfg := &config.Config{} + cfg.Gateway.Scheduling.LoadBatchEnabled = false + cache := &schedulerTestGatewayCache{} + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } + + store := svc.getOpenAIWSStateStore() + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_disabled_001", 36001, time.Hour)) + require.False(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx)) + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "resp_disabled_001", + "", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(36002), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + require.False(t, decision.StickyPreviousHit) +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPreviousResponseRouting(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + ctx := context.Background() + groupID := int64(10107) + accounts := []Account{ + { + ID: 37001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 5, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + }, + { + ID: 37002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + } + cfg := &config.Config{} + cfg.Gateway.Scheduling.LoadBatchEnabled = false + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, + cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } + + store := svc.getOpenAIWSStateStore() + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_enabled_001", 37001, time.Hour)) + require.True(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx)) + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "resp_enabled_001", + "", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(37001), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerPreviousResponse, decision.Layer) + require.True(t, decision.StickyPreviousHit) +} + +func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics_DisabledNoOp(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + svc := &OpenAIGatewayService{} + ttft := 120 + svc.ReportOpenAIAccountScheduleResult(10, true, &ttft) + svc.RecordOpenAIAccountSwitch() + + snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics() + require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot) +} + func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimitedAccountFallsBackToFreshCandidate(t *testing.T) { ctx := context.Background() groupID := int64(10101) @@ -53,10 +382,17 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimite staleBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} freshSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil} freshBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} - cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}} + cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}} snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{31001: freshSticky, 31002: freshBackup}} snapshotService := &SchedulerSnapshotService{cache: snapshotCache} - svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, cache: cache, cfg: &config.Config{}, schedulerSnapshot: snapshotService, concurrencyService: NewConcurrencyService(stubConcurrencyCache{})} + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, + cache: cache, + cfg: &config.Config{}, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + schedulerSnapshot: snapshotService, + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny) require.NoError(t, err) @@ -76,7 +412,12 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa freshSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{stalePrimary, staleSecondary}, accountsByID: map[int64]*Account{32001: freshPrimary, 32002: freshSecondary}} snapshotService := &SchedulerSnapshotService{cache: snapshotCache} - svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, cfg: &config.Config{}, schedulerSnapshot: snapshotService} + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, + cfg: &config.Config{}, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + schedulerSnapshot: snapshotService, + } account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil) require.NoError(t, err) @@ -92,18 +433,19 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeR staleBackup := &Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} dbSticky := Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil} dbBackup := Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} - cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}} + cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}} snapshotCache := &openAISnapshotCacheStub{ snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{33001: staleSticky, 33002: staleBackup}, } snapshotService := &SchedulerSnapshotService{cache: snapshotCache} svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}}, cache: cache, cfg: &config.Config{}, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), schedulerSnapshot: snapshotService, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny) @@ -128,8 +470,9 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeReche } snapshotService := &SchedulerSnapshotService{cache: snapshotCache} svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}}, cfg: &config.Config{}, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), schedulerSnapshot: snapshotService, } @@ -153,7 +496,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky( "openai_apikey_responses_websockets_v2_enabled": true, }, } - cache := &stubGatewayCache{} + cache := &schedulerTestGatewayCache{} cfg := &config.Config{} cfg.Gateway.OpenAIWS.Enabled = true cfg.Gateway.OpenAIWS.OAuthEnabled = true @@ -163,10 +506,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky( cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}}, cache: cache, cfg: cfg, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } store := svc.getOpenAIWSStateStore() @@ -204,17 +548,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testin Schedulable: true, Concurrency: 1, } - cache := &stubGatewayCache{ + cache := &schedulerTestGatewayCache{ sessionBindings: map[string]int64{ "openai:session_hash_abc": account.ID, }, } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}}, cache: cache, cfg: &config.Config{}, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, decision, err := svc.SelectAccountWithScheduler( @@ -260,7 +605,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS Priority: 9, }, } - cache := &stubGatewayCache{ + cache := &schedulerTestGatewayCache{ sessionBindings: map[string]int64{ "openai:session_hash_sticky_busy": 21001, }, @@ -273,7 +618,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS cfg.Gateway.OpenAIWS.OAuthEnabled = true cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - concurrencyCache := stubConcurrencyCache{ + concurrencyCache := schedulerTestConcurrencyCache{ acquireResults: map[int64]bool{ 21001: false, // sticky 账号已满 21002: true, // 若回退负载均衡会命中该账号(本测试要求不能切换) @@ -288,9 +633,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, cache: cache, cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), concurrencyService: NewConcurrencyService(concurrencyCache), } @@ -328,17 +674,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP "openai_ws_force_http": true, }, } - cache := &stubGatewayCache{ + cache := &schedulerTestGatewayCache{ sessionBindings: map[string]int64{ "openai:session_hash_force_http": account.ID, }, } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}}, cache: cache, cfg: &config.Config{}, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, decision, err := svc.SelectAccountWithScheduler( @@ -387,15 +734,15 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick }, }, } - cache := &stubGatewayCache{ + cache := &schedulerTestGatewayCache{ sessionBindings: map[string]int64{ "openai:session_hash_ws_only": 2201, }, } - cfg := newOpenAIWSV2TestConfig() + cfg := newSchedulerTestOpenAIWSV2Config() // 构造“HTTP-only 账号负载更低”的场景,验证 required transport 会强制过滤。 - concurrencyCache := stubConcurrencyCache{ + concurrencyCache := schedulerTestConcurrencyCache{ loadMap: map[int64]*AccountLoadInfo{ 2201: {AccountID: 2201, LoadRate: 0, WaitingCount: 0}, 2202: {AccountID: 2202, LoadRate: 90, WaitingCount: 5}, @@ -403,9 +750,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, cache: cache, cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), concurrencyService: NewConcurrencyService(concurrencyCache), } @@ -445,10 +793,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailabl } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: accounts}, - cache: &stubGatewayCache{}, - cfg: newOpenAIWSV2TestConfig(), - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, + cfg: newSchedulerTestOpenAIWSV2Config(), + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, decision, err := svc.SelectAccountWithScheduler( @@ -507,7 +856,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.2 cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.1 - concurrencyCache := stubConcurrencyCache{ + concurrencyCache := schedulerTestConcurrencyCache{ loadMap: map[int64]*AccountLoadInfo{ 3001: {AccountID: 3001, LoadRate: 95, WaitingCount: 8}, 3002: {AccountID: 3002, LoadRate: 20, WaitingCount: 1}, @@ -520,9 +869,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: accounts}, - cache: &stubGatewayCache{}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), concurrencyService: NewConcurrencyService(concurrencyCache), } @@ -559,16 +909,17 @@ func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) { Schedulable: true, Concurrency: 1, } - cache := &stubGatewayCache{ + cache := &schedulerTestGatewayCache{ sessionBindings: map[string]int64{ "openai:session_hash_metrics": account.ID, }, } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}}, cache: cache, cfg: &config.Config{}, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny) @@ -749,7 +1100,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1 cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1 - concurrencyCache := stubConcurrencyCache{ + concurrencyCache := schedulerTestConcurrencyCache{ loadMap: map[int64]*AccountLoadInfo{ 5101: {AccountID: 5101, LoadRate: 20, WaitingCount: 1}, 5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 1}, @@ -757,9 +1108,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA }, } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: accounts}, - cache: &stubGatewayCache{sessionBindings: map[string]int64{}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{sessionBindings: map[string]int64{}}, cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), concurrencyService: NewConcurrencyService(concurrencyCache), } @@ -905,12 +1257,14 @@ func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) { } func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + svc := &OpenAIGatewayService{} ttft := 120 svc.ReportOpenAIAccountScheduleResult(10, true, &ttft) svc.RecordOpenAIAccountSwitch() snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics() - require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1)) + require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot) require.Equal(t, 7, svc.openAIWSLBTopK()) require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL()) @@ -947,7 +1301,7 @@ func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t * require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportHTTPSSE)) require.False(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportResponsesWebsocketV2)) - cfg := newOpenAIWSV2TestConfig() + cfg := newSchedulerTestOpenAIWSV2Config() scheduler.service = &OpenAIGatewayService{cfg: cfg} account := &Account{ ID: 8801, diff --git a/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go b/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go index c5de8203..ddafc6eb 100644 --- a/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go +++ b/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go @@ -38,11 +38,12 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_UsesWSPassthroughSnapsh cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{*account}}, - cache: &stubGatewayCache{}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*account}}, + cache: &schedulerTestGatewayCache{}, cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache}, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, decision, err := svc.SelectAccountWithScheduler( diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go index 59764b29..34462a3a 100644 --- a/backend/internal/service/payment_config_service.go +++ b/backend/internal/service/payment_config_service.go @@ -196,12 +196,25 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo SettingHelpImageURL, SettingHelpText, SettingCancelRateLimitOn, SettingCancelRateLimitMax, SettingCancelWindowSize, SettingCancelWindowUnit, SettingCancelWindowMode, + SettingPaymentVisibleMethodAlipayEnabled, SettingPaymentVisibleMethodAlipaySource, + SettingPaymentVisibleMethodWxpayEnabled, SettingPaymentVisibleMethodWxpaySource, } vals, err := s.settingRepo.GetMultiple(ctx, keys) if err != nil { return nil, fmt.Errorf("get payment config settings: %w", err) } cfg := s.parsePaymentConfig(vals) + if s.entClient != nil { + instances, err := s.entClient.PaymentProviderInstance.Query(). + Where(paymentproviderinstance.EnabledEQ(true)). + All(ctx) + if err != nil { + return nil, fmt.Errorf("list enabled provider instances: %w", err) + } + cfg.EnabledTypes = applyVisibleMethodRoutingToEnabledTypes(cfg.EnabledTypes, vals, buildVisibleMethodSourceAvailability(instances)) + } else { + cfg.EnabledTypes = applyVisibleMethodRoutingToEnabledTypes(cfg.EnabledTypes, vals, nil) + } // Load Stripe publishable key from the first enabled Stripe provider instance cfg.StripePublishableKey = s.getStripePublishableKey(ctx) return cfg, nil @@ -234,18 +247,23 @@ func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *Payme cfg.LoadBalanceStrategy = payment.DefaultLoadBalanceStrategy } if raw := vals[SettingEnabledPaymentTypes]; raw != "" { + types := make([]string, 0, len(strings.Split(raw, ","))) for _, t := range strings.Split(raw, ",") { t = strings.TrimSpace(t) if t != "" { - cfg.EnabledTypes = append(cfg.EnabledTypes, t) + types = append(types, t) } } + cfg.EnabledTypes = NormalizeVisibleMethods(types) } return cfg } // getStripePublishableKey finds the publishable key from the first enabled Stripe provider instance. func (s *PaymentConfigService) getStripePublishableKey(ctx context.Context) string { + if s.entClient == nil { + return "" + } instances, err := s.entClient.PaymentProviderInstance.Query(). Where( paymentproviderinstance.EnabledEQ(true), @@ -385,3 +403,79 @@ func pcParseInt(s string, defaultVal int) int { } return v } + +func buildVisibleMethodSourceAvailability(instances []*dbent.PaymentProviderInstance) map[string]bool { + available := make(map[string]bool, 4) + for _, inst := range instances { + switch inst.ProviderKey { + case payment.TypeAlipay: + if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipayDirect) { + available[VisibleMethodSourceOfficialAlipay] = true + } + case payment.TypeWxpay: + if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpayDirect) { + available[VisibleMethodSourceOfficialWechat] = true + } + case payment.TypeEasyPay: + for _, supportedType := range splitTypes(inst.SupportedTypes) { + switch NormalizeVisibleMethod(supportedType) { + case payment.TypeAlipay: + available[VisibleMethodSourceEasyPayAlipay] = true + case payment.TypeWxpay: + available[VisibleMethodSourceEasyPayWechat] = true + } + } + } + } + return available +} + +func applyVisibleMethodRoutingToEnabledTypes(base []string, vals map[string]string, available map[string]bool) []string { + shouldExpose := map[string]bool{ + payment.TypeAlipay: visibleMethodShouldBeExposed(payment.TypeAlipay, vals, available), + payment.TypeWxpay: visibleMethodShouldBeExposed(payment.TypeWxpay, vals, available), + } + + seen := make(map[string]struct{}, len(base)+2) + out := make([]string, 0, len(base)+2) + appendType := func(paymentType string) { + paymentType = NormalizeVisibleMethod(paymentType) + if paymentType == "" { + return + } + if _, ok := seen[paymentType]; ok { + return + } + seen[paymentType] = struct{}{} + out = append(out, paymentType) + } + + for _, paymentType := range base { + visibleMethod := NormalizeVisibleMethod(paymentType) + switch visibleMethod { + case payment.TypeAlipay, payment.TypeWxpay: + if shouldExpose[visibleMethod] { + appendType(visibleMethod) + } + default: + appendType(visibleMethod) + } + } + + for _, visibleMethod := range []string{payment.TypeAlipay, payment.TypeWxpay} { + if shouldExpose[visibleMethod] { + appendType(visibleMethod) + } + } + return out +} + +func visibleMethodShouldBeExposed(method string, vals map[string]string, available map[string]bool) bool { + enabledKey := visibleMethodEnabledSettingKey(method) + sourceKey := visibleMethodSourceSettingKey(method) + if enabledKey == "" || sourceKey == "" || vals[enabledKey] != "true" { + return false + } + source := NormalizeVisibleMethodSource(method, vals[sourceKey]) + return source != "" && available[source] +} diff --git a/backend/internal/service/payment_config_service_test.go b/backend/internal/service/payment_config_service_test.go index 027bb796..10919058 100644 --- a/backend/internal/service/payment_config_service_test.go +++ b/backend/internal/service/payment_config_service_test.go @@ -1,9 +1,17 @@ package service import ( + "context" + "database/sql" "testing" + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" "github.com/Wei-Shaw/sub2api/internal/payment" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" ) func TestPcParseFloat(t *testing.T) { @@ -163,6 +171,20 @@ func TestParsePaymentConfig(t *testing.T) { } }) + t.Run("enabled types are normalized to visible methods and deduplicated", func(t *testing.T) { + t.Parallel() + vals := map[string]string{ + SettingEnabledPaymentTypes: "alipay_direct, alipay, wxpay_direct, wxpay", + } + cfg := svc.parsePaymentConfig(vals) + if len(cfg.EnabledTypes) != 2 { + t.Fatalf("EnabledTypes len = %d, want 2", len(cfg.EnabledTypes)) + } + if cfg.EnabledTypes[0] != "alipay" || cfg.EnabledTypes[1] != "wxpay" { + t.Fatalf("EnabledTypes = %v, want [alipay wxpay]", cfg.EnabledTypes) + } + }) + t.Run("empty enabled types string", func(t *testing.T) { t.Parallel() vals := map[string]string{ @@ -204,3 +226,167 @@ func TestGetBasePaymentType(t *testing.T) { }) } } + +func TestApplyVisibleMethodRoutingToEnabledTypes(t *testing.T) { + t.Parallel() + + base := []string{"alipay", "wxpay", "stripe"} + vals := map[string]string{ + SettingPaymentVisibleMethodAlipayEnabled: "true", + SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceOfficialAlipay, + SettingPaymentVisibleMethodWxpayEnabled: "true", + SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat, + } + available := map[string]bool{ + VisibleMethodSourceOfficialAlipay: true, + VisibleMethodSourceOfficialWechat: false, + } + + got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available) + want := []string{"alipay", "stripe"} + if len(got) != len(want) { + t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got) + } + } +} + +func TestApplyVisibleMethodRoutingAddsConfiguredVisibleMethod(t *testing.T) { + t.Parallel() + + base := []string{"stripe"} + vals := map[string]string{ + SettingPaymentVisibleMethodAlipayEnabled: "true", + SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceEasyPayAlipay, + } + available := map[string]bool{ + VisibleMethodSourceEasyPayAlipay: true, + } + + got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available) + want := []string{"stripe", "alipay"} + if len(got) != len(want) { + t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got) + } + } +} + +func TestBuildVisibleMethodSourceAvailability(t *testing.T) { + t.Parallel() + + instances := []*dbent.PaymentProviderInstance{ + {ProviderKey: payment.TypeAlipay, SupportedTypes: "alipay"}, + {ProviderKey: payment.TypeEasyPay, SupportedTypes: "wxpay_direct, alipay"}, + {ProviderKey: payment.TypeWxpay, SupportedTypes: "wxpay_direct"}, + } + + got := buildVisibleMethodSourceAvailability(instances) + if !got[VisibleMethodSourceOfficialAlipay] { + t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialAlipay) + } + if !got[VisibleMethodSourceEasyPayAlipay] { + t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayAlipay) + } + if !got[VisibleMethodSourceOfficialWechat] { + t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialWechat) + } + if !got[VisibleMethodSourceEasyPayWechat] { + t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayWechat) + } +} + +func TestGetPaymentConfigAppliesVisibleMethodRouting(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeEasyPay). + SetName("EasyPay Alipay"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetEnabled(true). + Save(ctx) + if err != nil { + t.Fatalf("create easypay instance: %v", err) + } + + svc := &PaymentConfigService{ + entClient: client, + settingRepo: &paymentConfigSettingRepoStub{ + values: map[string]string{ + SettingEnabledPaymentTypes: "alipay,wxpay,stripe", + SettingPaymentVisibleMethodAlipayEnabled: "true", + SettingPaymentVisibleMethodAlipaySource: "easypay", + SettingPaymentVisibleMethodWxpayEnabled: "true", + SettingPaymentVisibleMethodWxpaySource: "wxpay", + }, + }, + } + + cfg, err := svc.GetPaymentConfig(ctx) + if err != nil { + t.Fatalf("GetPaymentConfig returned error: %v", err) + } + + want := []string{payment.TypeAlipay, payment.TypeStripe} + if len(cfg.EnabledTypes) != len(want) { + t.Fatalf("EnabledTypes len = %d, want %d (%v)", len(cfg.EnabledTypes), len(want), cfg.EnabledTypes) + } + for i := range want { + if cfg.EnabledTypes[i] != want[i] { + t.Fatalf("EnabledTypes[%d] = %q, want %q (full=%v)", i, cfg.EnabledTypes[i], want[i], cfg.EnabledTypes) + } + } +} + +func newPaymentConfigServiceTestClient(t *testing.T) *dbent.Client { + t.Helper() + + db, err := sql.Open("sqlite", "file:payment_config_service?mode=memory&cache=shared") + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + + if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil { + t.Fatalf("enable foreign keys: %v", err) + } + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + return client +} + +type paymentConfigSettingRepoStub struct { + values map[string]string +} + +func (s *paymentConfigSettingRepoStub) Get(context.Context, string) (*Setting, error) { + return nil, nil +} +func (s *paymentConfigSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + return s.values[key], nil +} +func (s *paymentConfigSettingRepoStub) Set(context.Context, string, string) error { return nil } +func (s *paymentConfigSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + out[key] = s.values[key] + } + return out, nil +} +func (s *paymentConfigSettingRepoStub) SetMultiple(context.Context, map[string]string) error { + return nil +} +func (s *paymentConfigSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + return s.values, nil +} +func (s *paymentConfigSettingRepoStub) Delete(context.Context, string) error { return nil } diff --git a/backend/internal/service/payment_resume_service.go b/backend/internal/service/payment_resume_service.go new file mode 100644 index 00000000..894a8198 --- /dev/null +++ b/backend/internal/service/payment_resume_service.go @@ -0,0 +1,248 @@ +package service + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "net/url" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/payment" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +const ( + PaymentSourceHostedRedirect = "hosted_redirect" + PaymentSourceWechatInAppResume = "wechat_in_app_resume" + + paymentResumeFallbackSigningKey = "sub2api-payment-resume" + + SettingPaymentVisibleMethodAlipaySource = "payment_visible_method_alipay_source" + SettingPaymentVisibleMethodWxpaySource = "payment_visible_method_wxpay_source" + SettingPaymentVisibleMethodAlipayEnabled = "payment_visible_method_alipay_enabled" + SettingPaymentVisibleMethodWxpayEnabled = "payment_visible_method_wxpay_enabled" + + VisibleMethodSourceOfficialAlipay = "official_alipay" + VisibleMethodSourceEasyPayAlipay = "easypay_alipay" + VisibleMethodSourceOfficialWechat = "official_wxpay" + VisibleMethodSourceEasyPayWechat = "easypay_wxpay" +) + +type ResumeTokenClaims struct { + OrderID int64 `json:"oid"` + UserID int64 `json:"uid,omitempty"` + ProviderInstanceID string `json:"pi,omitempty"` + ProviderKey string `json:"pk,omitempty"` + PaymentType string `json:"pt,omitempty"` + CanonicalReturnURL string `json:"ru,omitempty"` + IssuedAt int64 `json:"iat"` +} + +type PaymentResumeService struct { + signingKey []byte +} + +type visibleMethodLoadBalancer struct { + inner payment.LoadBalancer + configService *PaymentConfigService +} + +func NewPaymentResumeService(signingKey []byte) *PaymentResumeService { + return &PaymentResumeService{signingKey: signingKey} +} + +func NormalizeVisibleMethod(method string) string { + return payment.GetBasePaymentType(strings.TrimSpace(method)) +} + +func NormalizeVisibleMethods(methods []string) []string { + if len(methods) == 0 { + return nil + } + seen := make(map[string]struct{}, len(methods)) + out := make([]string, 0, len(methods)) + for _, method := range methods { + normalized := NormalizeVisibleMethod(method) + if normalized == "" { + continue + } + if _, ok := seen[normalized]; ok { + continue + } + seen[normalized] = struct{}{} + out = append(out, normalized) + } + return out +} + +func NormalizePaymentSource(source string) string { + switch strings.TrimSpace(strings.ToLower(source)) { + case "", PaymentSourceHostedRedirect: + return PaymentSourceHostedRedirect + case "wechat_in_app", "wxpay_resume", PaymentSourceWechatInAppResume: + return PaymentSourceWechatInAppResume + default: + return strings.TrimSpace(strings.ToLower(source)) + } +} + +func NormalizeVisibleMethodSource(method, source string) string { + switch NormalizeVisibleMethod(method) { + case payment.TypeAlipay: + switch strings.TrimSpace(strings.ToLower(source)) { + case VisibleMethodSourceOfficialAlipay, payment.TypeAlipay, payment.TypeAlipayDirect, "official": + return VisibleMethodSourceOfficialAlipay + case VisibleMethodSourceEasyPayAlipay, payment.TypeEasyPay: + return VisibleMethodSourceEasyPayAlipay + } + case payment.TypeWxpay: + switch strings.TrimSpace(strings.ToLower(source)) { + case VisibleMethodSourceOfficialWechat, payment.TypeWxpay, payment.TypeWxpayDirect, "wechat", "official": + return VisibleMethodSourceOfficialWechat + case VisibleMethodSourceEasyPayWechat, payment.TypeEasyPay: + return VisibleMethodSourceEasyPayWechat + } + } + return "" +} + +func VisibleMethodProviderKeyForSource(method, source string) (string, bool) { + switch NormalizeVisibleMethodSource(method, source) { + case VisibleMethodSourceOfficialAlipay: + return payment.TypeAlipay, NormalizeVisibleMethod(method) == payment.TypeAlipay + case VisibleMethodSourceEasyPayAlipay: + return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeAlipay + case VisibleMethodSourceOfficialWechat: + return payment.TypeWxpay, NormalizeVisibleMethod(method) == payment.TypeWxpay + case VisibleMethodSourceEasyPayWechat: + return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeWxpay + default: + return "", false + } +} + +func newVisibleMethodLoadBalancer(inner payment.LoadBalancer, configService *PaymentConfigService) payment.LoadBalancer { + if inner == nil || configService == nil || configService.settingRepo == nil { + return inner + } + return &visibleMethodLoadBalancer{inner: inner, configService: configService} +} + +func (lb *visibleMethodLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) { + return lb.inner.GetInstanceConfig(ctx, instanceID) +} + +func (lb *visibleMethodLoadBalancer) SelectInstance(ctx context.Context, providerKey string, paymentType payment.PaymentType, strategy payment.Strategy, orderAmount float64) (*payment.InstanceSelection, error) { + visibleMethod := NormalizeVisibleMethod(paymentType) + if providerKey != "" || (visibleMethod != payment.TypeAlipay && visibleMethod != payment.TypeWxpay) { + return lb.inner.SelectInstance(ctx, providerKey, paymentType, strategy, orderAmount) + } + + enabledKey := visibleMethodEnabledSettingKey(visibleMethod) + sourceKey := visibleMethodSourceSettingKey(visibleMethod) + vals, err := lb.configService.settingRepo.GetMultiple(ctx, []string{enabledKey, sourceKey}) + if err != nil { + return nil, fmt.Errorf("load visible method routing for %s: %w", visibleMethod, err) + } + if vals[enabledKey] != "true" { + return nil, fmt.Errorf("visible payment method %s is disabled", visibleMethod) + } + + targetProviderKey, ok := VisibleMethodProviderKeyForSource(visibleMethod, vals[sourceKey]) + if !ok { + return nil, fmt.Errorf("visible payment method %s has no valid source", visibleMethod) + } + return lb.inner.SelectInstance(ctx, targetProviderKey, paymentType, strategy, orderAmount) +} + +func visibleMethodEnabledSettingKey(method string) string { + switch NormalizeVisibleMethod(method) { + case payment.TypeAlipay: + return SettingPaymentVisibleMethodAlipayEnabled + case payment.TypeWxpay: + return SettingPaymentVisibleMethodWxpayEnabled + default: + return "" + } +} + +func visibleMethodSourceSettingKey(method string) string { + switch NormalizeVisibleMethod(method) { + case payment.TypeAlipay: + return SettingPaymentVisibleMethodAlipaySource + case payment.TypeWxpay: + return SettingPaymentVisibleMethodWxpaySource + default: + return "" + } +} + +func CanonicalizeReturnURL(raw string) (string, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return "", nil + } + parsed, err := url.Parse(raw) + if err != nil || !parsed.IsAbs() || parsed.Host == "" { + return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be an absolute http/https URL") + } + if parsed.Scheme != "http" && parsed.Scheme != "https" { + return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use http or https") + } + parsed.Fragment = "" + if parsed.Path == "" { + parsed.Path = "/" + } + return parsed.String(), nil +} + +func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, error) { + if claims.OrderID <= 0 { + return "", fmt.Errorf("resume token requires order id") + } + if claims.IssuedAt == 0 { + claims.IssuedAt = time.Now().Unix() + } + payload, err := json.Marshal(claims) + if err != nil { + return "", fmt.Errorf("marshal resume claims: %w", err) + } + encodedPayload := base64.RawURLEncoding.EncodeToString(payload) + return encodedPayload + "." + s.sign(encodedPayload), nil +} + +func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, error) { + parts := strings.Split(token, ".") + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed") + } + if !hmac.Equal([]byte(parts[1]), []byte(s.sign(parts[0]))) { + return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch") + } + payload, err := base64.RawURLEncoding.DecodeString(parts[0]) + if err != nil { + return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is malformed") + } + var claims ResumeTokenClaims + if err := json.Unmarshal(payload, &claims); err != nil { + return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is invalid") + } + if claims.OrderID <= 0 { + return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token missing order id") + } + return &claims, nil +} + +func (s *PaymentResumeService) sign(payload string) string { + key := s.signingKey + if len(key) == 0 { + key = []byte(paymentResumeFallbackSigningKey) + } + mac := hmac.New(sha256.New, key) + _, _ = mac.Write([]byte(payload)) + return base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) +} diff --git a/backend/internal/service/payment_resume_service_test.go b/backend/internal/service/payment_resume_service_test.go new file mode 100644 index 00000000..e56b4a88 --- /dev/null +++ b/backend/internal/service/payment_resume_service_test.go @@ -0,0 +1,240 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/payment" +) + +func TestNormalizeVisibleMethods(t *testing.T) { + t.Parallel() + + got := NormalizeVisibleMethods([]string{ + "alipay_direct", + "alipay", + " wxpay_direct ", + "wxpay", + "stripe", + }) + + want := []string{"alipay", "wxpay", "stripe"} + if len(got) != len(want) { + t.Fatalf("NormalizeVisibleMethods len = %d, want %d (%v)", len(got), len(want), got) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("NormalizeVisibleMethods[%d] = %q, want %q (full=%v)", i, got[i], want[i], got) + } + } +} + +func TestNormalizePaymentSource(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expect string + }{ + {name: "empty uses default", input: "", expect: PaymentSourceHostedRedirect}, + {name: "wechat alias normalized", input: "wechat_in_app", expect: PaymentSourceWechatInAppResume}, + {name: "canonical value preserved", input: PaymentSourceWechatInAppResume, expect: PaymentSourceWechatInAppResume}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := NormalizePaymentSource(tt.input); got != tt.expect { + t.Fatalf("NormalizePaymentSource(%q) = %q, want %q", tt.input, got, tt.expect) + } + }) + } +} + +func TestCanonicalizeReturnURL(t *testing.T) { + t.Parallel() + + got, err := CanonicalizeReturnURL("https://example.com/pay/result?b=2#a") + if err != nil { + t.Fatalf("CanonicalizeReturnURL returned error: %v", err) + } + if got != "https://example.com/pay/result?b=2" { + t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://example.com/pay/result?b=2") + } +} + +func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) { + t.Parallel() + + if _, err := CanonicalizeReturnURL("/payment/result"); err == nil { + t.Fatal("CanonicalizeReturnURL should reject relative URLs") + } +} + +func TestPaymentResumeTokenRoundTrip(t *testing.T) { + t.Parallel() + + svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := svc.CreateToken(ResumeTokenClaims{ + OrderID: 42, + UserID: 7, + ProviderInstanceID: "19", + ProviderKey: "easypay", + PaymentType: "wxpay", + CanonicalReturnURL: "https://example.com/payment/result", + IssuedAt: 1234567890, + }) + if err != nil { + t.Fatalf("CreateToken returned error: %v", err) + } + + claims, err := svc.ParseToken(token) + if err != nil { + t.Fatalf("ParseToken returned error: %v", err) + } + if claims.OrderID != 42 || claims.UserID != 7 { + t.Fatalf("claims mismatch: %+v", claims) + } + if claims.ProviderInstanceID != "19" || claims.ProviderKey != "easypay" || claims.PaymentType != "wxpay" { + t.Fatalf("claims provider snapshot mismatch: %+v", claims) + } + if claims.CanonicalReturnURL != "https://example.com/payment/result" { + t.Fatalf("claims return URL = %q", claims.CanonicalReturnURL) + } +} + +func TestNormalizeVisibleMethodSource(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + method string + input string + want string + }{ + {name: "alipay official alias", method: payment.TypeAlipay, input: "alipay", want: VisibleMethodSourceOfficialAlipay}, + {name: "alipay easypay alias", method: payment.TypeAlipay, input: "easypay", want: VisibleMethodSourceEasyPayAlipay}, + {name: "wxpay official alias", method: payment.TypeWxpay, input: "wxpay", want: VisibleMethodSourceOfficialWechat}, + {name: "wxpay easypay alias", method: payment.TypeWxpay, input: "easypay", want: VisibleMethodSourceEasyPayWechat}, + {name: "unsupported source", method: payment.TypeWxpay, input: "stripe", want: ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := NormalizeVisibleMethodSource(tt.method, tt.input); got != tt.want { + t.Fatalf("NormalizeVisibleMethodSource(%q, %q) = %q, want %q", tt.method, tt.input, got, tt.want) + } + }) + } +} + +func TestVisibleMethodProviderKeyForSource(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + method string + source string + want string + ok bool + }{ + {name: "official alipay", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialAlipay, want: payment.TypeAlipay, ok: true}, + {name: "easypay alipay", method: payment.TypeAlipay, source: VisibleMethodSourceEasyPayAlipay, want: payment.TypeEasyPay, ok: true}, + {name: "official wechat", method: payment.TypeWxpay, source: VisibleMethodSourceOfficialWechat, want: payment.TypeWxpay, ok: true}, + {name: "easypay wechat", method: payment.TypeWxpay, source: VisibleMethodSourceEasyPayWechat, want: payment.TypeEasyPay, ok: true}, + {name: "mismatched method and source", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialWechat, want: "", ok: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, ok := VisibleMethodProviderKeyForSource(tt.method, tt.source) + if got != tt.want || ok != tt.ok { + t.Fatalf("VisibleMethodProviderKeyForSource(%q, %q) = (%q, %v), want (%q, %v)", tt.method, tt.source, got, ok, tt.want, tt.ok) + } + }) + } +} + +func TestVisibleMethodLoadBalancerUsesConfiguredSource(t *testing.T) { + t.Parallel() + + inner := &captureLoadBalancer{} + configService := &PaymentConfigService{ + settingRepo: &paymentSettingRepoStub{ + values: map[string]string{ + SettingPaymentVisibleMethodAlipayEnabled: "true", + SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceOfficialAlipay, + }, + }, + } + lb := newVisibleMethodLoadBalancer(inner, configService) + + _, err := lb.SelectInstance(context.Background(), "", payment.TypeAlipay, payment.StrategyRoundRobin, 12.5) + if err != nil { + t.Fatalf("SelectInstance returned error: %v", err) + } + if inner.lastProviderKey != payment.TypeAlipay { + t.Fatalf("lastProviderKey = %q, want %q", inner.lastProviderKey, payment.TypeAlipay) + } +} + +func TestVisibleMethodLoadBalancerRejectsDisabledVisibleMethod(t *testing.T) { + t.Parallel() + + inner := &captureLoadBalancer{} + configService := &PaymentConfigService{ + settingRepo: &paymentSettingRepoStub{ + values: map[string]string{ + SettingPaymentVisibleMethodWxpayEnabled: "false", + SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat, + }, + }, + } + lb := newVisibleMethodLoadBalancer(inner, configService) + + if _, err := lb.SelectInstance(context.Background(), "", payment.TypeWxpay, payment.StrategyRoundRobin, 9.9); err == nil { + t.Fatal("SelectInstance should reject disabled visible method") + } +} + +type paymentSettingRepoStub struct { + values map[string]string +} + +func (s *paymentSettingRepoStub) Get(context.Context, string) (*Setting, error) { return nil, nil } +func (s *paymentSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + return s.values[key], nil +} +func (s *paymentSettingRepoStub) Set(context.Context, string, string) error { return nil } +func (s *paymentSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + out[key] = s.values[key] + } + return out, nil +} +func (s *paymentSettingRepoStub) SetMultiple(context.Context, map[string]string) error { return nil } +func (s *paymentSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + return s.values, nil +} +func (s *paymentSettingRepoStub) Delete(context.Context, string) error { return nil } + +type captureLoadBalancer struct { + lastProviderKey string + lastPaymentType string +} + +func (c *captureLoadBalancer) GetInstanceConfig(context.Context, int64) (map[string]string, error) { + return map[string]string{}, nil +} + +func (c *captureLoadBalancer) SelectInstance(_ context.Context, providerKey string, paymentType payment.PaymentType, _ payment.Strategy, _ float64) (*payment.InstanceSelection, error) { + c.lastProviderKey = providerKey + c.lastPaymentType = paymentType + return &payment.InstanceSelection{ProviderKey: providerKey, SupportedTypes: paymentType}, nil +} diff --git a/backend/internal/service/payment_service.go b/backend/internal/service/payment_service.go index 6fc23f97..e897741a 100644 --- a/backend/internal/service/payment_service.go +++ b/backend/internal/service/payment_service.go @@ -65,15 +65,17 @@ func generateRandomString(n int) string { } type CreateOrderRequest struct { - UserID int64 - Amount float64 - PaymentType string - ClientIP string - IsMobile bool - SrcHost string - SrcURL string - OrderType string - PlanID int64 + UserID int64 + Amount float64 + PaymentType string + ClientIP string + IsMobile bool + SrcHost string + SrcURL string + ReturnURL string + PaymentSource string + OrderType string + PlanID int64 } type CreateOrderResponse struct { @@ -88,6 +90,7 @@ type CreateOrderResponse struct { ClientSecret string `json:"client_secret,omitempty"` ExpiresAt time.Time `json:"expires_at"` PaymentMode string `json:"payment_mode,omitempty"` + ResumeToken string `json:"resume_token,omitempty"` } type OrderListParams struct { @@ -165,10 +168,13 @@ type PaymentService struct { configService *PaymentConfigService userRepo UserRepository groupRepo GroupRepository + resumeService *PaymentResumeService } func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService { - return &PaymentService{entClient: entClient, registry: registry, loadBalancer: loadBalancer, redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo} + svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo} + svc.resumeService = NewPaymentResumeService(psResumeSigningKey(configService)) + return svc } // --- Provider Registry --- @@ -262,6 +268,20 @@ func psNilIfEmpty(s string) *string { return &s } +func (s *PaymentService) paymentResume() *PaymentResumeService { + if s.resumeService != nil { + return s.resumeService + } + return NewPaymentResumeService(psResumeSigningKey(s.configService)) +} + +func psResumeSigningKey(configService *PaymentConfigService) []byte { + if configService == nil { + return nil + } + return configService.encryptionKey +} + func psSliceContains(sl []string, s string) bool { for _, v := range sl { if v == s { diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 7f4a2eb1..de555478 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -9,6 +9,7 @@ import ( "fmt" "log/slog" "net/url" + "os" "sort" "strconv" "strings" @@ -114,6 +115,66 @@ type SettingService struct { webSearchManagerBuilder WebSearchManagerBuilder } +type ProviderDefaultGrantSettings struct { + Balance float64 + Concurrency int + Subscriptions []DefaultSubscriptionSetting + GrantOnSignup bool + GrantOnFirstBind bool +} + +type AuthSourceDefaultSettings struct { + Email ProviderDefaultGrantSettings + LinuxDo ProviderDefaultGrantSettings + OIDC ProviderDefaultGrantSettings + WeChat ProviderDefaultGrantSettings + ForceEmailOnThirdPartySignup bool +} + +type authSourceDefaultKeySet struct { + balance string + concurrency string + subscriptions string + grantOnSignup string + grantOnFirstBind string +} + +var ( + emailAuthSourceDefaultKeys = authSourceDefaultKeySet{ + balance: SettingKeyAuthSourceDefaultEmailBalance, + concurrency: SettingKeyAuthSourceDefaultEmailConcurrency, + subscriptions: SettingKeyAuthSourceDefaultEmailSubscriptions, + grantOnSignup: SettingKeyAuthSourceDefaultEmailGrantOnSignup, + grantOnFirstBind: SettingKeyAuthSourceDefaultEmailGrantOnFirstBind, + } + linuxDoAuthSourceDefaultKeys = authSourceDefaultKeySet{ + balance: SettingKeyAuthSourceDefaultLinuxDoBalance, + concurrency: SettingKeyAuthSourceDefaultLinuxDoConcurrency, + subscriptions: SettingKeyAuthSourceDefaultLinuxDoSubscriptions, + grantOnSignup: SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup, + grantOnFirstBind: SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind, + } + oidcAuthSourceDefaultKeys = authSourceDefaultKeySet{ + balance: SettingKeyAuthSourceDefaultOIDCBalance, + concurrency: SettingKeyAuthSourceDefaultOIDCConcurrency, + subscriptions: SettingKeyAuthSourceDefaultOIDCSubscriptions, + grantOnSignup: SettingKeyAuthSourceDefaultOIDCGrantOnSignup, + grantOnFirstBind: SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind, + } + weChatAuthSourceDefaultKeys = authSourceDefaultKeySet{ + balance: SettingKeyAuthSourceDefaultWeChatBalance, + concurrency: SettingKeyAuthSourceDefaultWeChatConcurrency, + subscriptions: SettingKeyAuthSourceDefaultWeChatSubscriptions, + grantOnSignup: SettingKeyAuthSourceDefaultWeChatGrantOnSignup, + grantOnFirstBind: SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind, + } +) + +const ( + defaultAuthSourceBalance = 0 + defaultAuthSourceConcurrency = 5 +) + // NewSettingService 创建系统设置服务实例 func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *SettingService { return &SettingService{ @@ -212,6 +273,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings if oidcProviderName == "" { oidcProviderName = "OIDC" } + weChatEnabled := isWeChatOAuthConfigured() // Password reset requires email verification to be enabled emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true" @@ -254,6 +316,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings CustomMenuItems: settings[SettingKeyCustomMenuItems], CustomEndpoints: settings[SettingKeyCustomEndpoints], LinuxDoOAuthEnabled: linuxDoEnabled, + WeChatOAuthEnabled: weChatEnabled, BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", PaymentEnabled: settings[SettingPaymentEnabled] == "true", OIDCOAuthEnabled: oidcEnabled, @@ -310,6 +373,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any CustomMenuItems json.RawMessage `json:"custom_menu_items"` CustomEndpoints json.RawMessage `json:"custom_endpoints"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"` PaymentEnabled bool `json:"payment_enabled"` OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"` @@ -344,6 +408,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, + WeChatOAuthEnabled: settings.WeChatOAuthEnabled, BackendModeEnabled: settings.BackendModeEnabled, PaymentEnabled: settings.PaymentEnabled, OIDCOAuthEnabled: settings.OIDCOAuthEnabled, @@ -392,6 +457,14 @@ func filterUserVisibleMenuItems(raw string) json.RawMessage { return result } +func isWeChatOAuthConfigured() bool { + openConfigured := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID")) != "" && + strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET")) != "" + mpConfigured := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID")) != "" && + strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET")) != "" + return openConfigured || mpConfigured +} + // safeRawJSONArray returns raw as json.RawMessage if it's valid JSON, otherwise "[]". func safeRawJSONArray(raw string) json.RawMessage { raw = strings.TrimSpace(raw) @@ -919,6 +992,74 @@ func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultS return parseDefaultSubscriptions(value) } +func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*AuthSourceDefaultSettings, error) { + keys := []string{ + SettingKeyAuthSourceDefaultEmailBalance, + SettingKeyAuthSourceDefaultEmailConcurrency, + SettingKeyAuthSourceDefaultEmailSubscriptions, + SettingKeyAuthSourceDefaultEmailGrantOnSignup, + SettingKeyAuthSourceDefaultEmailGrantOnFirstBind, + SettingKeyAuthSourceDefaultLinuxDoBalance, + SettingKeyAuthSourceDefaultLinuxDoConcurrency, + SettingKeyAuthSourceDefaultLinuxDoSubscriptions, + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup, + SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind, + SettingKeyAuthSourceDefaultOIDCBalance, + SettingKeyAuthSourceDefaultOIDCConcurrency, + SettingKeyAuthSourceDefaultOIDCSubscriptions, + SettingKeyAuthSourceDefaultOIDCGrantOnSignup, + SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind, + SettingKeyAuthSourceDefaultWeChatBalance, + SettingKeyAuthSourceDefaultWeChatConcurrency, + SettingKeyAuthSourceDefaultWeChatSubscriptions, + SettingKeyAuthSourceDefaultWeChatGrantOnSignup, + SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind, + SettingKeyForceEmailOnThirdPartySignup, + } + + settings, err := s.settingRepo.GetMultiple(ctx, keys) + if err != nil { + return nil, fmt.Errorf("get auth source default settings: %w", err) + } + + return &AuthSourceDefaultSettings{ + Email: parseProviderDefaultGrantSettings(settings, emailAuthSourceDefaultKeys), + LinuxDo: parseProviderDefaultGrantSettings(settings, linuxDoAuthSourceDefaultKeys), + OIDC: parseProviderDefaultGrantSettings(settings, oidcAuthSourceDefaultKeys), + WeChat: parseProviderDefaultGrantSettings(settings, weChatAuthSourceDefaultKeys), + ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true", + }, nil +} + +func (s *SettingService) UpdateAuthSourceDefaultSettings(ctx context.Context, settings *AuthSourceDefaultSettings) error { + if settings == nil { + return nil + } + + for _, subscriptions := range [][]DefaultSubscriptionSetting{ + settings.Email.Subscriptions, + settings.LinuxDo.Subscriptions, + settings.OIDC.Subscriptions, + settings.WeChat.Subscriptions, + } { + if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil { + return err + } + } + + updates := make(map[string]string, 21) + writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email) + writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo) + writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC) + writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat) + updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup) + + if err := s.settingRepo.SetMultiple(ctx, updates); err != nil { + return fmt.Errorf("update auth source default settings: %w", err) + } + return nil +} + // InitializeDefaultSettings 初始化默认设置 func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // 检查是否已有设置 @@ -933,25 +1074,46 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // 初始化默认设置 defaults := map[string]string{ - SettingKeyRegistrationEnabled: "true", - SettingKeyEmailVerifyEnabled: "false", - SettingKeyRegistrationEmailSuffixWhitelist: "[]", - SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能 - SettingKeySiteName: "Sub2API", - SettingKeySiteLogo: "", - SettingKeyPurchaseSubscriptionEnabled: "false", - SettingKeyPurchaseSubscriptionURL: "", - SettingKeyTableDefaultPageSize: "20", - SettingKeyTablePageSizeOptions: "[10,20,50,100]", - SettingKeyCustomMenuItems: "[]", - SettingKeyCustomEndpoints: "[]", - SettingKeyOIDCConnectEnabled: "false", - SettingKeyOIDCConnectProviderName: "OIDC", - SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), - SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), - SettingKeyDefaultSubscriptions: "[]", - SettingKeySMTPPort: "587", - SettingKeySMTPUseTLS: "false", + SettingKeyRegistrationEnabled: "true", + SettingKeyEmailVerifyEnabled: "false", + SettingKeyRegistrationEmailSuffixWhitelist: "[]", + SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能 + SettingKeySiteName: "Sub2API", + SettingKeySiteLogo: "", + SettingKeyPurchaseSubscriptionEnabled: "false", + SettingKeyPurchaseSubscriptionURL: "", + SettingKeyTableDefaultPageSize: "20", + SettingKeyTablePageSizeOptions: "[10,20,50,100]", + SettingKeyCustomMenuItems: "[]", + SettingKeyCustomEndpoints: "[]", + SettingKeyOIDCConnectEnabled: "false", + SettingKeyOIDCConnectProviderName: "OIDC", + SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), + SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), + SettingKeyDefaultSubscriptions: "[]", + SettingKeyAuthSourceDefaultEmailBalance: "0", + SettingKeyAuthSourceDefaultEmailConcurrency: "5", + SettingKeyAuthSourceDefaultEmailSubscriptions: "[]", + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true", + SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false", + SettingKeyAuthSourceDefaultLinuxDoBalance: "0", + SettingKeyAuthSourceDefaultLinuxDoConcurrency: "5", + SettingKeyAuthSourceDefaultLinuxDoSubscriptions: "[]", + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true", + SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "false", + SettingKeyAuthSourceDefaultOIDCBalance: "0", + SettingKeyAuthSourceDefaultOIDCConcurrency: "5", + SettingKeyAuthSourceDefaultOIDCSubscriptions: "[]", + SettingKeyAuthSourceDefaultOIDCGrantOnSignup: "true", + SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "false", + SettingKeyAuthSourceDefaultWeChatBalance: "0", + SettingKeyAuthSourceDefaultWeChatConcurrency: "5", + SettingKeyAuthSourceDefaultWeChatSubscriptions: "[]", + SettingKeyAuthSourceDefaultWeChatGrantOnSignup: "true", + SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind: "false", + SettingKeyForceEmailOnThirdPartySignup: "false", + SettingKeySMTPPort: "587", + SettingKeySMTPUseTLS: "false", // Model fallback defaults SettingKeyEnableModelFallback: "false", SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022", @@ -1164,6 +1326,8 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } else { result.OIDCConnectValidateIDToken = oidcBase.ValidateIDToken } + result.OIDCConnectUsePKCE = true + result.OIDCConnectValidateIDToken = true if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" { result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v) } else { @@ -1317,6 +1481,51 @@ func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting { return normalized } +func parseProviderDefaultGrantSettings(settings map[string]string, keys authSourceDefaultKeySet) ProviderDefaultGrantSettings { + result := ProviderDefaultGrantSettings{ + Balance: defaultAuthSourceBalance, + Concurrency: defaultAuthSourceConcurrency, + Subscriptions: []DefaultSubscriptionSetting{}, + GrantOnSignup: true, + GrantOnFirstBind: false, + } + + if v, err := strconv.ParseFloat(strings.TrimSpace(settings[keys.balance]), 64); err == nil { + result.Balance = v + } + if v, err := strconv.Atoi(strings.TrimSpace(settings[keys.concurrency])); err == nil { + result.Concurrency = v + } + if items := parseDefaultSubscriptions(settings[keys.subscriptions]); items != nil { + result.Subscriptions = items + } + if raw, ok := settings[keys.grantOnSignup]; ok { + result.GrantOnSignup = raw == "true" + } + if raw, ok := settings[keys.grantOnFirstBind]; ok { + result.GrantOnFirstBind = raw == "true" + } + + return result +} + +func writeProviderDefaultGrantUpdates(updates map[string]string, keys authSourceDefaultKeySet, settings ProviderDefaultGrantSettings) { + updates[keys.balance] = strconv.FormatFloat(settings.Balance, 'f', 8, 64) + updates[keys.concurrency] = strconv.Itoa(settings.Concurrency) + + subscriptions := settings.Subscriptions + if subscriptions == nil { + subscriptions = []DefaultSubscriptionSetting{} + } + raw, err := json.Marshal(subscriptions) + if err != nil { + raw = []byte("[]") + } + updates[keys.subscriptions] = string(raw) + updates[keys.grantOnSignup] = strconv.FormatBool(settings.GrantOnSignup) + updates[keys.grantOnFirstBind] = strconv.FormatBool(settings.GrantOnFirstBind) +} + func parseTablePreferences(defaultPageSizeRaw, optionsRaw string) (int, []int) { defaultPageSize := 20 if v, err := strconv.Atoi(strings.TrimSpace(defaultPageSizeRaw)); err == nil { @@ -1539,6 +1748,7 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" { effective.RedirectURL = strings.TrimSpace(v) } + effective.UsePKCE = true if !effective.Enabled { return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled") @@ -1587,9 +1797,6 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured") } case "none": - if !effective.UsePKCE { - return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none") - } default: return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid") } @@ -1737,6 +1944,8 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config. if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok { effective.ValidateIDToken = raw == "true" } + effective.UsePKCE = true + effective.ValidateIDToken = true if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" { effective.AllowedSigningAlgs = strings.TrimSpace(v) } @@ -1864,9 +2073,6 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config. return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured") } case "none": - if !effective.UsePKCE { - return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none") - } default: return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid") } diff --git a/backend/internal/service/setting_service_auth_source_defaults_test.go b/backend/internal/service/setting_service_auth_source_defaults_test.go new file mode 100644 index 00000000..097bf604 --- /dev/null +++ b/backend/internal/service/setting_service_auth_source_defaults_test.go @@ -0,0 +1,136 @@ +//go:build unit + +package service + +import ( + "context" + "encoding/json" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type authSourceDefaultsRepoStub struct { + values map[string]string + updates map[string]string +} + +func (s *authSourceDefaultsRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *authSourceDefaultsRepoStub) GetValue(ctx context.Context, key string) (string, error) { + panic("unexpected GetValue call") +} + +func (s *authSourceDefaultsRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *authSourceDefaultsRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *authSourceDefaultsRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + s.updates = make(map[string]string, len(settings)) + for key, value := range settings { + s.updates[key] = value + if s.values == nil { + s.values = map[string]string{} + } + s.values[key] = value + } + return nil +} + +func (s *authSourceDefaultsRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *authSourceDefaultsRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +func TestSettingService_GetAuthSourceDefaultSettings_ParsesValuesAndDefaults(t *testing.T) { + repo := &authSourceDefaultsRepoStub{ + values: map[string]string{ + SettingKeyAuthSourceDefaultEmailBalance: "12.5", + SettingKeyAuthSourceDefaultEmailConcurrency: "7", + SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", + SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "true", + SettingKeyForceEmailOnThirdPartySignup: "true", + }, + } + svc := NewSettingService(repo, &config.Config{}) + + got, err := svc.GetAuthSourceDefaultSettings(context.Background()) + require.NoError(t, err) + require.Equal(t, 12.5, got.Email.Balance) + require.Equal(t, 7, got.Email.Concurrency) + require.Equal(t, []DefaultSubscriptionSetting{{GroupID: 11, ValidityDays: 30}}, got.Email.Subscriptions) + require.False(t, got.Email.GrantOnSignup) + require.False(t, got.Email.GrantOnFirstBind) + require.Equal(t, 0.0, got.LinuxDo.Balance) + require.Equal(t, 5, got.LinuxDo.Concurrency) + require.Equal(t, []DefaultSubscriptionSetting{}, got.LinuxDo.Subscriptions) + require.True(t, got.LinuxDo.GrantOnSignup) + require.True(t, got.LinuxDo.GrantOnFirstBind) + require.Equal(t, 5, got.OIDC.Concurrency) + require.Equal(t, 5, got.WeChat.Concurrency) + require.True(t, got.ForceEmailOnThirdPartySignup) +} + +func TestSettingService_UpdateAuthSourceDefaultSettings_PersistsAllKeys(t *testing.T) { + repo := &authSourceDefaultsRepoStub{} + svc := NewSettingService(repo, &config.Config{}) + + err := svc.UpdateAuthSourceDefaultSettings(context.Background(), &AuthSourceDefaultSettings{ + Email: ProviderDefaultGrantSettings{ + Balance: 1.25, + Concurrency: 3, + Subscriptions: []DefaultSubscriptionSetting{{GroupID: 21, ValidityDays: 14}}, + GrantOnSignup: false, + GrantOnFirstBind: true, + }, + LinuxDo: ProviderDefaultGrantSettings{ + Balance: 2, + Concurrency: 4, + Subscriptions: []DefaultSubscriptionSetting{{GroupID: 22, ValidityDays: 30}}, + GrantOnSignup: true, + GrantOnFirstBind: false, + }, + OIDC: ProviderDefaultGrantSettings{ + Balance: 3, + Concurrency: 5, + Subscriptions: []DefaultSubscriptionSetting{{GroupID: 23, ValidityDays: 60}}, + GrantOnSignup: true, + GrantOnFirstBind: true, + }, + WeChat: ProviderDefaultGrantSettings{ + Balance: 4, + Concurrency: 6, + Subscriptions: []DefaultSubscriptionSetting{{GroupID: 24, ValidityDays: 90}}, + GrantOnSignup: false, + GrantOnFirstBind: false, + }, + ForceEmailOnThirdPartySignup: true, + }) + require.NoError(t, err) + require.Equal(t, "1.25000000", repo.updates[SettingKeyAuthSourceDefaultEmailBalance]) + require.Equal(t, "3", repo.updates[SettingKeyAuthSourceDefaultEmailConcurrency]) + require.Equal(t, "false", repo.updates[SettingKeyAuthSourceDefaultEmailGrantOnSignup]) + require.Equal(t, "true", repo.updates[SettingKeyAuthSourceDefaultEmailGrantOnFirstBind]) + require.Equal(t, "true", repo.updates[SettingKeyForceEmailOnThirdPartySignup]) + + var got []DefaultSubscriptionSetting + require.NoError(t, json.Unmarshal([]byte(repo.updates[SettingKeyAuthSourceDefaultWeChatSubscriptions]), &got)) + require.Equal(t, []DefaultSubscriptionSetting{{GroupID: 24, ValidityDays: 90}}, got) +} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index ab2eb274..e991ebef 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -152,6 +152,7 @@ type PublicSettings struct { CustomEndpoints string // JSON array of custom endpoints LinuxDoOAuthEnabled bool + WeChatOAuthEnabled bool BackendModeEnabled bool PaymentEnabled bool OIDCOAuthEnabled bool diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go index 59f8aa6b..d8b5325c 100644 --- a/backend/internal/service/user.go +++ b/backend/internal/service/user.go @@ -7,19 +7,27 @@ import ( ) type User struct { - ID int64 - Email string - Username string - Notes string - PasswordHash string - Role string - Balance float64 - Concurrency int - Status string - AllowedGroups []int64 - TokenVersion int64 // Incremented on password change to invalidate existing tokens - CreatedAt time.Time - UpdatedAt time.Time + ID int64 + Email string + Username string + Notes string + AvatarURL string + AvatarSource string + AvatarMIME string + AvatarByteSize int + AvatarSHA256 string + PasswordHash string + Role string + Balance float64 + Concurrency int + Status string + AllowedGroups []int64 + TokenVersion int64 // Incremented on password change to invalidate existing tokens + SignupSource string + LastLoginAt *time.Time + LastActiveAt *time.Time + CreatedAt time.Time + UpdatedAt time.Time // GroupRates 用户专属分组倍率配置 // map[groupID]rateMultiplier diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 3490e804..2f6d9427 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -2,9 +2,13 @@ package service import ( "context" + "crypto/sha256" "crypto/subtle" + "encoding/base64" + "encoding/hex" "fmt" "log/slog" + "net/url" "strings" "time" @@ -17,10 +21,14 @@ var ( ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect") ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions") ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later") + ErrAvatarInvalid = infraerrors.BadRequest("AVATAR_INVALID", "avatar must be a valid image data URL or http(s) URL") + ErrAvatarTooLarge = infraerrors.BadRequest("AVATAR_TOO_LARGE", "avatar image must be 100KB or smaller") + ErrAvatarNotImage = infraerrors.BadRequest("AVATAR_NOT_IMAGE", "avatar content must be an image") ) const ( - maxNotifyEmails = 3 // Maximum number of notification emails per user + maxNotifyEmails = 3 // Maximum number of notification emails per user + maxInlineAvatarBytes = 100 * 1024 // User-level rate limiting for notify email verification codes notifyCodeUserRateLimit = 5 @@ -47,6 +55,9 @@ type UserRepository interface { GetFirstAdmin(ctx context.Context) (*User, error) Update(ctx context.Context, user *User) error Delete(ctx context.Context, id int64) error + GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) + UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) + DeleteUserAvatar(ctx context.Context, userID int64) error List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error) @@ -71,11 +82,30 @@ type UserRepository interface { type UpdateProfileRequest struct { Email *string `json:"email"` Username *string `json:"username"` + AvatarURL *string `json:"avatar_url"` Concurrency *int `json:"concurrency"` BalanceNotifyEnabled *bool `json:"balance_notify_enabled"` BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` } +type UserAvatar struct { + StorageProvider string + StorageKey string + URL string + ContentType string + ByteSize int + SHA256 string +} + +type UpsertUserAvatarInput struct { + StorageProvider string + StorageKey string + URL string + ContentType string + ByteSize int + SHA256 string +} + // ChangePasswordRequest 修改密码请求 type ChangePasswordRequest struct { CurrentPassword string `json:"current_password"` @@ -115,6 +145,9 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, erro if err != nil { return nil, fmt.Errorf("get user: %w", err) } + if err := s.hydrateUserAvatar(ctx, user); err != nil { + return nil, fmt.Errorf("get user avatar: %w", err) + } return user, nil } @@ -143,6 +176,27 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat user.Username = *req.Username } + if req.AvatarURL != nil { + avatarValue := strings.TrimSpace(*req.AvatarURL) + switch { + case avatarValue == "": + if err := s.userRepo.DeleteUserAvatar(ctx, userID); err != nil { + return nil, fmt.Errorf("delete avatar: %w", err) + } + applyUserAvatar(user, nil) + default: + avatarInput, err := normalizeUserAvatarInput(avatarValue) + if err != nil { + return nil, err + } + avatar, err := s.userRepo.UpsertUserAvatar(ctx, userID, avatarInput) + if err != nil { + return nil, fmt.Errorf("upsert avatar: %w", err) + } + applyUserAvatar(user, avatar) + } + } + if req.Concurrency != nil { user.Concurrency = *req.Concurrency } @@ -168,6 +222,87 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat return user, nil } +func applyUserAvatar(user *User, avatar *UserAvatar) { + if user == nil { + return + } + if avatar == nil { + user.AvatarURL = "" + user.AvatarSource = "" + user.AvatarMIME = "" + user.AvatarByteSize = 0 + user.AvatarSHA256 = "" + return + } + + user.AvatarURL = avatar.URL + user.AvatarSource = avatar.StorageProvider + user.AvatarMIME = avatar.ContentType + user.AvatarByteSize = avatar.ByteSize + user.AvatarSHA256 = avatar.SHA256 +} + +func normalizeUserAvatarInput(raw string) (UpsertUserAvatarInput, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + if strings.HasPrefix(raw, "data:") { + return normalizeInlineUserAvatarInput(raw) + } + + parsed, err := url.Parse(raw) + if err != nil || parsed == nil { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + if !strings.EqualFold(parsed.Scheme, "http") && !strings.EqualFold(parsed.Scheme, "https") { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + if strings.TrimSpace(parsed.Host) == "" { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + + return UpsertUserAvatarInput{ + StorageProvider: "remote_url", + URL: raw, + }, nil +} + +func normalizeInlineUserAvatarInput(raw string) (UpsertUserAvatarInput, error) { + body := strings.TrimPrefix(raw, "data:") + meta, encoded, ok := strings.Cut(body, ",") + if !ok { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + meta = strings.TrimSpace(meta) + encoded = strings.TrimSpace(encoded) + if !strings.HasSuffix(strings.ToLower(meta), ";base64") { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + + contentType := strings.TrimSpace(meta[:len(meta)-len(";base64")]) + if contentType == "" || !strings.HasPrefix(strings.ToLower(contentType), "image/") { + return UpsertUserAvatarInput{}, ErrAvatarNotImage + } + + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + if len(decoded) > maxInlineAvatarBytes { + return UpsertUserAvatarInput{}, ErrAvatarTooLarge + } + + sum := sha256.Sum256(decoded) + return UpsertUserAvatarInput{ + StorageProvider: "inline", + URL: raw, + ContentType: contentType, + ByteSize: len(decoded), + SHA256: hex.EncodeToString(sum[:]), + }, nil +} + // ChangePassword 修改密码 // Security: Increments TokenVersion to invalidate all existing JWT tokens func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error { @@ -202,9 +337,25 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) { if err != nil { return nil, fmt.Errorf("get user: %w", err) } + if err := s.hydrateUserAvatar(ctx, user); err != nil { + return nil, fmt.Errorf("get user avatar: %w", err) + } return user, nil } +func (s *UserService) hydrateUserAvatar(ctx context.Context, user *User) error { + if s == nil || s.userRepo == nil || user == nil || user.ID == 0 { + return nil + } + + avatar, err := s.userRepo.GetUserAvatar(ctx, user.ID) + if err != nil { + return err + } + applyUserAvatar(user, avatar) + return nil +} + // List 获取用户列表(管理员功能) func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { users, pagination, err := s.userRepo.List(ctx, params) diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index a998d5f4..7d63bb36 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -4,6 +4,9 @@ package service import ( "context" + "crypto/sha256" + "encoding/base64" + "encoding/hex" "errors" "sync" "sync/atomic" @@ -19,14 +22,65 @@ import ( type mockUserRepo struct { updateBalanceErr error updateBalanceFn func(ctx context.Context, id int64, amount float64) error + getByIDUser *User + getByIDErr error + updateFn func(ctx context.Context, user *User) error + updateCalls int + upsertAvatarFn func(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) + upsertAvatarArgs []UpsertUserAvatarInput + deleteAvatarFn func(ctx context.Context, userID int64) error + deleteAvatarIDs []int64 + getAvatarFn func(ctx context.Context, userID int64) (*UserAvatar, error) } -func (m *mockUserRepo) Create(context.Context, *User) error { return nil } -func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil } +func (m *mockUserRepo) Create(context.Context, *User) error { return nil } +func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { + if m.getByIDErr != nil { + return nil, m.getByIDErr + } + if m.getByIDUser != nil { + cloned := *m.getByIDUser + return &cloned, nil + } + return &User{}, nil +} func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil } func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil } -func (m *mockUserRepo) Update(context.Context, *User) error { return nil } -func (m *mockUserRepo) Delete(context.Context, int64) error { return nil } +func (m *mockUserRepo) Update(ctx context.Context, user *User) error { + m.updateCalls++ + if m.updateFn != nil { + return m.updateFn(ctx, user) + } + return nil +} +func (m *mockUserRepo) Delete(context.Context, int64) error { return nil } +func (m *mockUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) { + if m.getAvatarFn != nil { + return m.getAvatarFn(ctx, userID) + } + return nil, nil +} +func (m *mockUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) { + m.upsertAvatarArgs = append(m.upsertAvatarArgs, input) + if m.upsertAvatarFn != nil { + return m.upsertAvatarFn(ctx, userID, input) + } + return &UserAvatar{ + StorageProvider: input.StorageProvider, + StorageKey: input.StorageKey, + URL: input.URL, + ContentType: input.ContentType, + ByteSize: input.ByteSize, + SHA256: input.SHA256, + }, nil +} +func (m *mockUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error { + m.deleteAvatarIDs = append(m.deleteAvatarIDs, userID) + if m.deleteAvatarFn != nil { + return m.deleteAvatarFn(ctx, userID) + } + return nil +} func (m *mockUserRepo) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { return nil, nil, nil } @@ -200,3 +254,121 @@ func TestNewUserService_FieldsAssignment(t *testing.T) { require.Equal(t, auth, svc.authCacheInvalidator) require.Equal(t, cache, svc.billingCache) } + +func TestUpdateProfile_StoresInlineAvatarWithinLimit(t *testing.T) { + raw := []byte("small-avatar") + dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw) + expectedSum := sha256.Sum256(raw) + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 7, + Email: "avatar@example.com", + Username: "avatar-user", + }, + } + svc := NewUserService(repo, nil, nil, nil) + + updated, err := svc.UpdateProfile(context.Background(), 7, UpdateProfileRequest{ + AvatarURL: &dataURL, + }) + require.NoError(t, err) + require.Len(t, repo.upsertAvatarArgs, 1) + require.Equal(t, "inline", repo.upsertAvatarArgs[0].StorageProvider) + require.Equal(t, "image/png", repo.upsertAvatarArgs[0].ContentType) + require.Equal(t, len(raw), repo.upsertAvatarArgs[0].ByteSize) + require.Equal(t, hex.EncodeToString(expectedSum[:]), repo.upsertAvatarArgs[0].SHA256) + require.Equal(t, dataURL, updated.AvatarURL) + require.Equal(t, "inline", updated.AvatarSource) + require.Equal(t, "image/png", updated.AvatarMIME) + require.Equal(t, len(raw), updated.AvatarByteSize) + require.Equal(t, hex.EncodeToString(expectedSum[:]), updated.AvatarSHA256) +} + +func TestUpdateProfile_RejectsInlineAvatarOverLimit(t *testing.T) { + raw := make([]byte, maxInlineAvatarBytes+1) + dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw) + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 8, + Email: "large-avatar@example.com", + Username: "too-large", + }, + } + svc := NewUserService(repo, nil, nil, nil) + + _, err := svc.UpdateProfile(context.Background(), 8, UpdateProfileRequest{ + AvatarURL: &dataURL, + }) + require.ErrorIs(t, err, ErrAvatarTooLarge) + require.Empty(t, repo.upsertAvatarArgs) + require.Empty(t, repo.deleteAvatarIDs) + require.Zero(t, repo.updateCalls) +} + +func TestUpdateProfile_StoresRemoteAvatarURL(t *testing.T) { + remoteURL := "https://cdn.example.com/avatar.png" + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 9, + Email: "remote-avatar@example.com", + Username: "remote-avatar", + }, + } + svc := NewUserService(repo, nil, nil, nil) + + updated, err := svc.UpdateProfile(context.Background(), 9, UpdateProfileRequest{ + AvatarURL: &remoteURL, + }) + require.NoError(t, err) + require.Len(t, repo.upsertAvatarArgs, 1) + require.Equal(t, "remote_url", repo.upsertAvatarArgs[0].StorageProvider) + require.Equal(t, remoteURL, repo.upsertAvatarArgs[0].URL) + require.Equal(t, remoteURL, updated.AvatarURL) + require.Equal(t, "remote_url", updated.AvatarSource) + require.Zero(t, updated.AvatarByteSize) +} + +func TestUpdateProfile_DeletesAvatarOnEmptyString(t *testing.T) { + empty := "" + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 10, + Email: "delete-avatar@example.com", + Username: "delete-avatar", + AvatarURL: "https://cdn.example.com/old.png", + AvatarSource: "remote_url", + }, + } + svc := NewUserService(repo, nil, nil, nil) + + updated, err := svc.UpdateProfile(context.Background(), 10, UpdateProfileRequest{ + AvatarURL: &empty, + }) + require.NoError(t, err) + require.Equal(t, []int64{10}, repo.deleteAvatarIDs) + require.Empty(t, repo.upsertAvatarArgs) + require.Empty(t, updated.AvatarURL) + require.Empty(t, updated.AvatarSource) +} + +func TestGetProfile_HydratesAvatarFromRepository(t *testing.T) { + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 12, + Email: "profile-avatar@example.com", + Username: "profile-avatar", + }, + getAvatarFn: func(context.Context, int64) (*UserAvatar, error) { + return &UserAvatar{ + StorageProvider: "remote_url", + URL: "https://cdn.example.com/profile.png", + }, nil + }, + } + svc := NewUserService(repo, nil, nil, nil) + + user, err := svc.GetProfile(context.Background(), 12) + require.NoError(t, err) + require.Equal(t, "https://cdn.example.com/profile.png", user.AvatarURL) + require.Equal(t, "remote_url", user.AvatarSource) +} diff --git a/backend/migrations/108_auth_identity_foundation_core.sql b/backend/migrations/108_auth_identity_foundation_core.sql new file mode 100644 index 00000000..117e3ca3 --- /dev/null +++ b/backend/migrations/108_auth_identity_foundation_core.sql @@ -0,0 +1,141 @@ +ALTER TABLE users +ADD COLUMN IF NOT EXISTS signup_source VARCHAR(20) NOT NULL DEFAULT 'email', +ADD COLUMN IF NOT EXISTS last_login_at TIMESTAMPTZ NULL, +ADD COLUMN IF NOT EXISTS last_active_at TIMESTAMPTZ NULL; + +UPDATE users +SET signup_source = 'email' +WHERE signup_source IS NULL OR signup_source = ''; + +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 + FROM pg_constraint + WHERE conname = 'users_signup_source_check' + ) THEN + ALTER TABLE users + ADD CONSTRAINT users_signup_source_check + CHECK (signup_source IN ('email', 'linuxdo', 'wechat', 'oidc')); + END IF; +END $$; + +CREATE TABLE IF NOT EXISTS auth_identities ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + provider_type VARCHAR(20) NOT NULL, + provider_key TEXT NOT NULL, + provider_subject TEXT NOT NULL, + verified_at TIMESTAMPTZ NULL, + issuer TEXT NULL, + metadata JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT auth_identities_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')) +); + +CREATE UNIQUE INDEX IF NOT EXISTS auth_identities_provider_subject_key + ON auth_identities (provider_type, provider_key, provider_subject); + +CREATE INDEX IF NOT EXISTS auth_identities_user_id_idx + ON auth_identities (user_id); + +CREATE INDEX IF NOT EXISTS auth_identities_user_provider_idx + ON auth_identities (user_id, provider_type); + +CREATE TABLE IF NOT EXISTS auth_identity_channels ( + id BIGSERIAL PRIMARY KEY, + identity_id BIGINT NOT NULL REFERENCES auth_identities(id) ON DELETE CASCADE, + provider_type VARCHAR(20) NOT NULL, + provider_key TEXT NOT NULL, + channel VARCHAR(20) NOT NULL, + channel_app_id TEXT NOT NULL, + channel_subject TEXT NOT NULL, + metadata JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT auth_identity_channels_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')) +); + +CREATE UNIQUE INDEX IF NOT EXISTS auth_identity_channels_channel_key + ON auth_identity_channels (provider_type, provider_key, channel, channel_app_id, channel_subject); + +CREATE INDEX IF NOT EXISTS auth_identity_channels_identity_id_idx + ON auth_identity_channels (identity_id); + +CREATE TABLE IF NOT EXISTS pending_auth_sessions ( + id BIGSERIAL PRIMARY KEY, + session_token VARCHAR(255) NOT NULL, + intent VARCHAR(40) NOT NULL, + provider_type VARCHAR(20) NOT NULL, + provider_key TEXT NOT NULL, + provider_subject TEXT NOT NULL, + target_user_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL, + redirect_to TEXT NOT NULL DEFAULT '', + resolved_email TEXT NOT NULL DEFAULT '', + registration_password_hash TEXT NOT NULL DEFAULT '', + upstream_identity_claims JSONB NOT NULL DEFAULT '{}'::jsonb, + local_flow_state JSONB NOT NULL DEFAULT '{}'::jsonb, + browser_session_key TEXT NOT NULL DEFAULT '', + completion_code_hash TEXT NOT NULL DEFAULT '', + completion_code_expires_at TIMESTAMPTZ NULL, + email_verified_at TIMESTAMPTZ NULL, + password_verified_at TIMESTAMPTZ NULL, + totp_verified_at TIMESTAMPTZ NULL, + expires_at TIMESTAMPTZ NOT NULL, + consumed_at TIMESTAMPTZ NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT pending_auth_sessions_intent_check + CHECK (intent IN ('login', 'bind_current_user', 'adopt_existing_user_by_email')), + CONSTRAINT pending_auth_sessions_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')) +); + +CREATE UNIQUE INDEX IF NOT EXISTS pending_auth_sessions_session_token_key + ON pending_auth_sessions (session_token); + +CREATE INDEX IF NOT EXISTS pending_auth_sessions_target_user_id_idx + ON pending_auth_sessions (target_user_id); + +CREATE INDEX IF NOT EXISTS pending_auth_sessions_expires_at_idx + ON pending_auth_sessions (expires_at); + +CREATE INDEX IF NOT EXISTS pending_auth_sessions_provider_idx + ON pending_auth_sessions (provider_type, provider_key, provider_subject); + +CREATE INDEX IF NOT EXISTS pending_auth_sessions_completion_code_idx + ON pending_auth_sessions (completion_code_hash); + +CREATE TABLE IF NOT EXISTS identity_adoption_decisions ( + id BIGSERIAL PRIMARY KEY, + pending_auth_session_id BIGINT NOT NULL REFERENCES pending_auth_sessions(id) ON DELETE CASCADE, + identity_id BIGINT NULL REFERENCES auth_identities(id) ON DELETE SET NULL, + adopt_display_name BOOLEAN NOT NULL DEFAULT FALSE, + adopt_avatar BOOLEAN NOT NULL DEFAULT FALSE, + decided_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE UNIQUE INDEX IF NOT EXISTS identity_adoption_decisions_pending_auth_session_id_key + ON identity_adoption_decisions (pending_auth_session_id); + +CREATE INDEX IF NOT EXISTS identity_adoption_decisions_identity_id_idx + ON identity_adoption_decisions (identity_id); + +CREATE TABLE IF NOT EXISTS auth_identity_migration_reports ( + id BIGSERIAL PRIMARY KEY, + report_type VARCHAR(40) NOT NULL, + report_key TEXT NOT NULL, + details JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS auth_identity_migration_reports_type_idx + ON auth_identity_migration_reports (report_type); + +CREATE UNIQUE INDEX IF NOT EXISTS auth_identity_migration_reports_type_key + ON auth_identity_migration_reports (report_type, report_key); diff --git a/backend/migrations/109_auth_identity_compat_backfill.sql b/backend/migrations/109_auth_identity_compat_backfill.sql new file mode 100644 index 00000000..ddbbedbc --- /dev/null +++ b/backend/migrations/109_auth_identity_compat_backfill.sql @@ -0,0 +1,125 @@ +INSERT INTO auth_identities ( + user_id, + provider_type, + provider_key, + provider_subject, + verified_at, + metadata +) +SELECT + u.id, + 'email', + 'email', + LOWER(BTRIM(u.email)), + COALESCE(u.updated_at, u.created_at, NOW()), + jsonb_build_object( + 'backfill_source', 'users.email', + 'migration', '109_auth_identity_compat_backfill' + ) +FROM users AS u +WHERE u.deleted_at IS NULL + AND BTRIM(COALESCE(u.email, '')) <> '' + AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@linuxdo-connect.invalid')) <> '@linuxdo-connect.invalid' + AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@oidc-connect.invalid')) <> '@oidc-connect.invalid' + AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@wechat-connect.invalid')) <> '@wechat-connect.invalid' +ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; + +INSERT INTO auth_identities ( + user_id, + provider_type, + provider_key, + provider_subject, + verified_at, + metadata +) +SELECT + u.id, + 'linuxdo', + 'linuxdo', + SUBSTRING(BTRIM(u.email) FROM '(?i)^linuxdo-(.+)@linuxdo-connect\.invalid$'), + COALESCE(u.updated_at, u.created_at, NOW()), + jsonb_build_object( + 'backfill_source', 'synthetic_email', + 'legacy_email', BTRIM(u.email), + 'migration', '109_auth_identity_compat_backfill' + ) +FROM users AS u +WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(u.email)) ~ '^linuxdo-.+@linuxdo-connect\.invalid$' +ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; + +INSERT INTO auth_identities ( + user_id, + provider_type, + provider_key, + provider_subject, + verified_at, + metadata +) +SELECT + u.id, + 'wechat', + 'wechat', + SUBSTRING(BTRIM(u.email) FROM '(?i)^wechat-(.+)@wechat-connect\.invalid$'), + COALESCE(u.updated_at, u.created_at, NOW()), + jsonb_build_object( + 'backfill_source', 'synthetic_email', + 'legacy_email', BTRIM(u.email), + 'migration', '109_auth_identity_compat_backfill' + ) +FROM users AS u +WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(u.email)) ~ '^wechat-.+@wechat-connect\.invalid$' +ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; + +UPDATE users +SET signup_source = 'linuxdo' +WHERE deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^linuxdo-.+@linuxdo-connect\.invalid$'; + +UPDATE users +SET signup_source = 'wechat' +WHERE deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^wechat-.+@wechat-connect\.invalid$'; + +UPDATE users +SET signup_source = 'oidc' +WHERE deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^oidc-.+@oidc-connect\.invalid$'; + +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'oidc_synthetic_email_requires_manual_recovery', + CAST(u.id AS TEXT), + jsonb_build_object( + 'user_id', u.id, + 'email', LOWER(BTRIM(u.email)), + 'reason', 'cannot recover issuer_plus_sub deterministically from synthetic email alone', + 'migration', '109_auth_identity_compat_backfill' + ) +FROM users AS u +WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(u.email)) ~ '^oidc-.+@oidc-connect\.invalid$' +ON CONFLICT (report_type, report_key) DO NOTHING; + +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'wechat_openid_only_requires_remediation', + CAST(u.id AS TEXT), + jsonb_build_object( + 'user_id', u.id, + 'email', LOWER(BTRIM(u.email)), + 'reason', 'legacy wechat synthetic identity requires explicit unionid remediation if channel-only data exists', + 'migration', '109_auth_identity_compat_backfill' + ) +FROM users AS u +WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(u.email)) ~ '^wechat-.+@wechat-connect\.invalid$' + AND NOT EXISTS ( + SELECT 1 + FROM auth_identities ai + WHERE ai.user_id = u.id + AND ai.provider_type = 'wechat' + AND ai.provider_key = 'wechat' + ) +ON CONFLICT (report_type, report_key) DO NOTHING; diff --git a/backend/migrations/110_pending_auth_and_provider_default_grants.sql b/backend/migrations/110_pending_auth_and_provider_default_grants.sql new file mode 100644 index 00000000..fbaed62e --- /dev/null +++ b/backend/migrations/110_pending_auth_and_provider_default_grants.sql @@ -0,0 +1,60 @@ +CREATE TABLE IF NOT EXISTS user_provider_default_grants ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + provider_type VARCHAR(20) NOT NULL, + grant_reason VARCHAR(20) NOT NULL DEFAULT 'first_bind', + granted_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT user_provider_default_grants_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')), + CONSTRAINT user_provider_default_grants_reason_check + CHECK (grant_reason IN ('signup', 'first_bind')) +); + +CREATE UNIQUE INDEX IF NOT EXISTS user_provider_default_grants_user_provider_reason_key + ON user_provider_default_grants (user_id, provider_type, grant_reason); + +CREATE INDEX IF NOT EXISTS user_provider_default_grants_user_id_idx + ON user_provider_default_grants (user_id); + +CREATE TABLE IF NOT EXISTS user_avatars ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + storage_provider VARCHAR(20) NOT NULL DEFAULT 'database', + storage_key TEXT NOT NULL DEFAULT '', + url TEXT NOT NULL DEFAULT '', + content_type VARCHAR(100) NOT NULL DEFAULT '', + byte_size INT NOT NULL DEFAULT 0, + sha256 VARCHAR(64) NOT NULL DEFAULT '', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE UNIQUE INDEX IF NOT EXISTS user_avatars_user_id_key + ON user_avatars (user_id); + +INSERT INTO settings (key, value) +VALUES + ('auth_source_default_email_balance', '0'), + ('auth_source_default_email_concurrency', '5'), + ('auth_source_default_email_subscriptions', '[]'), + ('auth_source_default_email_grant_on_signup', 'true'), + ('auth_source_default_email_grant_on_first_bind', 'false'), + ('auth_source_default_linuxdo_balance', '0'), + ('auth_source_default_linuxdo_concurrency', '5'), + ('auth_source_default_linuxdo_subscriptions', '[]'), + ('auth_source_default_linuxdo_grant_on_signup', 'true'), + ('auth_source_default_linuxdo_grant_on_first_bind', 'false'), + ('auth_source_default_oidc_balance', '0'), + ('auth_source_default_oidc_concurrency', '5'), + ('auth_source_default_oidc_subscriptions', '[]'), + ('auth_source_default_oidc_grant_on_signup', 'true'), + ('auth_source_default_oidc_grant_on_first_bind', 'false'), + ('auth_source_default_wechat_balance', '0'), + ('auth_source_default_wechat_concurrency', '5'), + ('auth_source_default_wechat_subscriptions', '[]'), + ('auth_source_default_wechat_grant_on_signup', 'true'), + ('auth_source_default_wechat_grant_on_first_bind', 'false'), + ('force_email_on_third_party_signup', 'false') +ON CONFLICT (key) DO NOTHING; + diff --git a/backend/migrations/111_payment_routing_and_scheduler_flags.sql b/backend/migrations/111_payment_routing_and_scheduler_flags.sql new file mode 100644 index 00000000..f222a8d4 --- /dev/null +++ b/backend/migrations/111_payment_routing_and_scheduler_flags.sql @@ -0,0 +1,8 @@ +INSERT INTO settings (key, value) +VALUES + ('payment_visible_method_alipay_source', ''), + ('payment_visible_method_wxpay_source', ''), + ('payment_visible_method_alipay_enabled', 'false'), + ('payment_visible_method_wxpay_enabled', 'false'), + ('openai_advanced_scheduler_enabled', 'false') +ON CONFLICT (key) DO NOTHING; diff --git a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts new file mode 100644 index 00000000..574e1e36 --- /dev/null +++ b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts @@ -0,0 +1,60 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const post = vi.fn() + +vi.mock('@/api/client', () => ({ + apiClient: { + post + } +})) + +describe('oauth adoption auth api', () => { + beforeEach(() => { + post.mockReset() + post.mockResolvedValue({ data: {} }) + }) + + it('posts adoption decisions when exchanging pending oauth completion', async () => { + const { exchangePendingOAuthCompletion } = await import('@/api/auth') + + await exchangePendingOAuthCompletion({ + adoptDisplayName: false, + adoptAvatar: true + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/pending/exchange', { + adopt_display_name: false, + adopt_avatar: true + }) + }) + + it('posts linuxdo invitation completion with adoption decisions', async () => { + const { completeLinuxDoOAuthRegistration } = await import('@/api/auth') + + await completeLinuxDoOAuthRegistration('invite-code', { + adoptDisplayName: true, + adoptAvatar: false + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', { + invitation_code: 'invite-code', + adopt_display_name: true, + adopt_avatar: false + }) + }) + + it('posts oidc invitation completion with adoption decisions', async () => { + const { completeOIDCOAuthRegistration } = await import('@/api/auth') + + await completeOIDCOAuthRegistration('invite-code', { + adoptDisplayName: false, + adoptAvatar: true + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/oidc/complete-registration', { + invitation_code: 'invite-code', + adopt_display_name: false, + adopt_avatar: true + }) + }) +}) diff --git a/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts b/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts new file mode 100644 index 00000000..8756146e --- /dev/null +++ b/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts @@ -0,0 +1,118 @@ +import { describe, expect, it } from 'vitest' + +import { + appendAuthSourceDefaultsToUpdateRequest, + buildAuthSourceDefaultsState, + type UpdateSettingsRequest, +} from '@/api/admin/settings' + +describe('admin settings auth source defaults helpers', () => { + it('builds auth source defaults state from flat settings fields', () => { + const state = buildAuthSourceDefaultsState({ + auth_source_default_email_balance: 9.5, + auth_source_default_email_concurrency: 3, + auth_source_default_email_subscriptions: [ + { group_id: 1, validity_days: 30 }, + ], + auth_source_default_email_grant_on_signup: false, + auth_source_default_email_grant_on_first_bind: true, + auth_source_default_linuxdo_balance: 6, + auth_source_default_linuxdo_concurrency: 8, + auth_source_default_linuxdo_subscriptions: [ + { group_id: 2, validity_days: 60 }, + ], + auth_source_default_linuxdo_grant_on_signup: true, + auth_source_default_linuxdo_grant_on_first_bind: false, + }) + + expect(state.email).toEqual({ + balance: 9.5, + concurrency: 3, + subscriptions: [{ group_id: 1, validity_days: 30 }], + grant_on_signup: false, + grant_on_first_bind: true, + }) + expect(state.linuxdo).toEqual({ + balance: 6, + concurrency: 8, + subscriptions: [{ group_id: 2, validity_days: 60 }], + grant_on_signup: true, + grant_on_first_bind: false, + }) + expect(state.oidc).toEqual({ + balance: 0, + concurrency: 5, + subscriptions: [], + grant_on_signup: true, + grant_on_first_bind: false, + }) + expect(state.wechat).toEqual({ + balance: 0, + concurrency: 5, + subscriptions: [], + grant_on_signup: true, + grant_on_first_bind: false, + }) + }) + + it('appends auth source defaults back onto update payload', () => { + const payload: UpdateSettingsRequest = { + site_name: 'Sub2API', + } + + appendAuthSourceDefaultsToUpdateRequest(payload, { + email: { + balance: 1.25, + concurrency: 2, + subscriptions: [{ group_id: 3, validity_days: 7 }], + grant_on_signup: true, + grant_on_first_bind: false, + }, + linuxdo: { + balance: 0, + concurrency: 6, + subscriptions: [], + grant_on_signup: false, + grant_on_first_bind: true, + }, + oidc: { + balance: 4, + concurrency: 9, + subscriptions: [{ group_id: 9, validity_days: 90 }], + grant_on_signup: true, + grant_on_first_bind: true, + }, + wechat: { + balance: 2, + concurrency: 5, + subscriptions: [], + grant_on_signup: false, + grant_on_first_bind: false, + }, + }) + + expect(payload).toMatchObject({ + site_name: 'Sub2API', + auth_source_default_email_balance: 1.25, + auth_source_default_email_concurrency: 2, + auth_source_default_email_subscriptions: [{ group_id: 3, validity_days: 7 }], + auth_source_default_email_grant_on_signup: true, + auth_source_default_email_grant_on_first_bind: false, + auth_source_default_linuxdo_balance: 0, + auth_source_default_linuxdo_concurrency: 6, + auth_source_default_linuxdo_subscriptions: [], + auth_source_default_linuxdo_grant_on_signup: false, + auth_source_default_linuxdo_grant_on_first_bind: true, + auth_source_default_oidc_balance: 4, + auth_source_default_oidc_concurrency: 9, + auth_source_default_oidc_subscriptions: [{ group_id: 9, validity_days: 90 }], + auth_source_default_oidc_grant_on_signup: true, + auth_source_default_oidc_grant_on_first_bind: true, + auth_source_default_wechat_balance: 2, + auth_source_default_wechat_concurrency: 5, + auth_source_default_wechat_subscriptions: [], + auth_source_default_wechat_grant_on_signup: false, + auth_source_default_wechat_grant_on_first_bind: false, + }) + }) +}) diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 1e4a3053..8e182c1c 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -11,6 +11,81 @@ export interface DefaultSubscriptionSetting { validity_days: number } +export type AuthSourceType = 'email' | 'linuxdo' | 'oidc' | 'wechat' + +export interface AuthSourceDefaultsValue { + balance: number + concurrency: number + subscriptions: DefaultSubscriptionSetting[] + grant_on_signup: boolean + grant_on_first_bind: boolean +} + +export type AuthSourceDefaultsState = Record + +const AUTH_SOURCE_TYPES: AuthSourceType[] = ['email', 'linuxdo', 'oidc', 'wechat'] +const AUTH_SOURCE_DEFAULT_BALANCE = 0 +const AUTH_SOURCE_DEFAULT_CONCURRENCY = 5 + +export function normalizeDefaultSubscriptionSettings( + subscriptions: DefaultSubscriptionSetting[] | null | undefined +): DefaultSubscriptionSetting[] { + if (!Array.isArray(subscriptions)) return [] + + return subscriptions + .filter((item) => item.group_id > 0 && item.validity_days > 0) + .map((item) => ({ + group_id: Math.floor(item.group_id), + validity_days: Math.min(36500, Math.max(1, Math.floor(item.validity_days))) + })) +} + +export function buildAuthSourceDefaultsState( + settings: Partial +): AuthSourceDefaultsState { + const raw = settings as Record + + return AUTH_SOURCE_TYPES.reduce((acc, source) => { + const subscriptions = raw[`auth_source_default_${source}_subscriptions`] + acc[source] = { + balance: Number(raw[`auth_source_default_${source}_balance`] ?? AUTH_SOURCE_DEFAULT_BALANCE), + concurrency: Math.max( + 1, + Number(raw[`auth_source_default_${source}_concurrency`] ?? AUTH_SOURCE_DEFAULT_CONCURRENCY) + ), + subscriptions: normalizeDefaultSubscriptionSettings( + Array.isArray(subscriptions) ? (subscriptions as DefaultSubscriptionSetting[]) : [] + ), + grant_on_signup: raw[`auth_source_default_${source}_grant_on_signup`] !== false, + grant_on_first_bind: raw[`auth_source_default_${source}_grant_on_first_bind`] === true, + } + return acc + }, {} as AuthSourceDefaultsState) +} + +export function appendAuthSourceDefaultsToUpdateRequest( + payload: UpdateSettingsRequest, + authSourceDefaults: AuthSourceDefaultsState +): UpdateSettingsRequest { + const target = payload as Record + + for (const source of AUTH_SOURCE_TYPES) { + const current = authSourceDefaults[source] + target[`auth_source_default_${source}_balance`] = Number(current.balance) || 0 + target[`auth_source_default_${source}_concurrency`] = Math.max( + 1, + Math.floor(Number(current.concurrency) || AUTH_SOURCE_DEFAULT_CONCURRENCY) + ) + target[`auth_source_default_${source}_subscriptions`] = normalizeDefaultSubscriptionSettings( + current.subscriptions + ) + target[`auth_source_default_${source}_grant_on_signup`] = current.grant_on_signup + target[`auth_source_default_${source}_grant_on_first_bind`] = current.grant_on_first_bind + } + + return payload +} + /** * System settings interface */ @@ -29,6 +104,27 @@ export interface SystemSettings { default_balance: number default_concurrency: number default_subscriptions: DefaultSubscriptionSetting[] + auth_source_default_email_balance?: number + auth_source_default_email_concurrency?: number + auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_email_grant_on_signup?: boolean + auth_source_default_email_grant_on_first_bind?: boolean + auth_source_default_linuxdo_balance?: number + auth_source_default_linuxdo_concurrency?: number + auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_linuxdo_grant_on_signup?: boolean + auth_source_default_linuxdo_grant_on_first_bind?: boolean + auth_source_default_oidc_balance?: number + auth_source_default_oidc_concurrency?: number + auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_oidc_grant_on_signup?: boolean + auth_source_default_oidc_grant_on_first_bind?: boolean + auth_source_default_wechat_balance?: number + auth_source_default_wechat_concurrency?: number + auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_wechat_grant_on_signup?: boolean + auth_source_default_wechat_grant_on_first_bind?: boolean + force_email_on_third_party_signup?: boolean // OEM settings site_name: string site_logo: string @@ -137,6 +233,11 @@ export interface SystemSettings { payment_cancel_rate_limit_window: number payment_cancel_rate_limit_unit: string payment_cancel_rate_limit_window_mode: string + payment_visible_method_alipay_source?: string + payment_visible_method_wxpay_source?: string + payment_visible_method_alipay_enabled?: boolean + payment_visible_method_wxpay_enabled?: boolean + openai_advanced_scheduler_enabled?: boolean // Balance & quota notification balance_low_notify_enabled: boolean @@ -158,6 +259,27 @@ export interface UpdateSettingsRequest { default_balance?: number default_concurrency?: number default_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_email_balance?: number + auth_source_default_email_concurrency?: number + auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_email_grant_on_signup?: boolean + auth_source_default_email_grant_on_first_bind?: boolean + auth_source_default_linuxdo_balance?: number + auth_source_default_linuxdo_concurrency?: number + auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_linuxdo_grant_on_signup?: boolean + auth_source_default_linuxdo_grant_on_first_bind?: boolean + auth_source_default_oidc_balance?: number + auth_source_default_oidc_concurrency?: number + auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_oidc_grant_on_signup?: boolean + auth_source_default_oidc_grant_on_first_bind?: boolean + auth_source_default_wechat_balance?: number + auth_source_default_wechat_concurrency?: number + auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_wechat_grant_on_signup?: boolean + auth_source_default_wechat_grant_on_first_bind?: boolean + force_email_on_third_party_signup?: boolean site_name?: string site_logo?: string site_subtitle?: string @@ -245,6 +367,11 @@ export interface UpdateSettingsRequest { payment_cancel_rate_limit_window?: number payment_cancel_rate_limit_unit?: string payment_cancel_rate_limit_window_mode?: string + payment_visible_method_alipay_source?: string + payment_visible_method_wxpay_source?: string + payment_visible_method_alipay_enabled?: boolean + payment_visible_method_wxpay_enabled?: boolean + openai_advanced_scheduler_enabled?: boolean // Balance & quota notification balance_low_notify_enabled?: boolean balance_low_notify_threshold?: number diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index d7abcd6a..10b6ca58 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -198,6 +198,26 @@ export interface PendingOAuthExchangeResponse { suggested_avatar_url?: string } +export interface OAuthAdoptionDecision { + adoptDisplayName?: boolean + adoptAvatar?: boolean +} + +function serializeOAuthAdoptionDecision( + decision?: OAuthAdoptionDecision +): Record { + const payload: Record = {} + + if (typeof decision?.adoptDisplayName === 'boolean') { + payload.adopt_display_name = decision.adoptDisplayName + } + if (typeof decision?.adoptAvatar === 'boolean') { + payload.adopt_avatar = decision.adoptAvatar + } + + return payload +} + /** * Refresh the access token using the refresh token * @returns New token pair @@ -353,7 +373,8 @@ export async function resetPassword(request: ResetPasswordRequest): Promise { const { data } = await apiClient.post<{ access_token: string @@ -361,7 +382,8 @@ export async function completeLinuxDoOAuthRegistration( expires_in: number token_type: string }>('/auth/oauth/linuxdo/complete-registration', { - invitation_code: invitationCode + invitation_code: invitationCode, + ...serializeOAuthAdoptionDecision(decision) }) return data } @@ -372,7 +394,8 @@ export async function completeLinuxDoOAuthRegistration( * @returns Token pair on success */ export async function completeOIDCOAuthRegistration( - invitationCode: string + invitationCode: string, + decision?: OAuthAdoptionDecision ): Promise<{ access_token: string; refresh_token: string; expires_in: number; token_type: string }> { const { data } = await apiClient.post<{ access_token: string @@ -380,13 +403,19 @@ export async function completeOIDCOAuthRegistration( expires_in: number token_type: string }>('/auth/oauth/oidc/complete-registration', { - invitation_code: invitationCode + invitation_code: invitationCode, + ...serializeOAuthAdoptionDecision(decision) }) return data } -export async function exchangePendingOAuthCompletion(): Promise { - const { data } = await apiClient.post('/auth/oauth/pending/exchange', {}) +export async function exchangePendingOAuthCompletion( + decision?: OAuthAdoptionDecision +): Promise { + const { data } = await apiClient.post( + '/auth/oauth/pending/exchange', + serializeOAuthAdoptionDecision(decision) + ) return data } diff --git a/frontend/src/components/auth/WechatOAuthSection.vue b/frontend/src/components/auth/WechatOAuthSection.vue new file mode 100644 index 00000000..94e20222 --- /dev/null +++ b/frontend/src/components/auth/WechatOAuthSection.vue @@ -0,0 +1,53 @@ + + + diff --git a/frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts b/frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts new file mode 100644 index 00000000..810832a0 --- /dev/null +++ b/frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts @@ -0,0 +1,74 @@ +import { mount } from '@vue/test-utils' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import WechatOAuthSection from '@/components/auth/WechatOAuthSection.vue' + +const routeState = vi.hoisted(() => ({ + query: {} as Record, +})) + +const locationState = vi.hoisted(() => ({ + current: { href: 'http://localhost/login' } as { href: string }, +})) + +vi.mock('vue-router', () => ({ + useRoute: () => routeState, +})) + +vi.mock('vue-i18n', () => ({ + useI18n: () => ({ + t: (key: string, params?: Record) => { + if (key === 'auth.oidc.signIn') { + return `Continue with ${params?.providerName ?? ''}`.trim() + } + if (key === 'auth.oauthOrContinue') { + return 'or continue' + } + return key + }, + }), +})) + +describe('WechatOAuthSection', () => { + beforeEach(() => { + routeState.query = { redirect: '/billing?plan=pro' } + locationState.current = { href: 'http://localhost/login' } + Object.defineProperty(window, 'location', { + configurable: true, + value: locationState.current, + }) + Object.defineProperty(window.navigator, 'userAgent', { + configurable: true, + value: 'Mozilla/5.0', + }) + }) + + afterEach(() => { + vi.unstubAllGlobals() + }) + + it('starts the open WeChat OAuth flow with the current redirect target', async () => { + const wrapper = mount(WechatOAuthSection) + + expect(wrapper.text()).toContain('WeChat') + + await wrapper.get('button').trigger('click') + + expect(locationState.current.href).toContain( + '/api/v1/auth/oauth/wechat/start?mode=open&redirect=%2Fbilling%3Fplan%3Dpro' + ) + }) + + it('uses mp mode inside the WeChat browser', async () => { + Object.defineProperty(window.navigator, 'userAgent', { + configurable: true, + value: 'Mozilla/5.0 MicroMessenger', + }) + const wrapper = mount(WechatOAuthSection) + + await wrapper.get('button').trigger('click') + + expect(locationState.current.href).toContain( + '/api/v1/auth/oauth/wechat/start?mode=mp&redirect=%2Fbilling%3Fplan%3Dpro' + ) + }) +}) diff --git a/frontend/src/components/payment/PaymentStatusPanel.vue b/frontend/src/components/payment/PaymentStatusPanel.vue index 974dee66..8f5a5666 100644 --- a/frontend/src/components/payment/PaymentStatusPanel.vue +++ b/frontend/src/components/payment/PaymentStatusPanel.vue @@ -141,7 +141,9 @@ const props = defineProps<{ orderType?: string }>() -const emit = defineEmits<{ done: []; success: [] }>() +type PaymentOutcome = 'success' | 'cancelled' | 'expired' + +const emit = defineEmits<{ done: []; success: []; settled: [outcome: PaymentOutcome] }>() const { t } = useI18n() const paymentStore = usePaymentStore() @@ -154,7 +156,7 @@ const cancelling = ref(false) const paidOrder = ref(null) // Terminal outcome: null = still active, 'success' | 'cancelled' | 'expired' -const outcome = ref<'success' | 'cancelled' | 'expired' | null>(null) +const outcome = ref(null) let pollTimer: ReturnType | null = null let countdownTimer: ReturnType | null = null @@ -194,10 +196,19 @@ const countdownDisplay = computed(() => { function reopenPopup() { if (props.payUrl) { - window.open(props.payUrl, 'paymentPopup', POPUP_WINDOW_FEATURES) + const win = window.open(props.payUrl, 'paymentPopup', POPUP_WINDOW_FEATURES) + if (!win || win.closed) { + window.location.href = props.payUrl + } } } +function setOutcome(next: PaymentOutcome) { + if (outcome.value === next) return + outcome.value = next + emit('settled', next) +} + async function renderQR() { await nextTick() if (!qrCanvas.value || !qrUrl.value) return @@ -214,23 +225,23 @@ async function pollStatus() { if (order.status === 'COMPLETED' || order.status === 'PAID') { cleanup() paidOrder.value = order - outcome.value = 'success' + setOutcome('success') emit('success') } else if (order.status === 'CANCELLED') { cleanup() - outcome.value = 'cancelled' + setOutcome('cancelled') } else if (order.status === 'EXPIRED' || order.status === 'FAILED') { cleanup() - outcome.value = 'expired' + setOutcome('expired') } } function startCountdown(seconds: number) { remainingSeconds.value = Math.max(0, seconds) - if (remainingSeconds.value <= 0) { outcome.value = 'expired'; return } + if (remainingSeconds.value <= 0) { setOutcome('expired'); return } countdownTimer = setInterval(() => { remainingSeconds.value-- - if (remainingSeconds.value <= 0) { outcome.value = 'expired'; cleanup() } + if (remainingSeconds.value <= 0) { setOutcome('expired'); cleanup() } }, 1000) } @@ -240,7 +251,7 @@ async function handleCancel() { try { await paymentAPI.cancelOrder(props.orderId) cleanup() - outcome.value = 'cancelled' + setOutcome('cancelled') } catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'))) } finally { diff --git a/frontend/src/components/payment/__tests__/paymentFlow.spec.ts b/frontend/src/components/payment/__tests__/paymentFlow.spec.ts new file mode 100644 index 00000000..f5212f15 --- /dev/null +++ b/frontend/src/components/payment/__tests__/paymentFlow.spec.ts @@ -0,0 +1,163 @@ +import { describe, expect, it } from 'vitest' +import type { CreateOrderResult, MethodLimit } from '@/types/payment' +import { + decidePaymentLaunch, + getVisibleMethods, + readPaymentRecoverySnapshot, + type PaymentRecoverySnapshot, +} from '@/components/payment/paymentFlow' + +function methodLimit(overrides: Partial = {}): MethodLimit { + return { + daily_limit: 0, + daily_used: 0, + daily_remaining: 0, + single_min: 0, + single_max: 0, + fee_rate: 0, + available: true, + ...overrides, + } +} + +function createOrderResult(overrides: Partial = {}): CreateOrderResult { + return { + order_id: 101, + amount: 88, + pay_amount: 88, + fee_rate: 0, + expires_at: '2099-01-01T00:10:00.000Z', + ...overrides, + } +} + +describe('getVisibleMethods', () => { + it('filters hidden provider methods and normalizes aliases', () => { + const visible = getVisibleMethods({ + alipay_direct: methodLimit({ single_min: 5 }), + wxpay: methodLimit({ single_max: 100 }), + stripe: methodLimit({ fee_rate: 3 }), + }) + + expect(visible).toEqual({ + alipay: methodLimit({ single_min: 5 }), + wxpay: methodLimit({ single_max: 100 }), + }) + }) + + it('prefers canonical visible methods over aliases when both exist', () => { + const visible = getVisibleMethods({ + alipay: methodLimit({ single_min: 2 }), + alipay_direct: methodLimit({ single_min: 9 }), + wxpay_direct: methodLimit({ fee_rate: 1.2 }), + }) + + expect(visible.alipay.single_min).toBe(2) + expect(visible.wxpay.fee_rate).toBe(1.2) + }) +}) + +describe('decidePaymentLaunch', () => { + it('uses Stripe popup waiting flow for desktop Alipay client secret', () => { + const decision = decidePaymentLaunch(createOrderResult({ + client_secret: 'cs_test', + resume_token: 'resume-1', + }), { + visibleMethod: 'alipay', + orderType: 'balance', + isMobile: false, + }) + + expect(decision.kind).toBe('stripe_popup') + expect(decision.paymentState.paymentType).toBe('alipay') + expect(decision.stripeMethod).toBe('alipay') + expect(decision.recovery.resumeToken).toBe('resume-1') + }) + + it('uses Stripe route flow for mobile WeChat client secret', () => { + const decision = decidePaymentLaunch(createOrderResult({ + client_secret: 'cs_test', + }), { + visibleMethod: 'wxpay', + orderType: 'subscription', + isMobile: true, + }) + + expect(decision.kind).toBe('stripe_route') + expect(decision.stripeMethod).toBe('wechat_pay') + expect(decision.paymentState.orderType).toBe('subscription') + }) + + it('keeps hosted redirect metadata for recovery flows', () => { + const decision = decidePaymentLaunch(createOrderResult({ + pay_url: 'https://pay.example.com/session/abc', + payment_mode: 'popup', + resume_token: 'resume-2', + }), { + visibleMethod: 'wxpay', + orderType: 'balance', + isMobile: false, + }) + + expect(decision.kind).toBe('redirect_waiting') + expect(decision.paymentState.payUrl).toBe('https://pay.example.com/session/abc') + expect(decision.recovery.paymentMode).toBe('popup') + expect(decision.recovery.resumeToken).toBe('resume-2') + }) +}) + +describe('readPaymentRecoverySnapshot', () => { + it('restores an unexpired snapshot when the resume token matches', () => { + const snapshot: PaymentRecoverySnapshot = { + orderId: 33, + amount: 18, + qrCode: '', + expiresAt: '2099-01-01T00:10:00.000Z', + paymentType: 'alipay', + payUrl: 'https://pay.example.com/session/33', + clientSecret: '', + payAmount: 18, + orderType: 'balance', + paymentMode: 'popup', + resumeToken: 'resume-33', + createdAt: Date.UTC(2099, 0, 1, 0, 0, 0), + } + + const restored = readPaymentRecoverySnapshot(JSON.stringify(snapshot), { + now: Date.UTC(2099, 0, 1, 0, 1, 0), + resumeToken: 'resume-33', + }) + + expect(restored?.orderId).toBe(33) + }) + + it('drops expired or mismatched recovery snapshots', () => { + const expiredSnapshot: PaymentRecoverySnapshot = { + orderId: 55, + amount: 18, + qrCode: '', + expiresAt: '2024-01-01T00:10:00.000Z', + paymentType: 'wxpay', + payUrl: 'https://pay.example.com/session/55', + clientSecret: '', + payAmount: 18, + orderType: 'balance', + paymentMode: 'popup', + resumeToken: 'resume-55', + createdAt: Date.UTC(2024, 0, 1, 0, 0, 0), + } + + expect(readPaymentRecoverySnapshot(JSON.stringify(expiredSnapshot), { + now: Date.UTC(2024, 0, 1, 0, 20, 0), + resumeToken: 'resume-55', + })).toBeNull() + + expect(readPaymentRecoverySnapshot(JSON.stringify({ + ...expiredSnapshot, + expiresAt: '2099-01-01T00:10:00.000Z', + }), { + now: Date.UTC(2099, 0, 1, 0, 1, 0), + resumeToken: 'other-token', + })).toBeNull() + }) +}) diff --git a/frontend/src/components/payment/paymentFlow.ts b/frontend/src/components/payment/paymentFlow.ts new file mode 100644 index 00000000..70225a0c --- /dev/null +++ b/frontend/src/components/payment/paymentFlow.ts @@ -0,0 +1,197 @@ +import type { CreateOrderResult, MethodLimit, OrderType } from '@/types/payment' + +export const PAYMENT_RECOVERY_STORAGE_KEY = 'payment.recovery.current' + +const VISIBLE_METHOD_ALIASES = { + alipay: 'alipay', + alipay_direct: 'alipay', + wxpay: 'wxpay', + wxpay_direct: 'wxpay', +} as const + +export type VisiblePaymentMethod = 'alipay' | 'wxpay' +export type StripeVisibleMethod = 'alipay' | 'wechat_pay' +export type PaymentLaunchKind = + | 'qr_waiting' + | 'redirect_waiting' + | 'stripe_popup' + | 'stripe_route' + | 'unhandled' + +export interface PaymentRecoverySnapshot { + orderId: number + amount: number + qrCode: string + expiresAt: string + paymentType: string + payUrl: string + clientSecret: string + payAmount: number + orderType: OrderType | '' + paymentMode: string + resumeToken: string + createdAt: number +} + +export interface PaymentLaunchContext { + visibleMethod: string + orderType: OrderType + isMobile: boolean + now?: number + stripePopupUrl?: string + stripeRouteUrl?: string +} + +export interface PaymentLaunchDecision { + kind: PaymentLaunchKind + paymentState: PaymentRecoverySnapshot + recovery: PaymentRecoverySnapshot + stripeMethod?: StripeVisibleMethod +} + +type CreateOrderFlowResult = CreateOrderResult & { + resume_token?: string +} + +type StorageWriter = Pick + +export function normalizeVisibleMethod(method: string): VisiblePaymentMethod | '' { + const normalized = VISIBLE_METHOD_ALIASES[method.trim() as keyof typeof VISIBLE_METHOD_ALIASES] + return normalized ?? '' +} + +export function getVisibleMethods(methods: Record): Record { + const visible: Record = {} + + Object.entries(methods).forEach(([type, limit]) => { + const normalized = normalizeVisibleMethod(type) + if (!normalized) return + + const isCanonical = type === normalized + const existing = visible[normalized] + if (!existing || isCanonical) { + visible[normalized] = { ...limit } + } + }) + + return visible +} + +export function decidePaymentLaunch( + result: CreateOrderFlowResult, + context: PaymentLaunchContext, +): PaymentLaunchDecision { + const visibleMethod = normalizeVisibleMethod(context.visibleMethod) || context.visibleMethod + const baseState = createPaymentRecoverySnapshot({ + orderId: result.order_id, + amount: result.amount, + qrCode: result.qr_code || '', + expiresAt: result.expires_at || '', + paymentType: visibleMethod, + payUrl: result.pay_url || '', + clientSecret: result.client_secret || '', + payAmount: result.pay_amount, + orderType: context.orderType, + paymentMode: (result.payment_mode || '').trim(), + resumeToken: result.resume_token || '', + }, context.now) + + if (baseState.clientSecret) { + const stripeMethod: StripeVisibleMethod = visibleMethod === 'wxpay' ? 'wechat_pay' : 'alipay' + const kind: PaymentLaunchKind = stripeMethod === 'alipay' && !context.isMobile + ? 'stripe_popup' + : 'stripe_route' + const payUrl = kind === 'stripe_popup' + ? context.stripePopupUrl || context.stripeRouteUrl || '' + : context.stripeRouteUrl || context.stripePopupUrl || '' + const paymentState = { ...baseState, payUrl } + return { kind, paymentState, recovery: paymentState, stripeMethod } + } + + if (baseState.qrCode) { + return { kind: 'qr_waiting', paymentState: baseState, recovery: baseState } + } + + if (baseState.payUrl) { + return { kind: 'redirect_waiting', paymentState: baseState, recovery: baseState } + } + + return { kind: 'unhandled', paymentState: baseState, recovery: baseState } +} + +export function createPaymentRecoverySnapshot( + state: Omit, + now = Date.now(), +): PaymentRecoverySnapshot { + return { + ...state, + createdAt: now, + } +} + +export function writePaymentRecoverySnapshot( + storage: StorageWriter, + snapshot: PaymentRecoverySnapshot, + key = PAYMENT_RECOVERY_STORAGE_KEY, +): void { + storage.setItem(key, JSON.stringify(snapshot)) +} + +export function clearPaymentRecoverySnapshot( + storage: Pick, + key = PAYMENT_RECOVERY_STORAGE_KEY, +): void { + storage.removeItem(key) +} + +export function readPaymentRecoverySnapshot( + raw: string | null | undefined, + options: { now?: number; resumeToken?: string } = {}, +): PaymentRecoverySnapshot | null { + if (!raw) return null + + try { + const parsed = JSON.parse(raw) as Partial + if ( + typeof parsed.orderId !== 'number' + || typeof parsed.amount !== 'number' + || typeof parsed.qrCode !== 'string' + || typeof parsed.expiresAt !== 'string' + || typeof parsed.paymentType !== 'string' + || typeof parsed.payUrl !== 'string' + || typeof parsed.clientSecret !== 'string' + || typeof parsed.payAmount !== 'number' + || typeof parsed.paymentMode !== 'string' + || typeof parsed.resumeToken !== 'string' + || typeof parsed.createdAt !== 'number' + ) { + return null + } + + const now = options.now ?? Date.now() + const expiresAt = Date.parse(parsed.expiresAt) + if (Number.isFinite(expiresAt) && expiresAt <= now) { + return null + } + if (options.resumeToken && parsed.resumeToken && parsed.resumeToken !== options.resumeToken) { + return null + } + + return { + orderId: parsed.orderId, + amount: parsed.amount, + qrCode: parsed.qrCode, + expiresAt: parsed.expiresAt, + paymentType: parsed.paymentType, + payUrl: parsed.payUrl, + clientSecret: parsed.clientSecret, + payAmount: parsed.payAmount, + orderType: parsed.orderType === 'subscription' ? 'subscription' : 'balance', + paymentMode: parsed.paymentMode, + resumeToken: parsed.resumeToken, + createdAt: parsed.createdAt, + } + } catch { + return null + } +} diff --git a/frontend/src/router/__tests__/wechat-route.spec.ts b/frontend/src/router/__tests__/wechat-route.spec.ts new file mode 100644 index 00000000..84b20452 --- /dev/null +++ b/frontend/src/router/__tests__/wechat-route.spec.ts @@ -0,0 +1,55 @@ +import { describe, expect, it, vi } from 'vitest' + +const authStore = vi.hoisted(() => ({ + checkAuth: vi.fn(), + isAuthenticated: false, + isAdmin: false, + isSimpleMode: false, +})) + +const appStore = vi.hoisted(() => ({ + siteName: 'Sub2API', + backendModeEnabled: false, + cachedPublicSettings: null as null | Record, +})) + +vi.mock('@/stores/auth', () => ({ + useAuthStore: () => authStore, +})) + +vi.mock('@/stores/app', () => ({ + useAppStore: () => appStore, +})) + +vi.mock('@/stores/adminSettings', () => ({ + useAdminSettingsStore: () => ({ + customMenuItems: [], + }), +})) + +vi.mock('@/composables/useNavigationLoading', () => ({ + useNavigationLoadingState: () => ({ + startNavigation: vi.fn(), + endNavigation: vi.fn(), + isLoading: { value: false }, + }), +})) + +vi.mock('@/composables/useRoutePrefetch', () => ({ + useRoutePrefetch: () => ({ + triggerPrefetch: vi.fn(), + cancelPendingPrefetch: vi.fn(), + resetPrefetchState: vi.fn(), + }), +})) + +describe('router WeChat OAuth route', () => { + it('registers the WeChat callback route as a public route', async () => { + const { default: router } = await import('@/router') + const route = router.getRoutes().find((record) => record.name === 'WeChatOAuthCallback') + + expect(route?.path).toBe('/auth/wechat/callback') + expect(route?.meta.requiresAuth).toBe(false) + expect(route?.meta.title).toBe('WeChat OAuth Callback') + }) +}) diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index ad6e71c4..beaa1da2 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -83,6 +83,15 @@ const routes: RouteRecordRaw[] = [ title: 'LinuxDo OAuth Callback' } }, + { + path: '/auth/wechat/callback', + name: 'WeChatOAuthCallback', + component: () => import('@/views/auth/WechatCallbackView.vue'), + meta: { + requiresAuth: false, + title: 'WeChat OAuth Callback' + } + }, { path: '/auth/oidc/callback', name: 'OIDCOAuthCallback', diff --git a/frontend/src/stores/app.ts b/frontend/src/stores/app.ts index 1995383d..1b1af87b 100644 --- a/frontend/src/stores/app.ts +++ b/frontend/src/stores/app.ts @@ -336,6 +336,7 @@ export const useAppStore = defineStore('app', () => { custom_menu_items: [], custom_endpoints: [], linuxdo_oauth_enabled: false, + wechat_oauth_enabled: false, oidc_oauth_enabled: false, oidc_oauth_provider_name: 'OIDC', backend_mode_enabled: false, diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 89fd777f..529eff55 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -123,6 +123,7 @@ export interface PublicSettings { custom_menu_items: CustomMenuItem[] custom_endpoints: CustomEndpoint[] linuxdo_oauth_enabled: boolean + wechat_oauth_enabled: boolean oidc_oauth_enabled: boolean oidc_oauth_provider_name: string backend_mode_enabled: boolean diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index 8bfa0f2b..0d23baa5 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -1586,6 +1586,221 @@ + +
+
+

+ {{ localText('认证来源默认值', 'Auth Source Defaults') }} +

+

+ {{ + localText( + '按注册来源配置新用户默认余额、并发、订阅与授权策略。', + 'Configure per-source default balance, concurrency, subscriptions, and grant rules.' + ) + }} +

+
+
+
+
+ +

+ {{ + localText( + '启用后,Linux DO、OIDC、微信注册缺少邮箱时必须先补充邮箱地址。', + 'When enabled, Linux DO, OIDC, and WeChat signups must provide an email before account creation.' + ) + }} +

+
+ +
+ +
+
+
+
{{ authSource.title }}
+

+ {{ authSource.description }} +

+
+ +
+
+ + +
+
+ + +
+
+ +
+
+
+ +

+ {{ + localText( + '来源首次注册成功后立即发放默认权益。', + 'Grant default entitlements immediately after signup.' + ) + }} +

+
+ +
+ +
+
+ +

+ {{ + localText( + '来源首次绑定到现有账号时发放默认权益。', + 'Grant default entitlements when the source is first bound to an existing user.' + ) + }} +

+
+ +
+
+ +
+
+
+ +

+ {{ + localText( + '仅对当前认证来源生效,未配置时不追加来源专属订阅。', + 'Applies only to this auth source. Leave empty to skip source-specific subscriptions.' + ) + }} +

+
+ +
+ +
+ {{ + localText( + '当前来源未配置专属默认订阅。', + 'No source-specific default subscriptions configured.' + ) + }} +
+ +
+
+
+ + +
+
+ + +
+
+ +
+
+
+
+
+
+
+
@@ -1643,19 +1858,38 @@

-
-
-
@@ -2450,6 +2684,59 @@

+
+
+
+
+ +

+ {{ + localText( + '控制前台结算页是否展示该方式,以及展示时使用的来源键。', + 'Controls whether checkout shows this method and which source key it exposes.' + ) + }} +

+
+ +
+ +
+ + +

+ {{ + localText( + '留空表示由后端使用默认来源;可填 easypay、alipay、wxpay 等来源标识。', + 'Leave blank to let the backend decide. Typical values are easypay, alipay, or wxpay.' + ) + }} +

+
+
+
@@ -2827,7 +3114,14 @@ import { ref, reactive, computed, onMounted } from 'vue' import { useI18n } from 'vue-i18n' import { adminAPI } from '@/api' +import { + appendAuthSourceDefaultsToUpdateRequest, + buildAuthSourceDefaultsState, + normalizeDefaultSubscriptionSettings, +} from '@/api/admin/settings' import type { + AuthSourceDefaultsState, + AuthSourceType, SystemSettings, UpdateSettingsRequest, DefaultSubscriptionSetting, @@ -2864,6 +3158,10 @@ const { t, locale } = useI18n() const appStore = useAppStore() const adminSettingsStore = useAdminSettingsStore() +function localText(zh: string, en: string): string { + return locale.value.startsWith('zh') ? zh : en +} + type SettingsTab = 'general' | 'security' | 'users' | 'gateway' | 'payment' | 'email' | 'backup' const activeTab = ref('general') const settingsTabs = [ @@ -2960,6 +3258,12 @@ type SettingsForm = SystemSettings & { turnstile_secret_key: string linuxdo_connect_client_secret: string oidc_connect_client_secret: string + force_email_on_third_party_signup: boolean + payment_visible_method_alipay_source: string + payment_visible_method_wxpay_source: string + payment_visible_method_alipay_enabled: boolean + payment_visible_method_wxpay_enabled: boolean + openai_advanced_scheduler_enabled: boolean } const form = reactive({ @@ -2974,6 +3278,7 @@ const form = reactive({ default_balance: 0, default_concurrency: 1, default_subscriptions: [], + force_email_on_third_party_signup: false, site_name: 'Sub2API', site_logo: '', site_subtitle: 'Subscription to API Conversion Platform', @@ -2983,7 +3288,7 @@ const form = reactive({ home_content: '', backend_mode_enabled: false, hide_ccs_import_button: false, - payment_enabled: false, payment_min_amount: 1, payment_max_amount: 10000, payment_daily_limit: 50000, payment_max_pending_orders: 3, payment_order_timeout_minutes: 30, payment_balance_disabled: false, payment_balance_recharge_multiplier: 1, payment_recharge_fee_rate: 0, payment_enabled_types: [], payment_help_image_url: '', payment_help_text: '', payment_product_name_prefix: '', payment_product_name_suffix: '', payment_load_balance_strategy: 'round-robin', payment_cancel_rate_limit_enabled: false, payment_cancel_rate_limit_max: 10, payment_cancel_rate_limit_window: 1, payment_cancel_rate_limit_unit: 'day', payment_cancel_rate_limit_window_mode: 'rolling', + payment_enabled: false, payment_min_amount: 1, payment_max_amount: 10000, payment_daily_limit: 50000, payment_max_pending_orders: 3, payment_order_timeout_minutes: 30, payment_balance_disabled: false, payment_balance_recharge_multiplier: 1, payment_recharge_fee_rate: 0, payment_enabled_types: [], payment_help_image_url: '', payment_help_text: '', payment_product_name_prefix: '', payment_product_name_suffix: '', payment_load_balance_strategy: 'round-robin', payment_cancel_rate_limit_enabled: false, payment_cancel_rate_limit_max: 10, payment_cancel_rate_limit_window: 1, payment_cancel_rate_limit_unit: 'day', payment_cancel_rate_limit_window_mode: 'rolling', payment_visible_method_alipay_source: '', payment_visible_method_wxpay_source: '', payment_visible_method_alipay_enabled: false, payment_visible_method_wxpay_enabled: false, table_default_page_size: tablePageSizeDefault, table_page_size_options: [10, 20, 50, 100], custom_menu_items: [] as Array<{id: string; label: string; icon_svg: string; url: string; visibility: 'user' | 'admin'; sort_order: number}>, @@ -3051,6 +3356,7 @@ const form = reactive({ max_claude_code_version: '', // 分组隔离 allow_ungrouped_key_scheduling: false, + openai_advanced_scheduler_enabled: false, // Gateway forwarding behavior enable_fingerprint_unification: true, enable_metadata_passthrough: false, @@ -3063,6 +3369,74 @@ const form = reactive({ account_quota_notify_emails: [] as NotifyEmailEntry[] }) +const authSourceDefaults = reactive(buildAuthSourceDefaultsState({})) + +const authSourceDefaultsMeta = computed(() => [ + { + source: 'email' as AuthSourceType, + title: localText('邮箱注册', 'Email signup'), + description: localText('适用于邮箱密码注册的新用户默认配额。', 'Default quota grants for email-password signups.') + }, + { + source: 'linuxdo' as AuthSourceType, + title: localText('Linux DO 登录', 'Linux DO signup'), + description: localText('适用于 Linux DO 第三方注册的新用户默认配额。', 'Default quota grants for Linux DO signups.') + }, + { + source: 'oidc' as AuthSourceType, + title: localText('OIDC 登录', 'OIDC signup'), + description: localText('适用于 OIDC 第三方注册的新用户默认配额。', 'Default quota grants for OIDC signups.') + }, + { + source: 'wechat' as AuthSourceType, + title: localText('微信登录', 'WeChat signup'), + description: localText('适用于微信第三方注册的新用户默认配额。', 'Default quota grants for WeChat signups.') + }, +]) + +const paymentVisibleMethodCards = computed(() => [ + { + key: 'alipay' as const, + title: t('payment.methods.alipay'), + enabledField: 'payment_visible_method_alipay_enabled' as const, + sourceField: 'payment_visible_method_alipay_source' as const, + }, + { + key: 'wxpay' as const, + title: t('payment.methods.wxpay'), + enabledField: 'payment_visible_method_wxpay_enabled' as const, + sourceField: 'payment_visible_method_wxpay_source' as const, + }, +]) + +function getPaymentVisibleMethodEnabled(method: 'alipay' | 'wxpay'): boolean { + return method === 'alipay' + ? form.payment_visible_method_alipay_enabled + : form.payment_visible_method_wxpay_enabled +} + +function setPaymentVisibleMethodEnabled(method: 'alipay' | 'wxpay', enabled: boolean) { + if (method === 'alipay') { + form.payment_visible_method_alipay_enabled = enabled + return + } + form.payment_visible_method_wxpay_enabled = enabled +} + +function getPaymentVisibleMethodSource(method: 'alipay' | 'wxpay'): string { + return method === 'alipay' + ? form.payment_visible_method_alipay_source + : form.payment_visible_method_wxpay_source +} + +function setPaymentVisibleMethodSource(method: 'alipay' | 'wxpay', source: string) { + if (method === 'alipay') { + form.payment_visible_method_alipay_source = source + return + } + form.payment_visible_method_wxpay_source = source +} + // Proxies for web search emulation ProxySelector const webSearchProxies = ref([]) @@ -3428,15 +3802,9 @@ async function loadSettings() { (form as Record)[key] = value } } + Object.assign(authSourceDefaults, buildAuthSourceDefaultsState(settings)) form.backend_mode_enabled = settings.backend_mode_enabled - form.default_subscriptions = Array.isArray(settings.default_subscriptions) - ? settings.default_subscriptions - .filter((item) => item.group_id > 0 && item.validity_days > 0) - .map((item) => ({ - group_id: item.group_id, - validity_days: item.validity_days - })) - : [] + form.default_subscriptions = normalizeDefaultSubscriptionSettings(settings.default_subscriptions) registrationEmailSuffixWhitelistTags.value = normalizeRegistrationEmailSuffixDomains( settings.registration_email_suffix_whitelist ) @@ -3471,10 +3839,18 @@ async function loadSubscriptionGroups() { } } +function findNextAvailableSubscriptionGroup( + existingGroupIDs: number[] +): AdminGroup | undefined { + const existing = new Set(existingGroupIDs) + return subscriptionGroups.value.find((group) => !existing.has(group.id)) +} + function addDefaultSubscription() { if (subscriptionGroups.value.length === 0) return - const existing = new Set(form.default_subscriptions.map((item) => item.group_id)) - const candidate = subscriptionGroups.value.find((group) => !existing.has(group.id)) + const candidate = findNextAvailableSubscriptionGroup( + form.default_subscriptions.map((item) => item.group_id) + ) if (!candidate) return form.default_subscriptions.push({ group_id: candidate.id, @@ -3486,6 +3862,36 @@ function removeDefaultSubscription(index: number) { form.default_subscriptions.splice(index, 1) } +function addAuthSourceDefaultSubscription(source: AuthSourceType) { + if (subscriptionGroups.value.length === 0) return + const candidate = findNextAvailableSubscriptionGroup( + authSourceDefaults[source].subscriptions.map((item) => item.group_id) + ) + if (!candidate) return + authSourceDefaults[source].subscriptions.push({ + group_id: candidate.id, + validity_days: 30 + }) +} + +function removeAuthSourceDefaultSubscription(source: AuthSourceType, index: number) { + authSourceDefaults[source].subscriptions.splice(index, 1) +} + +function findDuplicateDefaultSubscription( + subscriptions: DefaultSubscriptionSetting[] +): DefaultSubscriptionSetting | undefined { + const seenGroupIDs = new Set() + + return subscriptions.find((item) => { + if (seenGroupIDs.has(item.group_id)) { + return true + } + seenGroupIDs.add(item.group_id) + return false + }) +} + async function saveSettings() { saving.value = true try { @@ -3520,21 +3926,12 @@ async function saveSettings() { form.table_default_page_size = normalizedTableDefaultPageSize form.table_page_size_options = normalizedTablePageSizeOptions - const normalizedDefaultSubscriptions = form.default_subscriptions - .filter((item) => item.group_id > 0 && item.validity_days > 0) - .map((item: DefaultSubscriptionSetting) => ({ - group_id: item.group_id, - validity_days: Math.min(36500, Math.max(1, Math.floor(item.validity_days))) - })) - - const seenGroupIDs = new Set() - const duplicateDefaultSubscription = normalizedDefaultSubscriptions.find((item) => { - if (seenGroupIDs.has(item.group_id)) { - return true - } - seenGroupIDs.add(item.group_id) - return false - }) + const normalizedDefaultSubscriptions = normalizeDefaultSubscriptionSettings( + form.default_subscriptions + ) + const duplicateDefaultSubscription = findDuplicateDefaultSubscription( + normalizedDefaultSubscriptions + ) if (duplicateDefaultSubscription) { appStore.showError( t('admin.settings.defaults.defaultSubscriptionsDuplicate', { @@ -3544,6 +3941,23 @@ async function saveSettings() { return } + for (const authSource of authSourceDefaultsMeta.value) { + authSourceDefaults[authSource.source].subscriptions = normalizeDefaultSubscriptionSettings( + authSourceDefaults[authSource.source].subscriptions + ) + const duplicate = findDuplicateDefaultSubscription( + authSourceDefaults[authSource.source].subscriptions + ) + if (duplicate) { + appStore.showError( + `${authSource.title}: ${t('admin.settings.defaults.defaultSubscriptionsDuplicate', { + groupId: duplicate.group_id + })}` + ) + return + } + } + // Validate URL fields — novalidate disables browser-native checks, so we validate here const isValidHttpUrl = (url: string): boolean => { if (!url) return true @@ -3571,6 +3985,7 @@ async function saveSettings() { default_balance: form.default_balance, default_concurrency: form.default_concurrency, default_subscriptions: normalizedDefaultSubscriptions, + force_email_on_third_party_signup: form.force_email_on_third_party_signup, site_name: form.site_name, site_logo: form.site_logo, site_subtitle: form.site_subtitle, @@ -3655,6 +4070,11 @@ async function saveSettings() { payment_cancel_rate_limit_window: Number(form.payment_cancel_rate_limit_window) || 1, payment_cancel_rate_limit_unit: form.payment_cancel_rate_limit_unit, payment_cancel_rate_limit_window_mode: form.payment_cancel_rate_limit_window_mode, + payment_visible_method_alipay_source: form.payment_visible_method_alipay_source, + payment_visible_method_wxpay_source: form.payment_visible_method_wxpay_source, + payment_visible_method_alipay_enabled: form.payment_visible_method_alipay_enabled, + payment_visible_method_wxpay_enabled: form.payment_visible_method_wxpay_enabled, + openai_advanced_scheduler_enabled: form.openai_advanced_scheduler_enabled, // Balance & quota notification balance_low_notify_enabled: form.balance_low_notify_enabled, balance_low_notify_threshold: Number(form.balance_low_notify_threshold) || 0, @@ -3663,12 +4083,15 @@ async function saveSettings() { account_quota_notify_emails: (form.account_quota_notify_emails || []).filter((e) => e.email.trim() !== ''), } + appendAuthSourceDefaultsToUpdateRequest(payload, authSourceDefaults) + const updated = await adminAPI.settings.updateSettings(payload) for (const [key, value] of Object.entries(updated)) { if (value !== null && value !== undefined) { (form as Record)[key] = value } } + Object.assign(authSourceDefaults, buildAuthSourceDefaultsState(updated)) registrationEmailSuffixWhitelistTags.value = normalizeRegistrationEmailSuffixDomains( updated.registration_email_suffix_whitelist ) diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue index af48959b..0a125def 100644 --- a/frontend/src/views/auth/LinuxDoCallbackView.vue +++ b/frontend/src/views/auth/LinuxDoCallbackView.vue @@ -11,32 +11,94 @@
-
-

- {{ t('auth.linuxdo.invitationRequired') }} -

-
- -
- -

- {{ invitationError }} -

-
- +
+
+

+ Use LinuxDo profile details +

+

+ Choose whether to apply the nickname or avatar from LinuxDo to this account. +

+
+ + + + +
+
+ + + +
@@ -71,7 +133,12 @@ import { useI18n } from 'vue-i18n' import { AuthLayout } from '@/components/layout' import Icon from '@/components/icons/Icon.vue' import { useAuthStore, useAppStore } from '@/stores' -import { completeLinuxDoOAuthRegistration } from '@/api/auth' +import { + completeLinuxDoOAuthRegistration, + exchangePendingOAuthCompletion, + type OAuthAdoptionDecision, + type PendingOAuthExchangeResponse +} from '@/api/auth' const route = useRoute() const router = useRouter() @@ -85,11 +152,16 @@ const errorMessage = ref('') // Invitation code flow state const needsInvitation = ref(false) -const pendingOAuthToken = ref('') const invitationCode = ref('') const isSubmitting = ref(false) const invitationError = ref('') const redirectTo = ref('/dashboard') +const adoptionRequired = ref(false) +const suggestedDisplayName = ref('') +const suggestedAvatarUrl = ref('') +const adoptDisplayName = ref(true) +const adoptAvatar = ref(true) +const needsAdoptionConfirmation = ref(false) function parseFragmentParams(): URLSearchParams { const raw = typeof window !== 'undefined' ? window.location.hash : '' @@ -106,6 +178,54 @@ function sanitizeRedirectPath(path: string | null | undefined): string { return path } +function currentAdoptionDecision(): OAuthAdoptionDecision { + return { + adoptDisplayName: adoptDisplayName.value, + adoptAvatar: adoptAvatar.value + } +} + +function applyAdoptionSuggestionState(completion: { + adoption_required?: boolean + suggested_display_name?: string + suggested_avatar_url?: string +}) { + adoptionRequired.value = completion.adoption_required === true + suggestedDisplayName.value = completion.suggested_display_name || '' + suggestedAvatarUrl.value = completion.suggested_avatar_url || '' + + if (!suggestedDisplayName.value) { + adoptDisplayName.value = false + } + if (!suggestedAvatarUrl.value) { + adoptAvatar.value = false + } +} + +function hasSuggestedProfile(completion: { + suggested_display_name?: string + suggested_avatar_url?: string +}): boolean { + return Boolean(completion.suggested_display_name || completion.suggested_avatar_url) +} + +async function finalizeLogin(completion: PendingOAuthExchangeResponse, redirect: string) { + if (!completion.access_token) { + throw new Error(t('auth.linuxdo.callbackMissingToken')) + } + + if (completion.refresh_token) { + localStorage.setItem('refresh_token', completion.refresh_token) + } + if (completion.expires_in) { + localStorage.setItem('token_expires_at', String(Date.now() + completion.expires_in * 1000)) + } + + await authStore.setToken(completion.access_token) + appStore.showSuccess(t('auth.loginSuccess')) + await router.replace(redirect) +} + async function handleSubmitInvitation() { invitationError.value = '' if (!invitationCode.value.trim()) return @@ -113,8 +233,8 @@ async function handleSubmitInvitation() { isSubmitting.value = true try { const tokenData = await completeLinuxDoOAuthRegistration( - pendingOAuthToken.value, - invitationCode.value.trim() + invitationCode.value.trim(), + currentAdoptionDecision() ) if (tokenData.refresh_token) { localStorage.setItem('refresh_token', tokenData.refresh_token) @@ -134,63 +254,65 @@ async function handleSubmitInvitation() { } } +async function handleContinueLogin() { + isSubmitting.value = true + try { + const completion = await exchangePendingOAuthCompletion(currentAdoptionDecision()) + await finalizeLogin(completion, redirectTo.value) + } catch (e: unknown) { + const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } + errorMessage.value = + err.response?.data?.detail || + err.response?.data?.message || + err.message || + t('auth.loginFailed') + appStore.showError(errorMessage.value) + needsAdoptionConfirmation.value = false + } finally { + isSubmitting.value = false + } +} + onMounted(async () => { const params = parseFragmentParams() - - const token = params.get('access_token') || '' - const refreshToken = params.get('refresh_token') || '' - const expiresInStr = params.get('expires_in') || '' - const redirect = sanitizeRedirectPath( - params.get('redirect') || (route.query.redirect as string | undefined) || '/dashboard' - ) const error = params.get('error') const errorDesc = params.get('error_description') || params.get('error_message') || '' if (error) { - if (error === 'invitation_required') { - pendingOAuthToken.value = params.get('pending_oauth_token') || '' - redirectTo.value = sanitizeRedirectPath(params.get('redirect')) - if (!pendingOAuthToken.value) { - errorMessage.value = t('auth.linuxdo.invalidPendingToken') - appStore.showError(errorMessage.value) - isProcessing.value = false - return - } - needsInvitation.value = true - isProcessing.value = false - return - } errorMessage.value = errorDesc || error appStore.showError(errorMessage.value) isProcessing.value = false return } - if (!token) { - errorMessage.value = t('auth.linuxdo.callbackMissingToken') - appStore.showError(errorMessage.value) - isProcessing.value = false - return - } - try { - // Store refresh token and expires_at (convert to timestamp) if provided - if (refreshToken) { - localStorage.setItem('refresh_token', refreshToken) - } - if (expiresInStr) { - const expiresIn = parseInt(expiresInStr, 10) - if (!isNaN(expiresIn)) { - localStorage.setItem('token_expires_at', String(Date.now() + expiresIn * 1000)) - } + const completion = await exchangePendingOAuthCompletion() + const redirect = sanitizeRedirectPath( + completion.redirect || (route.query.redirect as string | undefined) || '/dashboard' + ) + applyAdoptionSuggestionState(completion) + redirectTo.value = redirect + + if (completion.error === 'invitation_required') { + needsInvitation.value = true + isProcessing.value = false + return } - await authStore.setToken(token) - appStore.showSuccess(t('auth.loginSuccess')) - await router.replace(redirect) + if (adoptionRequired.value && hasSuggestedProfile(completion)) { + needsAdoptionConfirmation.value = true + isProcessing.value = false + return + } + + await finalizeLogin(completion, redirect) } catch (e: unknown) { - const err = e as { message?: string; response?: { data?: { detail?: string } } } - errorMessage.value = err.response?.data?.detail || err.message || t('auth.loginFailed') + const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } + errorMessage.value = + err.response?.data?.detail || + err.response?.data?.message || + err.message || + t('auth.loginFailed') appStore.showError(errorMessage.value) isProcessing.value = false } @@ -209,4 +331,3 @@ onMounted(async () => { transform: translateY(-8px); } - diff --git a/frontend/src/views/auth/LoginView.vue b/frontend/src/views/auth/LoginView.vue index 70b64e3f..fa4ac34c 100644 --- a/frontend/src/views/auth/LoginView.vue +++ b/frontend/src/views/auth/LoginView.vue @@ -11,12 +11,17 @@

-
+
+ (false) const turnstileEnabled = ref(false) const turnstileSiteKey = ref('') const linuxdoOAuthEnabled = ref(false) +const wechatOAuthEnabled = ref(false) const backendModeEnabled = ref(false) const oidcOAuthEnabled = ref(false) const oidcOAuthProviderName = ref('OIDC') @@ -267,6 +274,7 @@ onMounted(async () => { turnstileEnabled.value = settings.turnstile_enabled turnstileSiteKey.value = settings.turnstile_site_key || '' linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled + wechatOAuthEnabled.value = settings.wechat_oauth_enabled backendModeEnabled.value = settings.backend_mode_enabled oidcOAuthEnabled.value = settings.oidc_oauth_enabled oidcOAuthProviderName.value = settings.oidc_oauth_provider_name || 'OIDC' diff --git a/frontend/src/views/auth/OidcCallbackView.vue b/frontend/src/views/auth/OidcCallbackView.vue index a6cb6c12..55f8af6e 100644 --- a/frontend/src/views/auth/OidcCallbackView.vue +++ b/frontend/src/views/auth/OidcCallbackView.vue @@ -15,36 +15,99 @@
-
-

- {{ t('auth.oidc.invitationRequired', { providerName }) }} -

-
- -
- -

- {{ invitationError }} -

-
- +
+
+

+ Use {{ providerName }} profile details +

+

+ Choose whether to apply the nickname or avatar from {{ providerName }} to this + account. +

+
+ + + + +
+
+ + + +
@@ -81,7 +144,10 @@ import Icon from '@/components/icons/Icon.vue' import { useAuthStore, useAppStore } from '@/stores' import { completeOIDCOAuthRegistration, - getPublicSettings + exchangePendingOAuthCompletion, + getPublicSettings, + type OAuthAdoptionDecision, + type PendingOAuthExchangeResponse } from '@/api/auth' const route = useRoute() @@ -95,12 +161,17 @@ const isProcessing = ref(true) const errorMessage = ref('') const needsInvitation = ref(false) -const pendingOAuthToken = ref('') const invitationCode = ref('') const isSubmitting = ref(false) const invitationError = ref('') const redirectTo = ref('/dashboard') const providerName = ref('OIDC') +const adoptionRequired = ref(false) +const suggestedDisplayName = ref('') +const suggestedAvatarUrl = ref('') +const adoptDisplayName = ref(true) +const adoptAvatar = ref(true) +const needsAdoptionConfirmation = ref(false) function parseFragmentParams(): URLSearchParams { const raw = typeof window !== 'undefined' ? window.location.hash : '' @@ -129,6 +200,54 @@ async function loadProviderName() { } } +function currentAdoptionDecision(): OAuthAdoptionDecision { + return { + adoptDisplayName: adoptDisplayName.value, + adoptAvatar: adoptAvatar.value + } +} + +function applyAdoptionSuggestionState(completion: { + adoption_required?: boolean + suggested_display_name?: string + suggested_avatar_url?: string +}) { + adoptionRequired.value = completion.adoption_required === true + suggestedDisplayName.value = completion.suggested_display_name || '' + suggestedAvatarUrl.value = completion.suggested_avatar_url || '' + + if (!suggestedDisplayName.value) { + adoptDisplayName.value = false + } + if (!suggestedAvatarUrl.value) { + adoptAvatar.value = false + } +} + +function hasSuggestedProfile(completion: { + suggested_display_name?: string + suggested_avatar_url?: string +}): boolean { + return Boolean(completion.suggested_display_name || completion.suggested_avatar_url) +} + +async function finalizeLogin(completion: PendingOAuthExchangeResponse, redirect: string) { + if (!completion.access_token) { + throw new Error(t('auth.oidc.callbackMissingToken')) + } + + if (completion.refresh_token) { + localStorage.setItem('refresh_token', completion.refresh_token) + } + if (completion.expires_in) { + localStorage.setItem('token_expires_at', String(Date.now() + completion.expires_in * 1000)) + } + + await authStore.setToken(completion.access_token) + appStore.showSuccess(t('auth.loginSuccess')) + await router.replace(redirect) +} + async function handleSubmitInvitation() { invitationError.value = '' if (!invitationCode.value.trim()) return @@ -136,8 +255,8 @@ async function handleSubmitInvitation() { isSubmitting.value = true try { const tokenData = await completeOIDCOAuthRegistration( - pendingOAuthToken.value, - invitationCode.value.trim() + invitationCode.value.trim(), + currentAdoptionDecision() ) if (tokenData.refresh_token) { localStorage.setItem('refresh_token', tokenData.refresh_token) @@ -157,63 +276,67 @@ async function handleSubmitInvitation() { } } +async function handleContinueLogin() { + isSubmitting.value = true + try { + const completion = await exchangePendingOAuthCompletion(currentAdoptionDecision()) + await finalizeLogin(completion, redirectTo.value) + } catch (e: unknown) { + const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } + errorMessage.value = + err.response?.data?.detail || + err.response?.data?.message || + err.message || + t('auth.loginFailed') + appStore.showError(errorMessage.value) + needsAdoptionConfirmation.value = false + } finally { + isSubmitting.value = false + } +} + onMounted(async () => { void loadProviderName() const params = parseFragmentParams() - const token = params.get('access_token') || '' - const refreshToken = params.get('refresh_token') || '' - const expiresInStr = params.get('expires_in') || '' - const redirect = sanitizeRedirectPath( - params.get('redirect') || (route.query.redirect as string | undefined) || '/dashboard' - ) const error = params.get('error') const errorDesc = params.get('error_description') || params.get('error_message') || '' if (error) { - if (error === 'invitation_required') { - pendingOAuthToken.value = params.get('pending_oauth_token') || '' - redirectTo.value = sanitizeRedirectPath(params.get('redirect')) - if (!pendingOAuthToken.value) { - errorMessage.value = t('auth.oidc.invalidPendingToken') - appStore.showError(errorMessage.value) - isProcessing.value = false - return - } - needsInvitation.value = true - isProcessing.value = false - return - } errorMessage.value = errorDesc || error appStore.showError(errorMessage.value) isProcessing.value = false return } - if (!token) { - errorMessage.value = t('auth.oidc.callbackMissingToken') - appStore.showError(errorMessage.value) - isProcessing.value = false - return - } - try { - if (refreshToken) { - localStorage.setItem('refresh_token', refreshToken) - } - if (expiresInStr) { - const expiresIn = parseInt(expiresInStr, 10) - if (!isNaN(expiresIn)) { - localStorage.setItem('token_expires_at', String(Date.now() + expiresIn * 1000)) - } + const completion = await exchangePendingOAuthCompletion() + const redirect = sanitizeRedirectPath( + completion.redirect || (route.query.redirect as string | undefined) || '/dashboard' + ) + applyAdoptionSuggestionState(completion) + redirectTo.value = redirect + + if (completion.error === 'invitation_required') { + needsInvitation.value = true + isProcessing.value = false + return } - await authStore.setToken(token) - appStore.showSuccess(t('auth.loginSuccess')) - await router.replace(redirect) + if (adoptionRequired.value && hasSuggestedProfile(completion)) { + needsAdoptionConfirmation.value = true + isProcessing.value = false + return + } + + await finalizeLogin(completion, redirect) } catch (e: unknown) { - const err = e as { message?: string; response?: { data?: { detail?: string } } } - errorMessage.value = err.response?.data?.detail || err.message || t('auth.loginFailed') + const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } + errorMessage.value = + err.response?.data?.detail || + err.response?.data?.message || + err.message || + t('auth.loginFailed') appStore.showError(errorMessage.value) isProcessing.value = false } diff --git a/frontend/src/views/auth/RegisterView.vue b/frontend/src/views/auth/RegisterView.vue index bc8b8dce..378f9d8a 100644 --- a/frontend/src/views/auth/RegisterView.vue +++ b/frontend/src/views/auth/RegisterView.vue @@ -11,12 +11,17 @@

-
+
+ (false) const turnstileSiteKey = ref('') const siteName = ref('Sub2API') const linuxdoOAuthEnabled = ref(false) +const wechatOAuthEnabled = ref(false) const oidcOAuthEnabled = ref(false) const oidcOAuthProviderName = ref('OIDC') const registrationEmailSuffixWhitelist = ref([]) @@ -397,6 +404,7 @@ onMounted(async () => { turnstileSiteKey.value = settings.turnstile_site_key || '' siteName.value = settings.site_name || 'Sub2API' linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled + wechatOAuthEnabled.value = settings.wechat_oauth_enabled oidcOAuthEnabled.value = settings.oidc_oauth_enabled oidcOAuthProviderName.value = settings.oidc_oauth_provider_name || 'OIDC' registrationEmailSuffixWhitelist.value = normalizeRegistrationEmailSuffixWhitelist( diff --git a/frontend/src/views/auth/WechatCallbackView.vue b/frontend/src/views/auth/WechatCallbackView.vue new file mode 100644 index 00000000..407b395b --- /dev/null +++ b/frontend/src/views/auth/WechatCallbackView.vue @@ -0,0 +1,361 @@ + + + + + diff --git a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts new file mode 100644 index 00000000..60a40474 --- /dev/null +++ b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts @@ -0,0 +1,180 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { flushPromises, mount } from '@vue/test-utils' + +import LinuxDoCallbackView from '../LinuxDoCallbackView.vue' + +const replace = vi.fn() +const showSuccess = vi.fn() +const showError = vi.fn() +const setToken = vi.fn() +const exchangePendingOAuthCompletion = vi.fn() +const completeLinuxDoOAuthRegistration = vi.fn() + +vi.mock('vue-router', () => ({ + useRoute: () => ({ + query: {} + }), + useRouter: () => ({ + replace + }) +})) + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + t: (key: string) => key + }) + } +}) + +vi.mock('@/stores', () => ({ + useAuthStore: () => ({ + setToken + }), + useAppStore: () => ({ + showSuccess, + showError + }) +})) + +vi.mock('@/api/auth', () => ({ + exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args), + completeLinuxDoOAuthRegistration: (...args: any[]) => completeLinuxDoOAuthRegistration(...args) +})) + +describe('LinuxDoCallbackView', () => { + beforeEach(() => { + replace.mockReset() + showSuccess.mockReset() + showError.mockReset() + setToken.mockReset() + exchangePendingOAuthCompletion.mockReset() + completeLinuxDoOAuthRegistration.mockReset() + }) + + it('does not send adoption decisions during the initial exchange', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard', + adoption_required: true + }) + setToken.mockResolvedValue({}) + + mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(1) + expect(exchangePendingOAuthCompletion).toHaveBeenCalledWith() + }) + + it('waits for explicit adoption confirmation before finishing a non-invitation login', async () => { + exchangePendingOAuthCompletion + .mockResolvedValueOnce({ + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + .mockResolvedValueOnce({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard' + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(wrapper.text()).toContain('LinuxDo Nick') + expect(setToken).not.toHaveBeenCalled() + expect(replace).not.toHaveBeenCalled() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + await checkboxes[1].setValue(false) + + const buttons = wrapper.findAll('button') + expect(buttons).toHaveLength(1) + await buttons[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(2) + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(1) + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, { + adoptDisplayName: true, + adoptAvatar: false + }) + expect(setToken).toHaveBeenCalledWith('access-token') + expect(replace).toHaveBeenCalledWith('/dashboard') + }) + + it('renders adoption choices for invitation flow and submits the selected values', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'invitation_required', + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + completeLinuxDoOAuthRegistration.mockResolvedValue({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + token_type: 'Bearer' + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(wrapper.text()).toContain('LinuxDo Nick') + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(1) + expect(exchangePendingOAuthCompletion).toHaveBeenCalledWith() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + + await checkboxes[0].setValue(false) + await wrapper.find('input[type="text"]').setValue('invite-code') + await wrapper.find('button').trigger('click') + + expect(completeLinuxDoOAuthRegistration).toHaveBeenCalledWith('invite-code', { + adoptDisplayName: false, + adoptAvatar: true + }) + }) +}) diff --git a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts new file mode 100644 index 00000000..299c0746 --- /dev/null +++ b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts @@ -0,0 +1,191 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { flushPromises, mount } from '@vue/test-utils' + +import OidcCallbackView from '../OidcCallbackView.vue' + +const replace = vi.fn() +const showSuccess = vi.fn() +const showError = vi.fn() +const setToken = vi.fn() +const exchangePendingOAuthCompletion = vi.fn() +const completeOIDCOAuthRegistration = vi.fn() +const getPublicSettings = vi.fn() + +vi.mock('vue-router', () => ({ + useRoute: () => ({ + query: {} + }), + useRouter: () => ({ + replace + }) +})) + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + t: (key: string, params?: Record) => { + if (!params?.providerName) { + return key + } + return `${key}:${params.providerName}` + } + }) + } +}) + +vi.mock('@/stores', () => ({ + useAuthStore: () => ({ + setToken + }), + useAppStore: () => ({ + showSuccess, + showError + }) +})) + +vi.mock('@/api/auth', () => ({ + exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args), + completeOIDCOAuthRegistration: (...args: any[]) => completeOIDCOAuthRegistration(...args), + getPublicSettings: (...args: any[]) => getPublicSettings(...args) +})) + +describe('OidcCallbackView', () => { + beforeEach(() => { + replace.mockReset() + showSuccess.mockReset() + showError.mockReset() + setToken.mockReset() + exchangePendingOAuthCompletion.mockReset() + completeOIDCOAuthRegistration.mockReset() + getPublicSettings.mockReset() + getPublicSettings.mockResolvedValue({ + oidc_oauth_provider_name: 'ExampleID' + }) + }) + + it('does not send adoption decisions during the initial exchange', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard', + adoption_required: true + }) + setToken.mockResolvedValue({}) + + mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(1) + expect(exchangePendingOAuthCompletion).toHaveBeenCalledWith() + }) + + it('waits for explicit adoption confirmation before finishing a non-invitation login', async () => { + exchangePendingOAuthCompletion + .mockResolvedValueOnce({ + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + .mockResolvedValueOnce({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard' + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(wrapper.text()).toContain('OIDC Nick') + expect(setToken).not.toHaveBeenCalled() + expect(replace).not.toHaveBeenCalled() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + await checkboxes[0].setValue(false) + + const buttons = wrapper.findAll('button') + expect(buttons).toHaveLength(1) + await buttons[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(2) + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(1) + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, { + adoptDisplayName: false, + adoptAvatar: true + }) + expect(setToken).toHaveBeenCalledWith('access-token') + expect(replace).toHaveBeenCalledWith('/dashboard') + }) + + it('renders adoption choices for invitation flow and submits the selected values', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'invitation_required', + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + completeOIDCOAuthRegistration.mockResolvedValue({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + token_type: 'Bearer' + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(wrapper.text()).toContain('OIDC Nick') + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(1) + expect(exchangePendingOAuthCompletion).toHaveBeenCalledWith() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + + await checkboxes[1].setValue(false) + await wrapper.find('input[type="text"]').setValue('invite-code') + await wrapper.find('button').trigger('click') + + expect(completeOIDCOAuthRegistration).toHaveBeenCalledWith('invite-code', { + adoptDisplayName: true, + adoptAvatar: false + }) + }) +}) diff --git a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts new file mode 100644 index 00000000..a9e2ada2 --- /dev/null +++ b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts @@ -0,0 +1,241 @@ +import { flushPromises, mount } from '@vue/test-utils' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import WechatCallbackView from '@/views/auth/WechatCallbackView.vue' + +const { + postMock, + replaceMock, + setTokenMock, + showSuccessMock, + showErrorMock, + routeState, +} = vi.hoisted(() => ({ + postMock: vi.fn(), + replaceMock: vi.fn(), + setTokenMock: vi.fn(), + showSuccessMock: vi.fn(), + showErrorMock: vi.fn(), + routeState: { + query: {} as Record, + }, +})) + +vi.mock('vue-router', () => ({ + useRoute: () => routeState, + useRouter: () => ({ + replace: replaceMock, + }), +})) + +vi.mock('vue-i18n', () => ({ + createI18n: () => ({ + global: { + t: (key: string) => key, + }, + }), + useI18n: () => ({ + t: (key: string, params?: Record) => { + if (key === 'auth.oidc.callbackTitle') { + return `Signing you in with ${params?.providerName ?? ''}`.trim() + } + if (key === 'auth.oidc.callbackProcessing') { + return `Completing login with ${params?.providerName ?? ''}`.trim() + } + if (key === 'auth.oidc.invitationRequired') { + return `${params?.providerName ?? ''} invitation required`.trim() + } + if (key === 'auth.oidc.completeRegistration') { + return 'Complete registration' + } + if (key === 'auth.oidc.completing') { + return 'Completing' + } + if (key === 'auth.oidc.backToLogin') { + return 'Back to login' + } + if (key === 'auth.invitationCodePlaceholder') { + return 'Invitation code' + } + if (key === 'auth.loginSuccess') { + return 'Login success' + } + if (key === 'auth.loginFailed') { + return 'Login failed' + } + if (key === 'auth.oidc.callbackHint') { + return 'Callback hint' + } + if (key === 'auth.oidc.callbackMissingToken') { + return 'Missing login token' + } + if (key === 'auth.oidc.completeRegistrationFailed') { + return 'Complete registration failed' + } + return key + }, + }), +})) + +vi.mock('@/stores', () => ({ + useAuthStore: () => ({ + setToken: setTokenMock, + }), + useAppStore: () => ({ + showSuccess: showSuccessMock, + showError: showErrorMock, + }), +})) + +vi.mock('@/api/client', () => ({ + apiClient: { + post: postMock, + }, +})) + +describe('WechatCallbackView', () => { + beforeEach(() => { + postMock.mockReset() + replaceMock.mockReset() + setTokenMock.mockReset() + showSuccessMock.mockReset() + showErrorMock.mockReset() + routeState.query = {} + localStorage.clear() + }) + + it('does not send adoption decisions during the initial exchange', async () => { + postMock.mockResolvedValueOnce({ + data: { + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard', + adoption_required: true, + }, + }) + setTokenMock.mockResolvedValue({}) + + mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(postMock).toHaveBeenCalledWith('/auth/oauth/pending/exchange', {}) + expect(postMock).toHaveBeenCalledTimes(1) + }) + + it('waits for explicit adoption confirmation before finishing a non-invitation login', async () => { + postMock + .mockResolvedValueOnce({ + data: { + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }, + }) + .mockResolvedValueOnce({ + data: { + access_token: 'wechat-access-token', + refresh_token: 'wechat-refresh-token', + expires_in: 3600, + token_type: 'Bearer', + redirect: '/dashboard', + }, + }) + setTokenMock.mockResolvedValue({}) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(wrapper.text()).toContain('WeChat Nick') + expect(setTokenMock).not.toHaveBeenCalled() + expect(replaceMock).not.toHaveBeenCalled() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + await checkboxes[1].setValue(false) + + const buttons = wrapper.findAll('button') + expect(buttons).toHaveLength(1) + await buttons[0].trigger('click') + await flushPromises() + + expect(postMock).toHaveBeenNthCalledWith(1, '/auth/oauth/pending/exchange', {}) + expect(postMock).toHaveBeenNthCalledWith(2, '/auth/oauth/pending/exchange', { + adopt_display_name: true, + adopt_avatar: false, + }) + expect(setTokenMock).toHaveBeenCalledWith('wechat-access-token') + expect(replaceMock).toHaveBeenCalledWith('/dashboard') + expect(localStorage.getItem('refresh_token')).toBe('wechat-refresh-token') + }) + + it('renders adoption choices for invitation flow and submits the selected values', async () => { + postMock + .mockResolvedValueOnce({ + data: { + error: 'invitation_required', + redirect: '/subscriptions', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }, + }) + .mockResolvedValueOnce({ + data: { + access_token: 'wechat-invite-token', + refresh_token: 'wechat-invite-refresh', + expires_in: 600, + token_type: 'Bearer', + }, + }) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(wrapper.text()).toContain('WeChat Nick') + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + await checkboxes[0].setValue(false) + await wrapper.get('input[type="text"]').setValue(' INVITE-CODE ') + await wrapper.get('button').trigger('click') + await flushPromises() + + expect(postMock).toHaveBeenNthCalledWith(2, '/auth/oauth/wechat/complete-registration', { + invitation_code: 'INVITE-CODE', + adopt_display_name: false, + adopt_avatar: true, + }) + expect(setTokenMock).toHaveBeenCalledWith('wechat-invite-token') + expect(replaceMock).toHaveBeenCalledWith('/subscriptions') + }) +}) diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue index e91df5da..838f3000 100644 --- a/frontend/src/views/user/PaymentView.vue +++ b/frontend/src/views/user/PaymentView.vue @@ -23,20 +23,7 @@ :order-type="paymentState.orderType" @done="onPaymentDone" @success="onPaymentSuccess" - /> - - @@ -265,7 +252,7 @@